From 4d88fdf95812eeab064fb8713959697e2bb3c647 Mon Sep 17 00:00:00 2001
From: Anton Akhmerov <anton.akhmerov@gmail.com>
Date: Mon, 2 Sep 2013 14:48:42 +0200
Subject: [PATCH] clean-up of TranslationalSymmetry.add_site_family

---
 kwant/lattice.py            | 82 ++++++++++++++++++++++++-------------
 kwant/tests/test_lattice.py | 48 +++++++++++++++-------
 2 files changed, 86 insertions(+), 44 deletions(-)

diff --git a/kwant/lattice.py b/kwant/lattice.py
index 78ab556..15669ef 100644
--- a/kwant/lattice.py
+++ b/kwant/lattice.py
@@ -502,71 +502,95 @@ class TranslationalSymmetry(builder.Symmetry):
         self.site_family_data = {}
         self.is_reversed = False
 
-    def add_site_family(self, gr, other_vectors=None):
+    def add_site_family(self, fam, other_vectors=None):
         """
         Select a fundamental domain for site family and cache associated data.
 
         Parameters
         ----------
-        gr : `SiteFamily`
+        fam : `SiteFamily`
             the site family which has to be processed.  Be sure to delete the
             previously processed site families from `site_family_data` if you
             want to modify the cache.
 
         other_vectors : list of lists of integers
             Bravais lattice vectors used to complement the periods in forming
-            a basis. The fundamental domain belongs to the linear space
-            spanned by these vectors.
+            a basis. The fundamental domain consists of all the lattice sites
+            for which the zero coefficients corresponding to the symmetry
+            periods in the basis formed by the symmetry periods and
+            `other_vectors`. If an insufficient number of `other_vectors` is
+            provided to form a basis, the missing ones are selected
+            automatically.
 
         Raises
         ------
         KeyError
-            If `gr` is already stored in `site_family_data`.
+            If `fam` is already stored in `site_family_data`.
         ValueError
-            If lattice shape of `gr` cannot have the given `periods`.
+            If lattice `fam` is incompatible with given periods.
         """
-        if gr in self.site_family_data:
+        dim = self._periods.shape[1]
+        if fam in self.site_family_data:
             raise KeyError('Family already processed, delete it from '
                            'site_family_data first.')
-        inv = np.linalg.pinv(gr.prim_vecs)
-        bravais_periods = [np.dot(i, inv) for i in self._periods]
+        inv = np.linalg.pinv(fam.prim_vecs)
+        bravais_periods = np.dot(self._periods, inv)
+        # Absolute tolerance is correct in the following since we want an error
+        # relative to the closest integer.
         if not np.allclose(bravais_periods, np.round(bravais_periods),
                            rtol=0, atol=1e-8) or \
