Commit 75ef39f9 authored by Joseph Weston's avatar Joseph Weston
Browse files

Merge branch 'vectorize_wraparound' into 'master'

Vectorize wraparound

See merge request !363
parents 61eef7d3 1fa9874b
Pipeline #34527 passed with stages
in 10 minutes and 7 seconds
......@@ -174,8 +174,15 @@ def herm_conj(value):
"""
if hasattr(value, 'conjugate'):
value = value.conjugate()
if hasattr(value, 'transpose'):
value = value.transpose()
if hasattr(value, 'shape'):
if len(value.shape) > 2:
is_ta = isinstance(value, (
ta.ndarray_int, ta.ndarray_float, ta.ndarray_complex))
value = np.swapaxes(value, -1, -2)
if is_ta:
value = ta.array(value)
else:
value = value.transpose()
return value
......
......@@ -23,17 +23,28 @@ if _plotter.mpl_available:
from matplotlib import pyplot # pragma: no flakes
def _simple_syst(lat, E=0, t=1+1j, sym=None):
def _simple_syst(lat, E=0, t=1+1j, sym=None, vectorize=False):
"""Create a builder for a simple infinite system."""
if not sym:
sym = kwant.TranslationalSymmetry(lat.vec((1, 0)), lat.vec((0, 1)))
# Build system with 2d periodic BCs. This system cannot be finalized in
# Kwant <= 1.2.
syst = kwant.Builder(sym)
syst = kwant.Builder(sym, vectorize=vectorize)
syst[lat.shape(lambda p: True, (0, 0))] = E
syst[lat.neighbors(1)] = t
return syst
def _onsite(site, arg1):
return arg1
def _hopping(site1, site2, arg2):
return len(site1) * arg2
def _make_bloch(symm, lat, vectorize=True):
syst = kwant.Builder(symmetry=symm, vectorize=vectorize)
syst[lat.shape(lambda x: True, [0] * lat.dim)] = _onsite
syst[lat.neighbors()] = _hopping
return wraparound(syst).finalized()
def test_consistence_with_bands(kx=1.9, nkys=31):
kys = np.linspace(-np.pi, np.pi, nkys)
......@@ -190,6 +201,104 @@ def test_symmetry():
assert np.all(orig == new)
def test_vectorize():
params = dict(k_x=0, k_y=0)
square = kwant.lattice.square(norbs=1)
syst = _simple_syst(square)
syst_vec = _simple_syst(square, vectorize=True)
# test FiniteVectorizedSystem
keep = None
wrapped = wraparound(syst, keep=keep).finalized()
vectorized = wraparound(syst_vec, keep=keep).finalized()
assert np.allclose(wrapped.hamiltonian_submatrix(params=params),
vectorized.hamiltonian_submatrix(params=params))
# test InfiniteVectorizedSystem
for keep in (0, 1):
wrapped = wraparound(syst, keep=keep).finalized()
vectorized = wraparound(syst_vec, keep=keep).finalized()
assert np.allclose(wrapped.cell_hamiltonian(params=params),
vectorized.cell_hamiltonian(params=params))
assert np.allclose(wrapped.inter_cell_hopping(params=params),
vectorized.inter_cell_hopping(params=params))
def test_minimal_terms():
for dim in [1, 2, 3]:
prim_vecs = np.eye(dim)
lat = kwant.lattice.general(prim_vecs, norbs=1)
for size_short in range(1, 4):
for size_long in range(1, 6):
size = [size_long] + [size_short] * (dim-1)
symm = kwant.TranslationalSymmetry(*(size * prim_vecs))
fsyst = _make_bloch(symm, lat)
if dim == 1:
assert len(fsyst.terms) <= 3
if dim == 2:
assert len(fsyst.terms) <= 5
if dim == 3:
assert len(fsyst.terms) <= 7
def test_wrap_vectorize_value_functions():
def onsite_simple(site, arg1):
return arg1
def onsite_vec(sa, arg1):
num_sites = len(sa.tags)
return np.repeat([arg1], repeats=num_sites, axis=0)
def hopping_simple(site1, site2, arg2):
return arg2
def hopping_vec(sa1, sa2, arg2):
num_sites = len(sa1.tags)
return np.repeat([arg2], repeats=num_sites, axis=0)
for norbs in [1, 2]:
lat = kwant.lattice.chain(norbs=norbs)
lat_shape = lat.shape(lambda x: True, [0] * lat.dim)
params = {'arg1': np.diag(np.random.rand(norbs)),
'arg2': (np.random.rand(norbs, norbs)
+ 1j * np.random.rand(norbs, norbs)),
'k_x': 0}
for num_sites in [1, 2, 3]:
symm = kwant.TranslationalSymmetry(lat.vec([num_sites]))
builder = kwant.Builder(symmetry=symm, vectorize=False)
builder_vec_simple = kwant.Builder(symmetry=symm, vectorize=True)
builder_vec = kwant.Builder(symmetry=symm, vectorize=True)
builder[lat_shape] = onsite_simple
builder[lat.neighbors()] = hopping_simple
builder_vec_simple[lat_shape] = onsite_simple
builder_vec_simple[lat.neighbors()] = hopping_simple
builder_vec[lat_shape] = onsite_vec
builder_vec[lat.neighbors()] = hopping_vec
wrapped = wraparound(builder).finalized()
vectorized_simple = wraparound(builder_vec_simple).finalized()
vectorized = wraparound(builder_vec).finalized()
ham = wrapped.hamiltonian_submatrix(params=params)
ham_vec_simple = vectorized_simple.hamiltonian_submatrix(
params=params)
ham_vec = vectorized.hamiltonian_submatrix(params=params)
assert (ham == ham_vec).all()
assert (ham_vec_simple == ham_vec).all()
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot_2d_bands():
chain = kwant.lattice.chain(norbs=1)
......
......@@ -40,6 +40,10 @@ def _set_signature(func, params):
for name in params]
func.__signature__ = inspect.Signature(params)
@memoize
def _callable_herm_conj(val):
"""Keep the same id for every 'val'."""
return HermConjOfFunc(val)
## This wrapper is needed so that finalized systems that
## have been wrapped can be queried for their symmetry, which
......@@ -183,9 +187,6 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
f.__signature__ = inspect.Signature(params.values())
return f
if builder.vectorize:
raise TypeError("'wraparound' does not work with vectorized Builders.")
try:
momenta = ['k_{}'.format(coordinate_names[i])
for i in range(len(builder.symmetry.periods))]
......@@ -225,6 +226,8 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
ret.particle_hole = None
ret.time_reversal = None
ret.vectorize = builder.vectorize
sites = {}
hops = collections.defaultdict(list)
......@@ -235,7 +238,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
# Move the sites to the FD of the remaining symmetry, this guarantees that
# every site in the new system is an image of an original FD site translated
# purely by the remaining symmetry.
sites[ret.symmetry.to_fd(site)] = [bind_site(val) if callable(val) else val]
sites[ret.symmetry.to_fd(site)] = [val] # a list to append wrapped hoppings
for hop, val in builder.hopping_value_pairs():
a, b = hop
......@@ -259,8 +262,8 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
sites[a].append(bind_hopping_as_site(b_dom, val))
else:
# The hopping remains a hopping.
if any(b_dom) or callable(val):
# The hopping got wrapped-around or is a function.
if any(b_dom):
# The hopping got wrapped-around.
val = bind_hopping(b_dom, val)
# Make sure that there is only one entry for each hopping
......@@ -269,23 +272,35 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
if (b_wa_r, a_r) in hops:
assert (a, b_wa) not in hops
if callable(val):
assert not isinstance(val, HermConjOfFunc)
val = HermConjOfFunc(val)
val = _callable_herm_conj(val)
else:
val = herm_conj(val)
hops[b_wa_r, a_r].append(val)
hops[b_wa_r, a_r].append((val, b_dom))
else:
hops[a, b_wa].append(val)
hops[a, b_wa].append((val, b_dom))
# Copy stuff into result builder, converting lists of more than one element
# into summing functions.
for site, vals in sites.items():
ret[site] = vals[0] if len(vals) == 1 else bind_sum(1, *vals)
for hop, vals in hops.items():
ret[hop] = vals[0] if len(vals) == 1 else bind_sum(2, *vals)
if len(vals) == 1:
# no need to bind onsites without extra wrapped hoppings
ret[site] = vals[0]
else:
val = vals[0]
vals[0] = bind_site(val) if callable(val) else val
ret[site] = bind_sum(1, *vals)
for hop, vals_doms in hops.items():
if len(vals_doms) == 1:
# no need to bind hoppings that are not already bound
val, b_dom = vals_doms[0]
ret[hop] = val
else:
new_vals = [bind_hopping(b_dom, val) if callable(val)
and not any(b_dom) # skip hoppings already bound
else val for val, b_dom in vals_doms]
ret[hop] = bind_sum(2, *new_vals)
return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment