From 8252c1f2117e0ed81f601ad3336dce79eb9afbf0 Mon Sep 17 00:00:00 2001
From: Anton Akhmerov <anton.akhmerov@gmail.com>
Date: Wed, 21 Dec 2016 18:09:14 +0100
Subject: [PATCH] add fill method to Builder

---
 kwant/builder.py            | 162 +++++++++++++++++++++++++++---------
 kwant/tests/test_builder.py |  87 +++++++++++++++++++
 2 files changed, 211 insertions(+), 38 deletions(-)

diff --git a/kwant/builder.py b/kwant/builder.py
index 89e4377e..90c22623 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -708,11 +708,9 @@ class Builder:
 
     Builder instances can be made to automatically respect a `Symmetry` that is
     passed to them during creation.  The behavior of builders with a symmetry
-    is slightly more sophisticated.  First of all, it is implicitly assumed
-    throughout Kwant that **every** function assigned as a value to a builder
-    with a symmetry possesses the same symmetry.  Secondly, all keys are mapped
-    to the fundamental domain of the symmetry before storing them.  This may
-    produce confusing results when neighbors of a site are queried.
+    is slightly more sophisticated: all keys are mapped to the fundamental
+    domain of the symmetry before storing them.  This may produce confusing
+    results when neighbors of a site are queried.
 
     The method `attach_lead` *works* only if the sites affected by them have
     tags which are sequences of integers.  It *makes sense* only when these
@@ -1120,6 +1118,111 @@ class Builder:
         self.leads.extend(other.leads)
         return self
 
