From 488dd431e93f51c1fe420ae6cde434be88f34e73 Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Mon, 9 Sep 2019 13:41:44 +0200
Subject: [PATCH] add simple test for vectorized finalized Builders

---
 kwant/tests/test_builder.py | 104 ++++++++++++++++++++++++++++++++++++
 1 file changed, 104 insertions(+)

diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py
index f9832e38..575d0caf 100644
--- a/kwant/tests/test_builder.py
+++ b/kwant/tests/test_builder.py
@@ -636,6 +636,110 @@ def test_hamiltonian_evaluation(vectorize):
     test_raising(inf_fsyst, hop)
 
 
+def test_vectorized_hamiltonian_evaluation():
+
+    def onsite(site):
+        return site.tag[0]
+
+    def vectorized_onsite(sites):
+        return sites.tags[:, 0]
+
+    def hopping(to_site, from_site):
+        a, b = to_site.tag, from_site.tag
+        return a[0] + b[0] + 1j * (a[1] - b[1])
+
+    def vectorized_hopping(to_sites, from_sites):
+        a, b = to_sites.tags, from_sites.tags
+        return a[:, 0] + b[:, 0] + 1j * (a[:, 1] - b[:, 1])
+
+    tags = [(0, 0), (1, 1), (2, 2), (3, 3)]
+    edges = [(0, 1), (0, 2), (0, 3), (1, 2)]
+
+    fam = builder.SimpleSiteFamily(norbs=1)
+    sites = [fam(*tag) for tag in tags]
+    hops = [(fam(*tags[i]), fam(*tags[j])) for (i, j) in edges]
+
+    syst_simple = builder.Builder(vectorize=False)
+    syst_simple[sites] = onsite
+    syst_simple[hops] = hopping
+    fsyst_simple = syst_simple.finalized()
+
+    syst_vectorized = builder.Builder(vectorize=True)
+    syst_vectorized[sites] = vectorized_onsite
+    syst_vectorized[hops] = vectorized_hopping
+    fsyst_vectorized = syst_vectorized.finalized()
+
+    assert fsyst_vectorized.graph.num_nodes == len(tags)
+    assert fsyst_vectorized.graph.num_edges == 2 * len(edges)
+    assert len(fsyst_vectorized.site_arrays) == 1
+    assert fsyst_vectorized.site_arrays[0] == system.SiteArray(fam, tags)
+
+    assert np.allclose(
+        fsyst_simple.hamiltonian_submatrix(),
+        fsyst_vectorized.hamiltonian_submatrix(),
+    )
+
+    for i in range(len(tags)):
+        site = fsyst_vectorized.sites[i]
+        assert site in sites
+        assert (
+            fsyst_vectorized.hamiltonian(i, i)
+            == fsyst_simple.hamiltonian(i, i))
+
+    for t, h in fsyst_vectorized.graph:
+        assert (
+            fsyst_vectorized.hamiltonian(t, h)
+            == fsyst_simple.hamiltonian(t, h))
+
+    # Test infinite system, including hoppings that go both ways into
+    # the next cell
+    lat = kwant.lattice.square(norbs=1)
+
+    syst_vectorized = builder.Builder(kwant.TranslationalSymmetry((-1, 0)),
+                                      vectorize=True)
+    syst_vectorized[lat(0, 0)] = 4
+    syst_vectorized[lat(0, 1)] = 5
+    syst_vectorized[lat(0, 2)] = vectorized_onsite
+    syst_vectorized[(lat(1, 0), lat(0, 0))] = 1j
+    syst_vectorized[(lat(2, 1), lat(1, 1))] = vectorized_hopping
+    fsyst_vectorized = syst_vectorized.finalized()
+
+    syst_simple = builder.Builder(kwant.TranslationalSymmetry((-1, 0)),
+                                      vectorize=False)
+    syst_simple[lat(0, 0)] = 4
+    syst_simple[lat(0, 1)] = 5
+    syst_simple[lat(0, 2)] = onsite
+    syst_simple[(lat(1, 0), lat(0, 0))] = 1j
+    syst_simple[(lat(2, 1), lat(1, 1))] = hopping
+    fsyst_simple = syst_simple.finalized()
+
+    assert np.allclose(
+        fsyst_vectorized.hamiltonian_submatrix(),
+        fsyst_simple.hamiltonian_submatrix(),
+    )
+    assert np.allclose(
+        fsyst_vectorized.cell_hamiltonian(),
+        fsyst_simple.cell_hamiltonian(),
+    )
+    assert np.allclose(
+        fsyst_vectorized.inter_cell_hopping(),
+        fsyst_simple.inter_cell_hopping(),
+    )
+
+
+def test_vectorized_requires_norbs():
+
+    # Catch deprecation warning for lack of norbs
+    with warnings.catch_warnings():
+        warnings.simplefilter("ignore")
+        fam = builder.SimpleSiteFamily()
+
+    syst = builder.Builder(vectorize=True)
+    syst[fam(0, 0)] = 1
+
+    raises(ValueError, syst.finalized)
+
+
 def test_dangling():
     def make_system():
         #        1
-- 
GitLab