Skip to content
Snippets Groups Projects
Commit 1470558c authored by Christoph Groth's avatar Christoph Groth
Browse files

generalize indexing: valid keys are now: sites/hoppings, iterables and functions

parent 10443390
Branches
Tags
No related merge requests found
......@@ -312,7 +312,7 @@ class HoppingKind(object):
self.group_a = group_a
self.group_b = group_b if group_b is not None else group_a
def match(self, builder):
def __call__(self, builder):
"""Return an iterator over all possible matching hoppings whose sites
are already present in the system. The hoppings do *not* have to be
already present in the system.
......@@ -637,44 +637,42 @@ class Builder(object):
def _for_each_in_key(self, key, f_site, f_hopp):
"""Perform an operation on each site or hopping in key.
Key may be
* a single site or hopping object,
* a non-tuple iterable of sites,
* a non-tuple iterable of hoppings.
Key may be (tested in this order)
* a single site or a hopping (a tuple of two sites),
* a function that returns a key when called with the builder as only
parameter,
* a non-tuple iterable of keys.
"""
if isinstance(key, Site):
f_site(key)
return 0
elif isinstance(key, tuple):
f_hopp(key)
elif isinstance(key, HoppingKind):
for hopping in key.match(self):
f_hopp(hopping)
return 1
elif callable(key):
return self._for_each_in_key(key(self), f_site, f_hopp)
else:
try:
ikey = iter(key)
except:
ret = None
for item in key:
last = self._for_each_in_key(item, f_site, f_hopp)
if last != ret:
if ret is None:
ret = last
elif last is not None:
raise KeyError(item)
return ret
except TypeError:
raise KeyError(key)
try:
first = next(ikey)
except StopIteration:
return
if isinstance(first, Site):
f_site(first)
for site in ikey:
f_site(site)
elif isinstance(first, tuple):
f_hopp(first)
for hopping in ikey:
f_hopp(hopping)
elif isinstance(first, HoppingKind):
for hopping in first.match(self):
f_hopp(hopping)
for kind in ikey:
for hopping in kind.match(self):
f_hopp(hopping)
else:
raise KeyError(first)
# The following clauses make sure that a useful error message is
# generated for infinitely iterable keys (like strings).
except KeyError as e:
if not e.args and key != item:
raise KeyError(key)
else:
raise
except RuntimeError:
raise KeyError()
def _get_site(self, site):
site = self.symmetry.to_fd(site)
......
......@@ -388,7 +388,7 @@ def test_wavefunc_ldos_consistency(wave_func, ldos):
h += h.conjugate().transpose()
b[site] = h
for hopping_kind in square.nearest:
for hop in hopping_kind.match(b):
for hop in hopping_kind(b):
b[hop] = 10 * np.random.rand(n, n) + 1j * np.random.rand(n, n)
sys.attach_lead(left_lead)
sys.attach_lead(top_lead)
......
......@@ -537,19 +537,18 @@ def test_HoppingKind():
sys[((h if max(x, y, z) % 2 else g)(x, y, z)
for x in range(4) for y in range(2) for z in range(4))] = None
for delta, ga, gb, n in [((1, 0, 0), g, h, 4),
((1, 0, 0), h, g, 7),
((0, 1, 0), g, h, 1),
((0, 4, 0), h, h, 21),
((0, 0, 1), g, h, 4)
]:
ph = list(builder.HoppingKind(delta, ga, gb).match(sys))
((1, 0, 0), h, g, 7),
((0, 1, 0), g, h, 1),
((0, 4, 0), h, h, 21),
((0, 0, 1), g, h, 4)]:
ph = list(builder.HoppingKind(delta, ga, gb)(sys))
assert_equal(len(ph), n)
ph = set(ph)
assert_equal(len(ph), n)
ph2 = list((
sym.to_fd(b, a) for a, b in
builder.HoppingKind(ta.negative(delta), gb, ga).match(sys)))
builder.HoppingKind(ta.negative(delta), gb, ga)(sys)))
assert_equal(len(ph2), n)
ph2 = set(ph2)
assert_equal(ph2, ph)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment