diff --git a/kwant/lattice.py b/kwant/lattice.py index 9b2e0995b76ab4ad8b598ef055868c5049833ce2..d117e04aa443c224f7843aab2d33f70224fb31cd 100644 --- a/kwant/lattice.py +++ b/kwant/lattice.py @@ -14,6 +14,7 @@ from math import sqrt import numpy as np import tinyarray as ta from . import builder +from .linalg import lll def general(prim_vecs, basis=None, name=''): @@ -96,67 +97,16 @@ class Polyatomic(object): for offset, sname in zip(basis, name)] # Sequence of primitive vectors of the lattice. self.prim_vecs = prim_vecs + # Precalculation of auxiliary arrays for real space calculations. + self._reduced_vecs, self._transf = lll.lll(prim_vecs) + self._voronoi = ta.dot(lll.voronoi(self._reduced_vecs), self._transf) def shape(self, function, start): - """ - Yield all the lattice sites which belong to a certain shape. - - Parameters - ---------- - function : a boolean function of real-space coordinates - A function which evaluates to True inside the desired shape. - start : real-valued vector - The starting point to the flood-fill algorithm. If the site - nearest to `start` is not inside the shape, no sites are returned. + """Yield sites belonging to a certain shape. - Returns - ------- - sites : sequence of `Site` objects - all the sites that belong to the lattice and fit inside the shape. + See `~kwant.lattice.Shape` for more information. """ - Site = builder.Site - - dim = len(start) - num_vecs = len(self.prim_vecs) - if dim != self.prim_vecs.shape[1]: - raise ValueError('Dimensionality of start position does not match' - ' the space dimensionality.') - sls = self.sublattices - deltas = [ta.array(i * (0,) + (1,) + (num_vecs - 1 - i) * (0,)) - for i in xrange(num_vecs)] - deltas += [-delta for delta in deltas] - - # Check if no sites are going to be added, to catch a common error. - empty = True - for sl in sls: - if function(sl(*sl.closest(start)).pos): - empty = False - if empty: - msg = 'No sites close to {0} are inside the desired shape.' - raise ValueError(msg.format(start)) - - # Continue to flood fill. - outer_shell = set(sl.closest(start) for sl in sls) - inner_shell = set() - while outer_shell: - tmp = set() - for tag in outer_shell: - vec = ta.dot(tag, self.prim_vecs) - any_hits = False - for sl in sls: - if not function(vec + sl.offset): - continue - yield Site(sl, tag, True) - any_hits = True - if not any_hits: - continue - for shift in deltas: - new_tag = tag + shift - if new_tag not in inner_shell and \ - new_tag not in outer_shell: - tmp.add(new_tag) - inner_shell = outer_shell - outer_shell = tmp + return Shape(self, function, start) def vec(self, int_vec): """ @@ -208,6 +158,9 @@ class Monatomic(builder.SiteFamily, Polyatomic): self.prim_vecs = prim_vecs self.inv_pv = ta.array(np.linalg.pinv(prim_vecs)) self.offset = offset + # Precalculation of auxiliary arrays for real space calculations. + self._reduced_vecs, self._transf = lll.lll(prim_vecs) + self._voronoi = ta.dot(lll.voronoi(self._reduced_vecs), self._transf) def short_array_repr(array): full = ' '.join([i.lstrip() for i in repr(array).split('\n')]) @@ -245,9 +198,23 @@ class Monatomic(builder.SiteFamily, Polyatomic): raise ValueError("Dimensionality mismatch.") return tag + def n_closest(self, pos, n=1): + """Find n sites closest to position `pos`. + + Returns + ------- + sites : numpy array + An array with sites coordinates. + """ + # TODO (Anton): transform to tinyarrays, once ta indexing is better. + return np.dot(lll.cvp(pos - self.offset, self._reduced_vecs, n), + self._transf.T) + def closest(self, pos): - """Find the site closest to position `pos`.""" - return ta.array(ta.round(ta.dot(pos - self.offset, self.inv_pv)), int) + """ + Find the lattice coordinates of the site closest to position `pos`. + """ + return ta.array(self.n_closest(pos)[0]) def pos(self, tag): """Return the real-space position of the site with a given tag.""" @@ -423,6 +390,93 @@ class TranslationalSymmetry(builder.Symmetry): return result +class Shape(object): + def __init__(self, lattice, function, start): + """A class for finding all the lattice sites in a shape. + + It uses a flood-fill algorithm, and takes into account + the symmetry of the builder to which it is provided, or + the symmetry, that is supplied to it after initialization. + + Parameters + ---------- + lattice : Polyatomic or Monoatomic lattice + Lattice, to which the resulting sites should belong. + function : callable + A function of real space coordinates, which should + return True for coordinates inside the shape, + and False otherwise. + start : float vector + The origin for the flood-fill algorithm. + """ + self.lat, self.func, self.start = lattice, function, start + + def __call__(self, builder_or_symmetry=None): + """ + Yield all the lattice sites which belong to a certain shape. + + Parameters + ---------- + builder_or_symmetry : Builder or Symmetry instance + The builder to which the site from the shape are added, or + the symmetry, such that the sites from the shape belong to + its fundamental domain. If not provided, trivial symmetry is + used. + + Returns + ------- + sites : sequence of `Site` objects + all the sites that belong to the lattice and fit inside the shape. + """ + Site = builder.Site + lat, func, start = self.lat, self.func, self.start + try: + symmetry = builder_or_symmetry.symmetry + except AttributeError: + symmetry = builder_or_symmetry + if symmetry is None: + symmetry = builder.NoSymmetry() + + sym_site = lambda lat, tag: symmetry.to_fd(Site(lat, tag, True)) + + dim = len(start) + if dim != lat.prim_vecs.shape[1]: + raise ValueError('Dimensionality of start position does not match' + ' the space dimensionality.') + sls = lat.sublattices + deltas = [ta.array(i) for i in lat._voronoi] + + # Check if no sites are going to be added, to catch a common error. + empty = True + for sl in sls: + if func(sym_site(sl, sl.closest(start)).pos): + empty = False + if empty: + msg = 'No sites close to {0} are inside the desired shape.' + raise ValueError(msg.format(start)) + + # Continue to flood fill. + tags = set([sl.closest(start) for sl in sls]) + new_sites = [sym_site(sl, tag) for sl in sls for tag in tags] + new_sites = [i for i in new_sites if func(i.pos)] + old_sites = set() + while new_sites: + tmp = set() + for site in new_sites: + yield site + tags = set((i.tag for i in new_sites)) + for tag in tags: + for shift in deltas: + for sl in sls: + site = sym_site(sl, tag + shift) + if site not in old_sites and \ + site not in new_sites and \ + func(site.pos): + tmp.add(site) + old_sites = new_sites + new_sites = tmp + + ################ Library of lattices (to be extended) def chain(a=1, name=''): diff --git a/kwant/linalg/tests/test_mumps.py b/kwant/linalg/tests/test_mumps.py index 857c5cc66fe9f5b29e46bec1cb0d8b1250527f72..71508633461b1695874d638034e1e2a9c12da9e2 100644 --- a/kwant/linalg/tests/test_mumps.py +++ b/kwant/linalg/tests/test_mumps.py @@ -14,7 +14,6 @@ except ImportError: from kwant.lattice import honeycomb from kwant.builder import Builder, HoppingKind -from nose.tools import assert_equal, assert_true from numpy.testing.decorators import skipif import numpy as np import scipy.sparse as sp diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py index 6bb6d6e2f7c890b6607038d6105e4a21f506b5c3..3ef7b918c34e2ab518c2cbac17b709731cbb3ef7 100644 --- a/kwant/tests/test_lattice.py +++ b/kwant/tests/test_lattice.py @@ -15,6 +15,19 @@ from numpy.testing import assert_equal from kwant import lattice +def test_closest(): + np.random.seed(4) + lat = lattice.general(((1, 0), (0.5, sqrt(3)/2))) + for i in range(50): + point = 20 * np.random.rand(2) + closest = lat(*lat.closest(point)).pos + assert np.linalg.norm(point - closest) <= 1 / sqrt(3) + lat = lattice.general(np.random.randn(3, 3)) + for i in range(50): + tag = np.random.randint(10, size=(3,)) + assert_equal(lat.closest(lat(*tag).pos), tag) + + def test_general(): for lat in (lattice.general(((1, 0), (0.5, 0.5))), lattice.general(((1, 0), (0.5, sqrt(3)/2)), @@ -36,7 +49,7 @@ def test_shape(): lat = lattice.general(((1, 0), (0.5, sqrt(3) / 2)), ((0, 0), (0, 1 / sqrt(3)))) - sites = list(lat.shape(in_circle, (0, 0))) + sites = list(lat.shape(in_circle, (0, 0))()) sites_alt = list() sl0, sl1 = lat.sublattices for x in xrange(-2, 3): @@ -47,8 +60,15 @@ def test_shape(): sites_alt.append(site) assert len(sites) == len(sites_alt) assert_equal(set(sites), set(sites_alt)) - assert_raises(ValueError, lat.shape(in_circle, (10, 10)).next) - + assert_raises(ValueError, lat.shape(in_circle, (10, 10))().next) + # Check if narrow ribbons work. + for period in (0, 1), (1, 0), (1, -1): + vec = lat.vec(period) + sym = lattice.TranslationalSymmetry(vec) + def shape(pos): + return abs(pos[0] * vec[1] - pos[1] * vec[0]) < 10 + sites = list(lat.shape(shape, (0, 0))(sym)) + assert len(sites) > 35 def test_translational_symmetry(): ts = lattice.TranslationalSymmetry diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index 7290117cd546e624e04c4d337bbc33495c5ad5b7..62056470a8dbb40a7de1f9910d20765bb766fb8c 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -33,9 +33,7 @@ def sys_2d(W=3, r1=3, r2=8): lead0 = kwant.Builder(sym_lead0) lead2 = kwant.Builder(sym_lead0) - def lead_shape(pos): - (x, y) = pos - return (-1 < x < 1) and (-W / 2 < y < W / 2) + lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) lead0[lat.shape(lead_shape, (0, 0))] = 4 * t lead2[lat.shape(lead_shape, (0, 0))] = 4 * t @@ -62,9 +60,7 @@ def sys_3d(W=3, r1=2, r2=4, a=1, t=1.0): sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0))) lead0 = kwant.Builder(sym_lead0) - def lead_shape(pos): - (x, y, z) = pos - return (-1 < x < 1) and (-W / 2 < y < W / 2) and abs(z) < 2 + lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2 lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t lead0[lat.nearest] = - t