From 9945ee547b929cabd00f9f2712bb0508a144d788 Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Fri, 3 May 2013 12:23:11 +0200
Subject: [PATCH] lattice.Shape: code and documentation reformulations

---
 kwant/lattice.py            | 101 +++++++++++++++---------------------
 kwant/tests/test_lattice.py |   5 +-
 2 files changed, 45 insertions(+), 61 deletions(-)

diff --git a/kwant/lattice.py b/kwant/lattice.py
index 2885fd37..6dc6c2ec 100644
--- a/kwant/lattice.py
+++ b/kwant/lattice.py
@@ -471,28 +471,29 @@ class TranslationalSymmetry(builder.Symmetry):
 
 class Shape(object):
     def __init__(self, lattice, function, start):
-        """A class for finding all the lattice sites in a shape.
+        """A class for finding all the lattice sites inside a shape.
 
-        It uses a flood-fill algorithm, and takes into account
-        the symmetry of the builder to which it is provided, or
-        the symmetry, that is supplied to it after initialization.
+        When an instance of this class is called, a flood-fill algorithm finds
+        and yields all the sites inside the specified shape starting from the
+        specified position.
 
         Parameters
         ----------
         lattice : Polyatomic or Monoatomic lattice
             Lattice, to which the resulting sites should belong.
         function : callable
-            A function of real space coordinates, which should
-            return True for coordinates inside the shape,
-            and False otherwise.
+            A function of real space coordinates that returns a truth value:
+            true for coordinates inside the shape, and false otherwise.
         start : float vector
             The origin for the flood-fill algorithm.
 
         Notes
         -----
-        A ``Shape`` is a callable object: When called with a
-        `~kwant.builder.Builder` as sole argument, an instance of this class will
-        return an iterator over all the sites from the shape that are in the fundamental domain of the builder's symmetry.
+        A `~kwant.builder.Symmetry` or `~kwant.builder.Builder` may be passed as
+        sole argument when calling an instance of this class.  This will
+        restrict the flood-fill to the fundamental domain of the symmetry (or
+        the builder's symmetry).  Note that unless the shape function has that
+        symmetry itself, the result may be unexpected.
 
         Because a `~kwant.builder.Builder` can be indexed with functions or
         iterables of functions, ``Shape`` instances (or any non-tuple
@@ -501,70 +502,52 @@ class Shape(object):
         """
         self.lat, self.func, self.start = lattice, function, start
 
-    def __call__(self, builder_or_symmetry=None):
-        """
-        Yield all the lattice sites which belong to a certain shape.
-
-        Parameters
-        ----------
-        builder_or_symmetry : Builder or Symmetry instance
-            The builder to which the site from the shape are added, or
-            the symmetry, such that the sites from the shape belong to
-            its fundamental domain. If not provided, trivial symmetry is
-            used.
-
-        Returns
-        -------
-        sites : sequence of `Site` objects
-            all the sites that belong to the lattice and fit inside the shape.
-        """
+    def __call__(self, symmetry=None):
         Site = builder.Site
         lat, func, start = self.lat, self.func, self.start
-        try:
-            symmetry = builder_or_symmetry.symmetry
-        except AttributeError:
-            symmetry = builder_or_symmetry
+
         if symmetry is None:
-            symmetry = builder.NoSymmetry()
+             symmetry = builder.NoSymmetry()
+        elif not isinstance(symmetry, builder.Symmetry):
+            symmetry = symmetry.symmetry
 
-        sym_site = lambda lat, tag: symmetry.to_fd(Site(lat, tag, True))
+        def sym_site(lat, tag):
+            return symmetry.to_fd(Site(lat, tag, True))
 
         dim = len(start)
         if dim != lat.prim_vecs.shape[1]:
             raise ValueError('Dimensionality of start position does not match'
                              ' the space dimensionality.')
         sls = lat.sublattices
-        deltas = [ta.array(i) for i in lat._voronoi]
-
-        # Check if no sites are going to be added, to catch a common error.
-        empty = True
-        for sl in sls:
-            if func(sym_site(sl, sl.closest(start)).pos):
-                empty = False
-        if empty:
+        deltas = [ta.array(delta) for delta in lat._voronoi]
+
+        #### Flood-fill ####
+        sites = []
+        for tag in set(sl.closest(start) for sl in sls):
+            for sl in sls:
+                site = sym_site(sl, tag)
+                if func(site.pos):
+                    sites.append(site)
+        if not sites:
             msg = 'No sites close to {0} are inside the desired shape.'
             raise ValueError(msg.format(start))
 
-        # Continue to flood fill.
-        tags = set([sl.closest(start) for sl in sls])
-        new_sites = [sym_site(sl, tag) for sl in sls for tag in tags]
-        new_sites = [i for i in new_sites if func(i.pos)]
         old_sites = set()
-        while new_sites:
-            tmp = set()
-            for site in new_sites:
+        while sites:
+            tags = set()
+            for site in sites:
                 yield site
-            tags = set((i.tag for i in new_sites))
+                tags.add(site.tag)
+            tags = set(tag + delta for tag in tags for delta in deltas)
+            new_sites = set()
             for tag in tags:
-                for shift in deltas:
-                    for sl in sls:
-                        site = sym_site(sl, tag + shift)
-                        if site not in old_sites and \
-                           site not in new_sites and \
-                           func(site.pos):
-                            tmp.add(site)
-            old_sites = new_sites
-            new_sites = tmp
+                for sl in sls:
+                    site = sym_site(sl, tag)
+                    if site not in old_sites and site not in sites \
+                            and func(site.pos):
+                        new_sites.add(site)
+            old_sites = sites
+            sites = new_sites
 
 
 ################ Library of lattices (to be extended)
@@ -584,6 +567,6 @@ def square(a=1, name=''):
 def honeycomb(a=1, name=''):
     """Create a honeycomb lattice."""
     lat = Polyatomic(((a, 0), (0.5 * a, 0.5 * a * sqrt(3))),
-                            ((0, 0), (0, a / sqrt(3))), name=name)
+                     ((0, 0), (0, a / sqrt(3))), name=name)
     lat.a, lat.b = lat.sublattices
     return lat
diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py
index fdad446e..8e34f9be 100644
--- a/kwant/tests/test_lattice.py
+++ b/kwant/tests/test_lattice.py
@@ -56,8 +56,7 @@ def test_shape():
     def in_circle(pos):
         return pos[0] ** 2 + pos[1] ** 2 < 3
 
-    lat = lattice.general(((1, 0), (0.5, sqrt(3) / 2)),
-                               ((0, 0), (0, 1 / sqrt(3))))
+    lat = lattice.honeycomb()
     sites = list(lat.shape(in_circle, (0, 0))())
     sites_alt = list()
     sl0, sl1 = lat.sublattices
@@ -70,6 +69,7 @@ def test_shape():
     assert len(sites) == len(sites_alt)
     assert_equal(set(sites), set(sites_alt))
     assert_raises(ValueError, lat.shape(in_circle, (10, 10))().next)
+
     # Check if narrow ribbons work.
     for period in (0, 1), (1, 0), (1, -1):
         vec = lat.vec(period)
@@ -79,6 +79,7 @@ def test_shape():
         sites = list(lat.shape(shape, (0, 0))(sym))
         assert len(sites) > 35
 
+
 def test_translational_symmetry():
     ts = lattice.TranslationalSymmetry
     g2 = lattice.general(np.identity(2))
-- 
GitLab