Skip to content
Snippets Groups Projects
Commit c0d082aa authored by Joseph Weston's avatar Joseph Weston
Browse files

add failing vectorization test

The next commit(s) will implement the necessary changes to the
operator module to make this test pass.
parent 39b4ca8d
No related branches found
No related tags found
No related merge requests found
......@@ -538,10 +538,24 @@ def random_onsite(i):
return (2 + kwant.digest.uniform(i.tag)) * sigmaz
def vectorized_random_onsite(sites):
t = (np.array([kwant.digest.uniform(tag) for tag in sites.tags])
.reshape(-1, 1, 1))
return (2 + t) * sigmaz
def random_hopping(i, j):
return (-1 + kwant.digest.uniform(i.tag + j.tag)) * sigmay
def vectorized_random_hopping(sites_a, sites_b):
t = np.array([
kwant.digest.uniform(tag_a + tag_b)
for tag_a, tag_b in zip(sites_a.tags, sites_b.tags)
]).reshape(-1, 1, 1)
return (-1 + t) * sigmay
def f_sigmay(i):
return sigma0
......@@ -575,3 +589,47 @@ def test_pickling(A):
for op in ops:
loaded_op = pickle.loads(pickle.dumps(op))
assert np.all(op(wf) == loaded_op(wf))
@pytest.mark.parametrize("A", opservables)
def test_vectorization(A):
# We need to test non/vectorized systems with non/vectorized operators
def onsite(site):
t = kwant.digest.uniform(site.tag)
return t * sigmay + (1 - t) * sigmaz
def vectorized_onsite(sites):
t = np.array([kwant.digest.uniform(tag) for tag in sites.tags])
t = t.reshape(-1, 1, 1)
return t * sigmay + (1 - t) * sigmaz
lat = kwant.lattice.square(norbs=2)
# non-vectorized system
syst = kwant.Builder(vectorize=False)
syst[(lat(i, j) for i in range(5) for j in range(5))] = random_onsite
syst[lat.neighbors()] = random_hopping
fsyst = syst.finalized()
# vectorized system
vsyst = kwant.Builder(vectorize=True)
vsyst[(lat(i, j) for i in range(5) for j in range(5))] = vectorized_random_onsite
vsyst[lat.neighbors()] = vectorized_random_hopping
vfsyst = vsyst.finalized()
wf = np.random.rand(2 * len(fsyst.sites))
# vectorized and non-vectorized operators
op = A(fsyst, onsite)
vectorized_op = A(vfsyst, vectorized_onsite)
np.testing.assert_array_equal(op(wf), vectorized_op(wf))
# System is vectorized, and onsite is not *and* is incompatible
# because it uses 'site.tag', which does not exist for SiteArrays.
bad_operator = A(vfsyst, onsite)
with pytest.raises(kwant._common.UserCodeError) as excinfo:
bad_operator(wf)
assert "did you remember to vectorize" in str(excinfo.value).lower()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment