From e7c4fb137fc9fbf5b14841e4d3b93e1adba5aa93 Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Fri, 5 May 2017 00:17:35 +0200
Subject: [PATCH] implement Builder.closest()

---
 kwant/builder.py            | 55 +++++++++++++++++++++++++++++++++++++
 kwant/tests/test_builder.py | 41 +++++++++++++++++++++++++++
 2 files changed, 96 insertions(+)

diff --git a/kwant/builder.py b/kwant/builder.py
index d711463d..f245a0fb 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -20,6 +20,7 @@ import tinyarray as ta
 import numpy as np
 from scipy import sparse
 from . import system, graph, KwantDeprecationWarning, UserCodeError
+from .linalg import lll
 from .operator import Density
 from .physics import DiscreteSymmetry
 from ._common import ensure_isinstance, get_parameters
@@ -1177,6 +1178,60 @@ class Builder:
         a = self.symmetry.to_fd(site)
         return self._out_neighbors(a)
 
+    def closest(self, pos):
+        """Return the site that is closest to the given position.
+
+        This function takes into account the symmetry of the builder.  It is
+        assumed that the symmetry is a translational symmetry.
+
+        This function executes in a time proportional to the number of sites,
+        so it is not efficient for large builders.  It is especially slow for
+        builders with a symmetry, but such systems often contain only a limited
+        number of sites.
+
+        """
+        errmsg = ("Builder.closest() requires site families that provide "
+                  "pos().\nThe following one does not:\n")
+        sym = self.symmetry
+        n = sym.num_directions
+
+        if n:
+            # Determine basis in real space from first site.  (The result from
+            # any site would do.)
+            I = ta.identity(n, int)
+            site = next(iter(self.H))
+            space_basis = [sym.act(element, site).pos - site.pos
+                           for element in I]
+            space_basis, transf = lll.lll(space_basis)
+            transf = ta.array(transf.T, int)
+
+        tag_basis_cache = {}
+        dist = float('inf')
+        result = None
+        for site in self.H:
+            try:
+                site_pos = site.pos
+            except AttributeError:
+                raise AttributeError(errmsg + str(site.family))
+            if n:
+                fam = site.family
+                tag_basis = tag_basis_cache.get(fam)
+                if tag_basis is None:
+                    zero_site = Site(fam, ta.zeros(len(site.tag), int))
+                    tag_basis = [sym.act(element, zero_site).tag
+                                 for element in I]
+                    tag_basis = ta.dot(transf, tag_basis)
+                    tag_basis_cache[fam] = tag_basis
+                shift = lll.cvp(pos - site_pos, space_basis, 1)[0]
+                site = Site(fam, ta.dot(shift, tag_basis) + site.tag)
+                site_pos = site.pos
+            d = site_pos - pos
+            d = ta.dot(d, d)
+            if d < dist:
+                dist = d
+                result = site
+        return result
+
     def __iadd__(self, other):
         for site, value in other.site_value_pairs():
             self[site] = value
diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py
index 8c0d5747..6e22d183 100644
--- a/kwant/tests/test_builder.py
+++ b/kwant/tests/test_builder.py
@@ -16,6 +16,7 @@ import tinyarray as ta
 import numpy as np
 import kwant
 from kwant import builder
+from kwant._common import ensure_rng
 
 
 def test_bad_keys():
@@ -824,6 +825,46 @@ def test_neighbors_not_in_single_domain():
     raises(ValueError, sr.finalized)
 
 
+def inside_disc(center, rr):
+    def shape(site):
+        d = site.pos - center
+        dd = ta.dot(d, d)
+        return dd <= rr
+    return shape
+
+
+def test_closest():
+    rng = ensure_rng(10)
+    for sym_dim in range(1, 4):
+        for space_dim in range(sym_dim, 4):
+            lat = kwant.lattice.general(ta.identity(space_dim))
+
+            # Choose random periods.
+            while True:
+                periods = rng.randint(-10, 11, (sym_dim, space_dim))
+                if np.linalg.det(np.dot(periods, periods.T)) > 0.1:
+                    # Periods are reasonably linearly independent.
+                    break
+            syst = builder.Builder(kwant.TranslationalSymmetry(*periods))
+
+            for tag in rng.randint(-30, 31, (4, space_dim)):
+                # Add site.
+                syst[lat(*tag)] = None
+
+                # Test consistency with fill().
+                for point in 200 * rng.random_sample((10, space_dim)) - 100:
+                    closest = syst.closest(point)
+                    dist = closest.pos - point
+                    dist = ta.dot(dist, dist)
+                    syst2 = builder.Builder()
+                    syst2.fill(syst, inside_disc(point, 2 * dist), closest)
+                    assert syst2.closest(point) == closest
+                    for site in syst2.sites():
+                        dd = site.pos - point
+                        dd = ta.dot(dd, dd)
+                        assert dd >= 0.999999 * dist
+
+
 def test_iadd():
     lat = builder.SimpleSiteFamily()
 
-- 
GitLab