Commit 72a96c4a authored by Joseph Weston's avatar Joseph Weston
Browse files

merge change of Builder.fill() behavior

parents dc02e80e 20b86bd2
Pipeline #3895 failed with stages
in 69 minutes and 50 seconds
...@@ -16,7 +16,7 @@ import inspect ...@@ -16,7 +16,7 @@ import inspect
import tinyarray as ta import tinyarray as ta
import numpy as np import numpy as np
from scipy import sparse from scipy import sparse
from . import system, graph, UserCodeError from . import system, graph, KwantDeprecationWarning, UserCodeError
from .linalg import lll from .linalg import lll
from .operator import Density from .operator import Density
from .physics import DiscreteSymmetry from .physics import DiscreteSymmetry
...@@ -333,22 +333,16 @@ class Symmetry(metaclass=abc.ABCMeta): ...@@ -333,22 +333,16 @@ class Symmetry(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
def isstrictsupergroup(self, other): def has_subgroup(self, other):
"""Test whether `self` is a strict supergroup of `other`... """Test whether `self` has the subgroup `other`...
or, in other words, whether `other` is a subgroup of `self`. The or, in other words, whether `other` is a subgroup of `self`. The
reason why this is the abstract method (and not `issubgroup`) is that reason why this is the abstract method (and not `is_subgroup`) is that
in general it's not possible for a subgroup to know its supergroups. in general it's not possible for a subgroup to know its supergroups.
""" """
pass pass
def issubgroup(self, other):
"""Test whether `self` is a subgroup of `other`."""
return other.isstrictsupergroup(self)
__le__ = issubgroup
class NoSymmetry(Symmetry): class NoSymmetry(Symmetry):
"""A symmetry with a trivial symmetry group.""" """A symmetry with a trivial symmetry group."""
...@@ -389,8 +383,8 @@ class NoSymmetry(Symmetry): ...@@ -389,8 +383,8 @@ class NoSymmetry(Symmetry):
raise ValueError('Generators must be empty for NoSymmetry.') raise ValueError('Generators must be empty for NoSymmetry.')
return NoSymmetry(generators) return NoSymmetry(generators)
def isstrictsupergroup(self, other): def has_subgroup(self, other):
return False return isinstance(other, NoSymmetry)
...@@ -778,11 +772,6 @@ class Builder: ...@@ -778,11 +772,6 @@ class Builder:
Attaching a lead manually (without the use of `~Builder.attach_lead`) Attaching a lead manually (without the use of `~Builder.attach_lead`)
amounts to creating a `Lead` object and appending it to this list. amounts to creating a `Lead` object and appending it to this list.
``builder0 += builder1`` adds all the sites, hoppings, and leads of
``builder1`` to ``builder0``. Sites and hoppings present in both systems
are overwritten by those in ``builder1``. The leads of ``builder1`` are
appended to the leads of the system being extended.
.. warning:: .. warning::
If functions are used to set values in a builder with a symmetry, then If functions are used to set values in a builder with a symmetry, then
...@@ -1244,17 +1233,48 @@ class Builder: ...@@ -1244,17 +1233,48 @@ class Builder:
result = site result = site
return result return result
def __iadd__(self, other): def update(self, other):
"""Update builder from `other`.
All sites and hoppings of `other`, together with their values, are
written to `self`, overwriting already existing sites and hoppings.
The leads of `other` are appended to the leads of the system being
updated.
This method requires that both builders share the same symmetry.
"""
if (not self.symmetry.has_subgroup(other.symmetry)
or not other.symmetry.has_subgroup(self.symmetry)):
raise ValueError("Both builders involved in update() must have "
"equal symmetries.")
for site, value in other.site_value_pairs(): for site, value in other.site_value_pairs():
self[site] = value self[site] = value
for hop, value in other.hopping_value_pairs(): for hop, value in other.hopping_value_pairs():
self[hop] = value self[hop] = value
self.leads.extend(other.leads) self.leads.extend(other.leads)
def __iadd__(self, other):
warnings.warn("The += operator of builders is deprecated. Use "
"'Builder.update()' instead.", KwantDeprecationWarning,
stacklevel=2)
self.update(other)
return self return self
def fill(self, template, shape, start, *, overwrite=False, max_sites=10**7): def fill(self, template, shape, start, *, max_sites=10**7):
"""Populate builder using another one as a template. """Populate builder using another one as a template.
Starting from one or multiple sites, traverse the graph of the template
builder and copy sites and hoppings to the target builder. The
traversal stops at sites that are already present in the target and on
sites that are not inside the provided shape.
This function takes into account translational symmetry. As such,
typically the template will have a higher symmetry than the target.
Newly added sites are connected by hoppings to sites that were already
present. This facilitates construction of a system by a series of
calls to 'fill'.
Parameters Parameters
---------- ----------
template : `Builder` instance template : `Builder` instance
...@@ -1268,10 +1288,6 @@ class Builder: ...@@ -1268,10 +1288,6 @@ class Builder:
The site(s) at which the the flood-fill starts. If start is an The site(s) at which the the flood-fill starts. If start is an
iterable of numbers, the starting site will be iterable of numbers, the starting site will be
``template.closest(start)``. ``template.closest(start)``.
overwrite : boolean
Whether existing sites or hoppings in the target builder should be
overwritten. When overwriting is disabled (the default), existing
sites act as boundaries for the flood-fill.
max_sites : positive number max_sites : positive number
The maximal number of sites that may be added before The maximal number of sites that may be added before
``RuntimeError`` is raised. Used to prevent using up all memory. ``RuntimeError`` is raised. Used to prevent using up all memory.
...@@ -1303,7 +1319,7 @@ class Builder: ...@@ -1303,7 +1319,7 @@ class Builder:
templ_sym = template.symmetry templ_sym = template.symmetry
# Check that symmetries are commensurate. # Check that symmetries are commensurate.
if not self.symmetry <= templ_sym: if not templ_sym.has_subgroup(self.symmetry):
raise ValueError("Builder symmetry is not a subgroup of the " raise ValueError("Builder symmetry is not a subgroup of the "
"template symmetry") "template symmetry")
...@@ -1322,7 +1338,7 @@ class Builder: ...@@ -1322,7 +1338,7 @@ class Builder:
congested = True congested = True
for s in start: for s in start:
s = to_fd(s) s = to_fd(s)
if overwrite or s not in H: if s not in H:
congested = False congested = False
if shape(s): if shape(s):
active.add(s) active.add(s)
...@@ -1373,21 +1389,21 @@ class Builder: ...@@ -1373,21 +1389,21 @@ class Builder:
if (head_fd not in old_active if (head_fd not in old_active
and head_fd not in new_active): and head_fd not in new_active):
# The 'head' site has not been filled yet. # The 'head' site has not been filled yet.
if not shape(head_fd): if head_fd in H:
continue # The 'head' site exists. (It doesn't matter
# whether it's in the shape or not.) Fill the
if overwrite or head_fd not in H: # incoming edge as well to balance the hopping.
# Fill 'head' site.
new_active.add(head_fd)
H.setdefault(head_fd, [head_fd, None])
else:
# The 'head' site exists and won't be visited:
# fill the incoming edge as well to balance the
# hopping.
other_value = template._get_edge( other_value = template._get_edge(
*templ_sym.to_fd(head, tail)) *templ_sym.to_fd(head, tail))
self._set_edge(*to_fd(head, tail) self._set_edge(*to_fd(head, tail)
+ (other_value,)) + (other_value,))
else:
if not shape(head_fd):
# There is no site at 'head' and it's
# outside the shape.
continue
new_active.add(head_fd)
H.setdefault(head_fd, [head_fd, None])
# Fill the outgoing edge. # Fill the outgoing edge.
if head in old_heads: if head in old_heads:
......
...@@ -555,7 +555,7 @@ class TranslationalSymmetry(builder.Symmetry): ...@@ -555,7 +555,7 @@ class TranslationalSymmetry(builder.Symmetry):
raise ValueError('Generators must be sequences of integers.') raise ValueError('Generators must be sequences of integers.')
return TranslationalSymmetry(*ta.dot(generators, self.periods)) return TranslationalSymmetry(*ta.dot(generators, self.periods))
def isstrictsupergroup(self, other): def has_subgroup(self, other):
if isinstance(other, builder.NoSymmetry): if isinstance(other, builder.NoSymmetry):
return True return True
elif not isinstance(other, TranslationalSymmetry): elif not isinstance(other, TranslationalSymmetry):
......
...@@ -129,7 +129,7 @@ class VerySimpleSymmetry(builder.Symmetry): ...@@ -129,7 +129,7 @@ class VerySimpleSymmetry(builder.Symmetry):
def num_directions(self): def num_directions(self):
return 1 return 1
def isstrictsupergroup(self, other): def has_subgroup(self, other):
if isinstance(other, builder.NoSymmetry): if isinstance(other, builder.NoSymmetry):
return True return True
elif isinstance(other, VerySimpleSymmetry): elif isinstance(other, VerySimpleSymmetry):
...@@ -556,7 +556,8 @@ def test_hamiltonian_evaluation(): ...@@ -556,7 +556,8 @@ def test_hamiltonian_evaluation():
# test with infinite system # test with infinite system
inf_syst = kwant.Builder(VerySimpleSymmetry(2)) inf_syst = kwant.Builder(VerySimpleSymmetry(2))
inf_syst += syst for k, v in it.chain(syst.site_value_pairs(), syst.hopping_value_pairs()):
inf_syst[k] = v
inf_fsyst = inf_syst.finalized() inf_fsyst = inf_syst.finalized()
hop = tuple(map(inf_fsyst.sites.index, new_hop)) hop = tuple(map(inf_fsyst.sites.index, new_hop))
test_raising(inf_fsyst, hop) test_raising(inf_fsyst, hop)
...@@ -652,32 +653,36 @@ def test_fill(): ...@@ -652,32 +653,36 @@ def test_fill():
return -100 <= site.pos[0] < 100 return -100 <= site.pos[0] < 100
## Test that copying a builder by "fill" preserves everything. ## Test that copying a builder by "fill" preserves everything.
cubic = kwant.lattice.general(ta.identity(3)) for sym, func in [(kwant.TranslationalSymmetry(*np.diag([3, 4, 5])),
sym = kwant.TranslationalSymmetry((3, 0, 0), (0, 4, 0), (0, 0, 5)) lambda pos: True),
(builder.NoSymmetry(),
# Make a weird system. lambda pos: ta.dot(pos, pos) < 17)]:
orig = kwant.Builder(sym) cubic = kwant.lattice.general(ta.identity(3))
sites = cubic.shape(lambda pos: True, (0, 0, 0))
for i, site in enumerate(orig.expand(sites)): # Make a weird system.
if i % 7 == 0: orig = kwant.Builder(sym)
continue sites = cubic.shape(func, (0, 0, 0))
orig[site] = i for i, site in enumerate(orig.expand(sites)):
for i, hopp in enumerate(orig.expand(cubic.neighbors(1))): if i % 7 == 0:
if i % 11 == 0: continue
continue orig[site] = i
orig[hopp] = i * 1.2345 for i, hopp in enumerate(orig.expand(cubic.neighbors(1))):
for i, hopp in enumerate(orig.expand(cubic.neighbors(2))): if i % 11 == 0:
if i % 13 == 0: continue
continue orig[hopp] = i * 1.2345
orig[hopp] = i * 1j for i, hopp in enumerate(orig.expand(cubic.neighbors(2))):
if i % 13 == 0:
# Clone the original using fill. continue
clone = kwant.Builder(sym) orig[hopp] = i * 1j
clone.fill(orig, lambda s: True, (0, 0, 0))
# Clone the original using fill.
# Verify that both are identical. clone = kwant.Builder(sym)
assert set(clone.site_value_pairs()) == set(orig.site_value_pairs()) clone.fill(orig, lambda s: True, (0, 0, 0))
assert set(clone.hopping_value_pairs()) == set(orig.hopping_value_pairs())
# Verify that both are identical.
assert set(clone.site_value_pairs()) == set(orig.site_value_pairs())
assert (set(clone.hopping_value_pairs())
== set(orig.hopping_value_pairs()))
## Test for warning when "start" is out. ## Test for warning when "start" is out.
target = builder.Builder() target = builder.Builder()
...@@ -688,14 +693,13 @@ def test_fill(): ...@@ -688,14 +693,13 @@ def test_fill():
## Test filling of infinite builder. ## Test filling of infinite builder.
for n in [1, 2, 4]: for n in [1, 2, 4]:
sym_n = kwant.TranslationalSymmetry((n, 0)) sym_n = kwant.TranslationalSymmetry((n, 0))
for ow in [False, True]: for start in [g(0, 0), g(20, 0)]:
for start in [g(0, 0), g(20, 0)]: target = builder.Builder(sym_n)
target = builder.Builder(sym_n) sites = target.fill(template_1d, lambda s: True, start,
sites = target.fill(template_1d, lambda s: True, start, max_sites=10)
overwrite=ow, max_sites=10) assert len(sites) == n
assert len(sites) == n assert len(list(target.hoppings())) == n
assert len(list(target.hoppings())) == n assert set(sym_n.to_fd(s) for s in sites) == set(target.sites())
assert set(sym_n.to_fd(s) for s in sites) == set(target.sites())
## test max_sites ## test max_sites
target = builder.Builder() target = builder.Builder()
...@@ -711,14 +715,8 @@ def test_fill(): ...@@ -711,14 +715,8 @@ def test_fill():
target = builder.Builder() target = builder.Builder()
added_sites = target.fill(template_1d, line_200, g(0, 0)) added_sites = target.fill(template_1d, line_200, g(0, 0))
assert len(added_sites) == 200 assert len(added_sites) == 200
## test overwrite=False
with warns(RuntimeWarning): with warns(RuntimeWarning):
target.fill(template_1d, line_200, g(0, 0)) target.fill(template_1d, line_200, g(0, 0))
## test overwrite=True
added_sites = target.fill(template_1d, line_200, g(0, 0),
overwrite=True)
assert len(added_sites) == 200
## test multiplying unit cell size in 1D ## test multiplying unit cell size in 1D
n_cells = 10 n_cells = 10
...@@ -779,19 +777,55 @@ def test_fill(): ...@@ -779,19 +777,55 @@ def test_fill():
target = builder.Builder(kwant.TranslationalSymmetry((-2,))) target = builder.Builder(kwant.TranslationalSymmetry((-2,)))
target[lat(0)] = None target[lat(0)] = None
to_target_fd = target.symmetry.to_fd to_target_fd = target.symmetry.to_fd
# refuses to fill the target because target already contains the starting # Refuses to fill the target because target already contains the starting
# site and 'overwrite == False'. # site.
with warns(RuntimeWarning): with warns(RuntimeWarning):
target.fill(template, lambda x: True, lat(0)) target.fill(template, lambda x: True, lat(0))
# should only add a single site (and hopping) # should only add a single site (and hopping)
new_sites = target.fill(template, lambda x: True, lat(1), overwrite=False) new_sites = target.fill(template, lambda x: True, lat(1))
assert target[lat(0)] is None # should not be overwritten by template assert target[lat(0)] is None # should not be overwritten by template
assert target[lat(-1)] == template[lat(0)] assert target[lat(-1)] == template[lat(0)]
assert len(new_sites) == 1 assert len(new_sites) == 1
assert to_target_fd(new_sites[0]) == to_target_fd(lat(-1)) assert to_target_fd(new_sites[0]) == to_target_fd(lat(-1))
def test_fill_sticky():
"""Test that adjacent regions are properly interconnected when filled
separately.
"""
# Generate model template.
lat = kwant.lattice.kagome()
template = kwant.Builder(kwant.TranslationalSymmetry(
lat.vec((1, 0)), lat.vec((0, 1))))
for i, sl in enumerate(lat.sublattices):
template[sl(0, 0)] = i
for i in range(1, 3):
for j, hop in enumerate(template.expand(lat.neighbors(i))):
template[hop] = j * 1j
def disk(site):
pos = site.pos
return ta.dot(pos, pos) < 13
def halfplane(site):
return ta.dot(site.pos - (-1, 1), (-0.9, 0.63)) > 0
# Fill in one go.
syst0 = kwant.Builder()
syst0.fill(template, disk, (0, 0))
# Fill in two stages.
syst1 = kwant.Builder()
syst1.fill(template, lambda s: disk(s) and halfplane(s), (-2, 1))
syst1.fill(template, lambda s: disk(s) and not halfplane(s), (0, 0))
# Verify that both results are identical.
assert set(syst0.site_value_pairs()) == set(syst1.site_value_pairs())
assert (set(syst0.hopping_value_pairs())
== set(syst1.hopping_value_pairs()))
def test_attach_lead(): def test_attach_lead():
fam = builder.SimpleSiteFamily() fam = builder.SimpleSiteFamily()
fam_noncommensurate = builder.SimpleSiteFamily(name='other') fam_noncommensurate = builder.SimpleSiteFamily(name='other')
...@@ -907,7 +941,7 @@ def test_closest(): ...@@ -907,7 +941,7 @@ def test_closest():
assert dd >= 0.999999 * dist assert dd >= 0.999999 * dist
def test_iadd(): def test_update():
lat = builder.SimpleSiteFamily() lat = builder.SimpleSiteFamily()
syst = builder.Builder() syst = builder.Builder()
...@@ -930,7 +964,7 @@ def test_iadd(): ...@@ -930,7 +964,7 @@ def test_iadd():
lead1 = builder.BuilderLead(lead1, [lat(2,)]) lead1 = builder.BuilderLead(lead1, [lat(2,)])
other_syst.leads.append(lead1) other_syst.leads.append(lead1)
syst += other_syst syst.update(other_syst)
assert syst.leads == [lead0, lead1] assert syst.leads == [lead0, lead1]
expected = sorted([((0,), 1), ((1,), 2), ((2,), 2)]) expected = sorted([((0,), 1), ((1,), 2), ((2,), 2)])
assert sorted(((s.tag, v) for s, v in syst.site_value_pairs())) == expected assert sorted(((s.tag, v) for s, v in syst.site_value_pairs())) == expected
......
...@@ -251,21 +251,23 @@ def test_symmetry_act(): ...@@ -251,21 +251,23 @@ def test_symmetry_act():
sym.act(ta.array(el), site) sym.act(ta.array(el), site)
def test_symmetry_subgroup(): def test_symmetry_has_subgroup():
rng = np.random.RandomState(0) rng = np.random.RandomState(0)
## test whether actual subgroups are detected as such ## test whether actual subgroups are detected as such
vecs = rng.randn(3, 3) vecs = rng.randn(3, 3)
sym1 = lattice.TranslationalSymmetry(*vecs) sym1 = lattice.TranslationalSymmetry(*vecs)
assert sym1 >= sym1 ns = builder.NoSymmetry()
assert sym1 >= builder.NoSymmetry() assert ns.has_subgroup(ns)
assert sym1 >= lattice.TranslationalSymmetry(2 * vecs[0], assert sym1.has_subgroup(sym1)
3 * vecs[1] + 4 * vecs[2]) assert sym1.has_subgroup(ns)
assert not sym1 <= lattice.TranslationalSymmetry(*(0.8 * vecs)) assert sym1.has_subgroup(lattice.TranslationalSymmetry(
2 * vecs[0], 3 * vecs[1] + 4 * vecs[2]))
assert not lattice.TranslationalSymmetry(*(0.8 * vecs)).has_subgroup(sym1)
## test subgroup creation ## test subgroup creation
for dim in range(1, 4): for dim in range(1, 4):
generators = rng.randint(10, size=(dim, 3)) generators = rng.randint(10, size=(dim, 3))
assert sym1.subgroup(*generators) <= sym1 assert sym1.has_subgroup(sym1.subgroup(*generators))
# generators are not linearly independent # generators are not linearly independent
with raises(ValueError): with raises(ValueError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment