From 9945ee547b929cabd00f9f2712bb0508a144d788 Mon Sep 17 00:00:00 2001 From: Christoph Groth <christoph.groth@cea.fr> Date: Fri, 3 May 2013 12:23:11 +0200 Subject: [PATCH] lattice.Shape: code and documentation reformulations --- kwant/lattice.py | 101 +++++++++++++++--------------------- kwant/tests/test_lattice.py | 5 +- 2 files changed, 45 insertions(+), 61 deletions(-) diff --git a/kwant/lattice.py b/kwant/lattice.py index 2885fd37..6dc6c2ec 100644 --- a/kwant/lattice.py +++ b/kwant/lattice.py @@ -471,28 +471,29 @@ class TranslationalSymmetry(builder.Symmetry): class Shape(object): def __init__(self, lattice, function, start): - """A class for finding all the lattice sites in a shape. + """A class for finding all the lattice sites inside a shape. - It uses a flood-fill algorithm, and takes into account - the symmetry of the builder to which it is provided, or - the symmetry, that is supplied to it after initialization. + When an instance of this class is called, a flood-fill algorithm finds + and yields all the sites inside the specified shape starting from the + specified position. Parameters ---------- lattice : Polyatomic or Monoatomic lattice Lattice, to which the resulting sites should belong. function : callable - A function of real space coordinates, which should - return True for coordinates inside the shape, - and False otherwise. + A function of real space coordinates that returns a truth value: + true for coordinates inside the shape, and false otherwise. start : float vector The origin for the flood-fill algorithm. Notes ----- - A ``Shape`` is a callable object: When called with a - `~kwant.builder.Builder` as sole argument, an instance of this class will - return an iterator over all the sites from the shape that are in the fundamental domain of the builder's symmetry. + A `~kwant.builder.Symmetry` or `~kwant.builder.Builder` may be passed as + sole argument when calling an instance of this class. This will + restrict the flood-fill to the fundamental domain of the symmetry (or + the builder's symmetry). Note that unless the shape function has that + symmetry itself, the result may be unexpected. Because a `~kwant.builder.Builder` can be indexed with functions or iterables of functions, ``Shape`` instances (or any non-tuple @@ -501,70 +502,52 @@ class Shape(object): """ self.lat, self.func, self.start = lattice, function, start - def __call__(self, builder_or_symmetry=None): - """ - Yield all the lattice sites which belong to a certain shape. - - Parameters - ---------- - builder_or_symmetry : Builder or Symmetry instance - The builder to which the site from the shape are added, or - the symmetry, such that the sites from the shape belong to - its fundamental domain. If not provided, trivial symmetry is - used. - - Returns - ------- - sites : sequence of `Site` objects - all the sites that belong to the lattice and fit inside the shape. - """ + def __call__(self, symmetry=None): Site = builder.Site lat, func, start = self.lat, self.func, self.start - try: - symmetry = builder_or_symmetry.symmetry - except AttributeError: - symmetry = builder_or_symmetry + if symmetry is None: - symmetry = builder.NoSymmetry() + symmetry = builder.NoSymmetry() + elif not isinstance(symmetry, builder.Symmetry): + symmetry = symmetry.symmetry - sym_site = lambda lat, tag: symmetry.to_fd(Site(lat, tag, True)) + def sym_site(lat, tag): + return symmetry.to_fd(Site(lat, tag, True)) dim = len(start) if dim != lat.prim_vecs.shape[1]: raise ValueError('Dimensionality of start position does not match' ' the space dimensionality.') sls = lat.sublattices - deltas = [ta.array(i) for i in lat._voronoi] - - # Check if no sites are going to be added, to catch a common error. - empty = True - for sl in sls: - if func(sym_site(sl, sl.closest(start)).pos): - empty = False - if empty: + deltas = [ta.array(delta) for delta in lat._voronoi] + + #### Flood-fill #### + sites = [] + for tag in set(sl.closest(start) for sl in sls): + for sl in sls: + site = sym_site(sl, tag) + if func(site.pos): + sites.append(site) + if not sites: msg = 'No sites close to {0} are inside the desired shape.' raise ValueError(msg.format(start)) - # Continue to flood fill. - tags = set([sl.closest(start) for sl in sls]) - new_sites = [sym_site(sl, tag) for sl in sls for tag in tags] - new_sites = [i for i in new_sites if func(i.pos)] old_sites = set() - while new_sites: - tmp = set() - for site in new_sites: + while sites: + tags = set() + for site in sites: yield site - tags = set((i.tag for i in new_sites)) + tags.add(site.tag) + tags = set(tag + delta for tag in tags for delta in deltas) + new_sites = set() for tag in tags: - for shift in deltas: - for sl in sls: - site = sym_site(sl, tag + shift) - if site not in old_sites and \ - site not in new_sites and \ - func(site.pos): - tmp.add(site) - old_sites = new_sites - new_sites = tmp + for sl in sls: + site = sym_site(sl, tag) + if site not in old_sites and site not in sites \ + and func(site.pos): + new_sites.add(site) + old_sites = sites + sites = new_sites ################ Library of lattices (to be extended) @@ -584,6 +567,6 @@ def square(a=1, name=''): def honeycomb(a=1, name=''): """Create a honeycomb lattice.""" lat = Polyatomic(((a, 0), (0.5 * a, 0.5 * a * sqrt(3))), - ((0, 0), (0, a / sqrt(3))), name=name) + ((0, 0), (0, a / sqrt(3))), name=name) lat.a, lat.b = lat.sublattices return lat diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py index fdad446e..8e34f9be 100644 --- a/kwant/tests/test_lattice.py +++ b/kwant/tests/test_lattice.py @@ -56,8 +56,7 @@ def test_shape(): def in_circle(pos): return pos[0] ** 2 + pos[1] ** 2 < 3 - lat = lattice.general(((1, 0), (0.5, sqrt(3) / 2)), - ((0, 0), (0, 1 / sqrt(3)))) + lat = lattice.honeycomb() sites = list(lat.shape(in_circle, (0, 0))()) sites_alt = list() sl0, sl1 = lat.sublattices @@ -70,6 +69,7 @@ def test_shape(): assert len(sites) == len(sites_alt) assert_equal(set(sites), set(sites_alt)) assert_raises(ValueError, lat.shape(in_circle, (10, 10))().next) + # Check if narrow ribbons work. for period in (0, 1), (1, 0), (1, -1): vec = lat.vec(period) @@ -79,6 +79,7 @@ def test_shape(): sites = list(lat.shape(shape, (0, 0))(sym)) assert len(sites) > 35 + def test_translational_symmetry(): ts = lattice.TranslationalSymmetry g2 = lattice.general(np.identity(2)) -- GitLab