From e7c4fb137fc9fbf5b14841e4d3b93e1adba5aa93 Mon Sep 17 00:00:00 2001 From: Christoph Groth <christoph.groth@cea.fr> Date: Fri, 5 May 2017 00:17:35 +0200 Subject: [PATCH] implement Builder.closest() --- kwant/builder.py | 55 +++++++++++++++++++++++++++++++++++++ kwant/tests/test_builder.py | 41 +++++++++++++++++++++++++++ 2 files changed, 96 insertions(+) diff --git a/kwant/builder.py b/kwant/builder.py index d711463d..f245a0fb 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -20,6 +20,7 @@ import tinyarray as ta import numpy as np from scipy import sparse from . import system, graph, KwantDeprecationWarning, UserCodeError +from .linalg import lll from .operator import Density from .physics import DiscreteSymmetry from ._common import ensure_isinstance, get_parameters @@ -1177,6 +1178,60 @@ class Builder: a = self.symmetry.to_fd(site) return self._out_neighbors(a) + def closest(self, pos): + """Return the site that is closest to the given position. + + This function takes into account the symmetry of the builder. It is + assumed that the symmetry is a translational symmetry. + + This function executes in a time proportional to the number of sites, + so it is not efficient for large builders. It is especially slow for + builders with a symmetry, but such systems often contain only a limited + number of sites. + + """ + errmsg = ("Builder.closest() requires site families that provide " + "pos().\nThe following one does not:\n") + sym = self.symmetry + n = sym.num_directions + + if n: + # Determine basis in real space from first site. (The result from + # any site would do.) + I = ta.identity(n, int) + site = next(iter(self.H)) + space_basis = [sym.act(element, site).pos - site.pos + for element in I] + space_basis, transf = lll.lll(space_basis) + transf = ta.array(transf.T, int) + + tag_basis_cache = {} + dist = float('inf') + result = None + for site in self.H: + try: + site_pos = site.pos + except AttributeError: + raise AttributeError(errmsg + str(site.family)) + if n: + fam = site.family + tag_basis = tag_basis_cache.get(fam) + if tag_basis is None: + zero_site = Site(fam, ta.zeros(len(site.tag), int)) + tag_basis = [sym.act(element, zero_site).tag + for element in I] + tag_basis = ta.dot(transf, tag_basis) + tag_basis_cache[fam] = tag_basis + shift = lll.cvp(pos - site_pos, space_basis, 1)[0] + site = Site(fam, ta.dot(shift, tag_basis) + site.tag) + site_pos = site.pos + d = site_pos - pos + d = ta.dot(d, d) + if d < dist: + dist = d + result = site + return result + def __iadd__(self, other): for site, value in other.site_value_pairs(): self[site] = value diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py index 8c0d5747..6e22d183 100644 --- a/kwant/tests/test_builder.py +++ b/kwant/tests/test_builder.py @@ -16,6 +16,7 @@ import tinyarray as ta import numpy as np import kwant from kwant import builder +from kwant._common import ensure_rng def test_bad_keys(): @@ -824,6 +825,46 @@ def test_neighbors_not_in_single_domain(): raises(ValueError, sr.finalized) +def inside_disc(center, rr): + def shape(site): + d = site.pos - center + dd = ta.dot(d, d) + return dd <= rr + return shape + + +def test_closest(): + rng = ensure_rng(10) + for sym_dim in range(1, 4): + for space_dim in range(sym_dim, 4): + lat = kwant.lattice.general(ta.identity(space_dim)) + + # Choose random periods. + while True: + periods = rng.randint(-10, 11, (sym_dim, space_dim)) + if np.linalg.det(np.dot(periods, periods.T)) > 0.1: + # Periods are reasonably linearly independent. + break + syst = builder.Builder(kwant.TranslationalSymmetry(*periods)) + + for tag in rng.randint(-30, 31, (4, space_dim)): + # Add site. + syst[lat(*tag)] = None + + # Test consistency with fill(). + for point in 200 * rng.random_sample((10, space_dim)) - 100: + closest = syst.closest(point) + dist = closest.pos - point + dist = ta.dot(dist, dist) + syst2 = builder.Builder() + syst2.fill(syst, inside_disc(point, 2 * dist), closest) + assert syst2.closest(point) == closest + for site in syst2.sites(): + dd = site.pos - point + dd = ta.dot(dd, dd) + assert dd >= 0.999999 * dist + + def test_iadd(): lat = builder.SimpleSiteFamily() -- GitLab