+    def fill(self, other, start, shape=None, *, overwrite=False, max_sites=1e7):
+        """Populate a builder using another as a template.
+
+
+        Parameters
+        ----------
+        other : `Builder`
+            A Builder used as a template. The symmetry of the target `Builder`
+            must be a subgroup of the symmetry of the template.
+        start : tuple of integers or `Site`
+            The initial domain to be used to start the flood-fill or a `Site`
+            belonging to that domain.
+        shape : callable
+            A boolean function of site returning whether the site should be
+            added to the target builder or not. If not provided, all sites are
+            added (beware that there may be infinitely many).
+        overwrite : boolean
+            If existing sites or hoppings in the target `Builder` should be
+            overwritten.
+        max_sites : positive number
+            The maximal number of sites that may be added, used to prevent
+            memory overflow.
+
+        Returns
+        -------
+        added_sites : list of `Site` objects that were added to the system.
+
+        Raises
+        ------
+        ValueError
+            If the symmetry of the target isn't a subgroup of the template
+            symmetry.
+        RuntimeError
+            If `max_sites` sites are added.
+
+        Notes
+        -----
+        This function uses a flood-fill algorithm, so all sites in the template
+        builder should be reachable from all other sites.
+        """
+        if shape is None:
+            def shape(site): return True
+
+        if not max_sites > 0:
+            raise ValueError("max_sites must be positive.")
+
+        sym = other.symmetry
+        H = other.H
+
+        # if start is a site, convert it to a domain.
+        if isinstance(start, Site):
+            start = sym.which(start)
+
+        # Check that symmetries are commensurate.
+        if not self.symmetry <= sym:
+            raise ValueError("Builder symmetry is not a subgroup of the "
+                             "template symmetry")
+
+        # map site to the given domain of other.symmetry,
+        # while ensuring that it is in the fundamental domain
+        # of self.symmetry
+        def to_domain(domain, sites_or_hops):
+            for site_or_hop in sites_or_hops:
+                if isinstance(site_or_hop, Site):
+                    site_or_hop = (sym.act(domain, site_or_hop),)
+                else:
+                    site_or_hop = sym.act(domain, *site_or_hop)
+                result = self.symmetry.to_fd(*site_or_hop)
+                yield result
+
+        def add_site(candidate):
+            may_add = overwrite or candidate not in self.H
+            # Delay calling shape because it may raise an error.
+            if not may_add or candidate in all_added:
+                return
+            if shape(candidate):
+                if len(all_added) == max_sites:
+                    raise RuntimeError("Maximal number of sites (max_sites "
+                                       "parameter of fill()) added, stopping.")
+                new_sites.add(candidate)
+                all_added.add(candidate)
+                self[candidate] = other[candidate]
+
+        # Initialize the flood-fill
+        new_sites = set()
+        all_added = set()
+        for site in to_domain(start, H):
+            add_site(site)
+
+        # Flood-fill
+        while new_sites:
+            site = new_sites.pop()
+            domain = sym.which(site)
+            # other.neighbors(site) gives neighbors of the *image* of
+            # site in FD of other.symmetry, so we must map it correctly
+            hoppings = [(sym.to_fd(site), n) for n in other.neighbors(site)]
+            for hopping in to_domain(domain, hoppings):
+                add_site(hopping[1])
+                try:
+                    self[hopping] = other[hopping]
+                except KeyError:
+                    pass
+
+        return list(all_added)
+
     def attach_lead(self, lead_builder, origin=None, add_cells=0):
         """Attach a lead to the builder, possibly adding missing sites.
 
@@ -1139,7 +1242,7 @@ class Builder:
 
         Returns
         -------
-            added_sites : list of `Site` objects that were added to the system.
+        added_sites : list of `Site` objects that were added to the system.
 
         Raises
         ------
@@ -1207,44 +1310,23 @@ class Builder:
         min_dom = min(all_doms)
         del all_doms
 
+        def shape(site):
+            domain, = sym.which(site)
+            if domain < min_dom:
+                raise ValueError('Builder does not interrupt the lead,'
+                                 ' this lead cannot be attached.')
+            return domain < max_dom + 1
+
+        all_added = self.fill(lead_builder, (max_dom,), shape=shape,
+                              max_sites=float('inf'))
+
+        # Calculate the interface.
         interface = set()
-        added = set()
-        all_added = []
-        # Initialize flood-fill: create the outermost sites.
         for site in H:
             for neighbor in lead_builder.neighbors(site):
                 neighbor = sym.act((max_dom + 1,), neighbor)
                 if sym.which(neighbor)[0] == max_dom:
-                    if neighbor not in self:
-                        self[neighbor] = lead_builder[neighbor]
-                        added.add(neighbor)
                     interface.add(neighbor)
-        all_added.extend(added)
-
-        # Do flood-fill.
-        covered = True
-        while covered:
-            covered = False
-            added2 = set()
-            for site in added:
-                site_dom = sym.which(site)
-                move = lambda x: sym.act(site_dom, x)
-                for site_new in lead_builder.neighbors(site):
-                    site_new = move(site_new)
-                    new_dom = sym.which(site_new)[0]
-                    if new_dom == max_dom + 1:
-                        continue
-                    elif new_dom < min_dom:
-                        raise ValueError('Builder does not interrupt the lead,'
-                                         ' this lead cannot be attached.')
-                    if (site_new not in self
-                        and sym.which(site_new)[0] != max_dom + 1):
-                        self[site_new] = lead_builder[site_new]
-                        added2.add(site_new)
-                        covered = True
-                    self[site_new, site] = lead_builder[site_new, site]
-            added = added2
-            all_added.extend(added)
 
         self.leads.append(BuilderLead(lead_builder, tuple(interface)))
         return all_added
@@ -1264,6 +1346,10 @@ class Builder:
         This method does not modify the Builder instance for which it is
         called.
 
+        Upon finalization, it is implicitly assumed that **every** function
+        assigned as a value to a builder with a symmetry possesses the same
+        symmetry.
+
         Attached leads are also finalized and will be present in the finalized
         system to be returned.
 
diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py
index bfe757ed..2f497277 100644
--- a/kwant/tests/test_builder.py
+++ b/kwant/tests/test_builder.py
@@ -635,6 +635,93 @@ def test_builder_with_symmetry():
                                                                   (5, 0, -3))])
 
 
+def test_fill():
+    # Use function as a value since otherwise a hopping in the opposite
+    # direction may be stored after fill.
+    def f(*sites): pass
+    g = kwant.lattice.square()
+    sym_x = kwant.TranslationalSymmetry((-1, 0))
+    sym_xy = kwant.TranslationalSymmetry((-1, 0), (0, 1))
+
+    template_1d = builder.Builder(sym_x)
+    template_1d[g(0, 0)] = f
+    template_1d[g.neighbors()] = f
+
+    def line_200(site):
+        return -100 <= site.pos[0] < 100
+
+    ## test max_sites
+    target = builder.Builder()
+    for max_sites in (-1, 0):
+        with raises(ValueError):
+            target.fill(template_1d, g(0, 0), max_sites=max_sites)
+    target = builder.Builder()
+    with raises(RuntimeError):
+        target.fill(template_1d, g(0, 0), shape=line_200, max_sites=10)
+    ## test filling
+    target = builder.Builder()
+    added_sites = target.fill(template_1d, g(0, 0), shape=line_200)
+    assert len(added_sites) == 200
+    ## test overwrite=False
+    added_sites = target.fill(template_1d, g(0, 0), shape=line_200)
+    assert len(added_sites) == 0
+    ## test overwrite=True
+    added_sites = target.fill(template_1d, g(0, 0),
+                              shape=line_200, overwrite=True)
+    assert len(added_sites) == 200
+
+
+    ## test multiplying unit cell size in 1D
+    n_cells = 10
+    sym_nx = kwant.TranslationalSymmetry(*(sym_x.periods * n_cells))
+    target = builder.Builder(sym_nx)
+    target.fill(template_1d, g(0, 0))
+
+    should_be_syst = builder.Builder(sym_nx)
+    should_be_syst[(g(i, 0) for i in range(n_cells))] = f
+    should_be_syst[g.neighbors()] = f
+
+    assert sorted(target.sites()) == sorted(should_be_syst.sites())
+    assert sorted(target.hoppings()) == sorted(should_be_syst.hoppings())
+
+
+    ## test multiplying unit cell size in 2D
+    template_2d = builder.Builder(sym_xy)
+    template_2d[g(0, 0)] = f
+    template_2d[g.neighbors()] = f
+    template_2d[builder.HoppingKind((2, 2), g)] = f
+
+    nm_cells = (3, 5)
+    sym_nmxy = kwant.TranslationalSymmetry(*(sym_xy.periods * nm_cells))
+    target = builder.Builder(sym_nmxy)
+    target.fill(template_2d, g(0, 0))
+
+    should_be_syst = builder.Builder(sym_nmxy)
+    should_be_syst[(g(i, j) for i in range(10) for j in range(10))] = f
+    should_be_syst[g.neighbors()] = f
+    should_be_syst[builder.HoppingKind((2, 2), g)] = f
+
+    assert sorted(target.sites()) == sorted(should_be_syst.sites())
+    assert sorted(target.hoppings()) == sorted(should_be_syst.hoppings())
+
+
+    ## test filling 0D builder with 2D builder
+    def square_shape(site):
+        x, y = site.tag
+        return 0 <= x < 10 and 0 <= y < 10
+
+    target = builder.Builder()
+    target.fill(template_2d, g(0, 0), square_shape)
+
+    should_be_syst = builder.Builder()
+    should_be_syst[(g(i, j) for i in range(10) for j in range(10))] = f
+    should_be_syst[g.neighbors()] = f
+    should_be_syst[builder.HoppingKind((2, 2), g)] = f
+
+    assert sorted(target.sites()) == sorted(should_be_syst.sites())
+    assert sorted(target.hoppings()) == sorted(should_be_syst.hoppings())
+
+
 def test_attach_lead():
     fam = builder.SimpleSiteFamily()
     fam_noncommensurate = builder.SimpleSiteFamily(name='other')
-- 
GitLab