-           not np.allclose([gr.vec(i) for i in bravais_periods],
+           not np.allclose([fam.vec(i) for i in bravais_periods],
                            self._periods):
             msg = 'Site family {0} does not have commensurate periods with ' +\
                   'symmetry {1}.'
-            raise ValueError(msg.format(gr, self))
+            raise ValueError(msg.format(fam, self))
         bravais_periods = np.array(np.round(bravais_periods), dtype='int')
-        (num_dir, dim) = bravais_periods.shape
+        (num_dir, lat_dim) = bravais_periods.shape
         if other_vectors is None:
-            other_vectors = []
-        for vec in other_vectors:
-            for a in vec:
-                if not isinstance(a, int):
-                    raise ValueError('Only integer other_vectors are allowed.')
-        m = np.zeros((dim, dim), dtype=int)
-
-        m.T[: num_dir] = bravais_periods
+            other_vectors = np.zeros((0, lat_dim), dtype=int)
+        else:
+            other_vectors = np.array(other_vectors)
+            if np.any(np.round(other_vectors) - other_vectors):
+                raise ValueError('Only integer other_vectors are allowed.')
+            other_vectors = np.array(np.round(other_vectors), dtype=int)
+
+        m = np.zeros((lat_dim, lat_dim), dtype=int)
+
+        m.T[:num_dir] = bravais_periods
         num_vec = num_dir + len(other_vectors)
-        if len(other_vectors) != 0:
-            m.T[num_dir:num_vec] = other_vectors
-        norms = np.apply_along_axis(np.linalg.norm, 1, m)
-        indices = np.argsort(norms)
-        for coord in zip(indices, range(num_vec, dim)):
-            m[coord] = 1
+        m.T[num_dir:num_vec] = other_vectors
+
+        if np.linalg.matrix_rank(m) < num_vec:
+            raise ValueError('other_vectors and symmetry periods are not '
+                             'linearly independent.')
+
+        # To define the fundamental domain of the new site family we now need to
+        # choose `lat_dim - num_vec` extra lattice vectors that are not
+        # linearly dependent on the vectors we already have. To do so we
+        # continuously add the lattice basis vectors one by one such that they
+        # are not linearly dependent on the existent vectors
+        while num_vec < lat_dim:
+            vh = np.linalg.svd(np.dot(m[:, :num_vec].T, fam.prim_vecs),
+                               full_matrices=False)[2]
+            projector = np.identity(dim) - np.dot(vh.T, vh)
+
+            residuals = np.dot(fam.prim_vecs, projector)
+            residuals = np.apply_along_axis(np.linalg.norm, 1, residuals)
+            m[np.argmax(residuals), num_vec] = 1
+            num_vec += 1
 
         det_m = int(round(np.linalg.det(m)))
         if det_m == 0:
-            raise RuntimeError('Adding site group failed.')
+            print m
+            raise RuntimeError('Adding site family failed.')
 
         det_x_inv_m = \
             np.array(np.round(det_m * np.linalg.inv(m)), dtype=int)
-        assert (np.dot(m, det_x_inv_m) // det_m == np.identity(dim)).all()
+        assert (np.dot(m, det_x_inv_m) // det_m == np.identity(lat_dim)).all()
 
         det_x_inv_m_part = det_x_inv_m[:num_dir, :]
         m_part = m[:, :num_dir]
-        self.site_family_data[gr] = (ta.array(m_part),
+        self.site_family_data[fam] = (ta.array(m_part),
                                      ta.array(det_x_inv_m_part), det_m)
 
     @property
diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py
index 74800a3..2e739d6 100644
--- a/kwant/tests/test_lattice.py
+++ b/kwant/tests/test_lattice.py
@@ -103,35 +103,35 @@ def test_wire():
 
 def test_translational_symmetry():
     ts = lattice.TranslationalSymmetry
-    g2 = lattice.general(np.identity(2))
-    g3 = lattice.general(np.identity(3))
+    f2 = lattice.general(np.identity(2))
+    f3 = lattice.general(np.identity(3))
     shifted = lambda site, delta: site.family(*ta.add(site.tag, delta))
 
     assert_raises(ValueError, ts, (0, 0, 4), (0, 5, 0), (0, 0, 2))
     sym = ts((3.3, 0))
-    assert_raises(ValueError, sym.add_site_family, g2)
+    assert_raises(ValueError, sym.add_site_family, f2)
 
     # Test lattices with dimension smaller than dimension of space.
-    g2in3 = lattice.general([[4, 4, 0], [4, -4, 0]])
+    f2in3 = lattice.general([[4, 4, 0], [4, -4, 0]])
     sym = ts((8, 0, 0))
-    sym.add_site_family(g2in3)
+    sym.add_site_family(f2in3)
     sym = ts((8, 0, 1))
-    assert_raises(ValueError, sym.add_site_family, g2in3)
+    assert_raises(ValueError, sym.add_site_family, f2in3)
 
     # Test automatic fill-in of transverse vectors.
     sym = ts((1, 2))
-    sym.add_site_family(g2)
-    assert_not_equal(sym.site_family_data[g2][2], 0)
+    sym.add_site_family(f2)
+    assert_not_equal(sym.site_family_data[f2][2], 0)
     sym = ts((1, 0, 2), (3, 0, 2))
-    sym.add_site_family(g3)
-    assert_not_equal(sym.site_family_data[g3][2], 0)
+    sym.add_site_family(f3)
+    assert_not_equal(sym.site_family_data[f3][2], 0)
 
     transl_vecs = np.array([[10, 0], [7, 7]], dtype=int)
     sym = ts(*transl_vecs)
     assert_equal(sym.num_directions, 2)
     sym2 = ts(*transl_vecs[: 1, :])
-    sym2.add_site_family(g2, transl_vecs[1:, :])
-    for site in [g2(0, 0), g2(4, 0), g2(2, 1), g2(5, 5), g2(15, 6)]:
+    sym2.add_site_family(f2, transl_vecs[1:, :])
+    for site in [f2(0, 0), f2(4, 0), f2(2, 1), f2(5, 5), f2(15, 6)]:
         assert sym.in_fd(site)
         assert sym2.in_fd(site)
         assert_equal(sym.which(site), (0, 0))
@@ -150,10 +150,28 @@ def test_translational_symmetry():
                              (site, shifted(site, hop)))
 
     # Test act for hoppings belonging to different lattices.
-    g2p = lattice.general(2 * np.identity(2))
+    f2p = lattice.general(2 * np.identity(2))
     sym = ts(*(2 * np.identity(2)))
-    assert sym.act((1, 1), g2(0, 0), g2p(0, 0)) == (g2(2, 2), g2p(1, 1))
-    assert sym.act((1, 1), g2p(0, 0), g2(0, 0)) == (g2p(1, 1), g2(2, 2))
+    assert sym.act((1, 1), f2(0, 0), f2p(0, 0)) == (f2(2, 2), f2p(1, 1))
+    assert sym.act((1, 1), f2p(0, 0), f2(0, 0)) == (f2p(1, 1), f2(2, 2))
+
+    # Test add_site_family on random lattices and symmetries by ensuring that
+    # it's possible to add site groups that are compatible with a randomly
+    # generated symmetry with proper vectors.
+    np.random.seed(30)
+    vec = np.random.randn(3, 5)
+    lat = lattice.general(vec)
+    total = 0
+    for k in range(1, 4):
+        for i in range(10):
+            sym_vec = np.random.randint(-10, 10, size=(k, 3))
+            if np.linalg.matrix_rank(sym_vec) < k:
+                continue
+            total += 1
+            sym_vec = np.dot(sym_vec, vec)
+            sym = ts(*sym_vec)
+            sym.add_site_family(lat)
+    assert total > 20
 
 
 def test_translational_symmetry_reversed():
-- 
GitLab