From 903bf515dc33576f344369afba056b5d5055db73 Mon Sep 17 00:00:00 2001
From: antoniolrm <am@antoniomanesco.org>
Date: Fri, 17 May 2024 16:13:42 +0200
Subject: [PATCH] simplify test to avoid repeated families

---
 meanfi/tests/test_kwant.py | 23 +++++++++++------------
 1 file changed, 11 insertions(+), 12 deletions(-)

diff --git a/meanfi/tests/test_kwant.py b/meanfi/tests/test_kwant.py
index 9b947fa..53b2c83 100644
--- a/meanfi/tests/test_kwant.py
+++ b/meanfi/tests/test_kwant.py
@@ -16,27 +16,26 @@ def test_kwant_conversion(seed):
     ndim = np.random.randint(1, 3)
     cutoff = np.random.randint(1, 3)
     sites_in_cell = np.random.randint(1, 4)
-    ndof_per_site = [np.random.randint(1, 5) for site in range(sites_in_cell)]
+    ndof_per_site = [np.random.randint(1, 4) for site in range(sites_in_cell)]
     keyList = generate_tb_keys(cutoff, ndim)
-    n_cells = np.random.randint(4)
+
+    vecs = np.random.rand(ndim, ndim)
 
     # set a dummy lattice to read sites from
     lattice = kwant.lattice.general(
-        np.random.rand(ndim, ndim),
-        basis=np.random.rand(sites_in_cell, ndim),
+        vecs,
+        basis=np.random.rand(sites_in_cell, ndim) @ vecs,
         norbs=ndof_per_site,
     )
 
-    dummy_tb = kwant.Builder(kwant.TranslationalSymmetry(*n_cells * lattice.prim_vecs))
-    for site in range(sites_in_cell):
-        for i, sublattice in enumerate(lattice.sublattices):
-            for n in range(n_cells):
-                dummy_tb[sublattice(site, *[n for _ in range(ndim - 1)])] = (
-                    np.eye(ndof_per_site[i]) * 2
-                )
+    dummy_tb = kwant.Builder(kwant.TranslationalSymmetry(*lattice.prim_vecs))
+    for i, sublattice in enumerate(lattice.sublattices):
+        dummy_tb[lattice.shape(lambda pos: True, tuple(ndim * [0]))] = np.eye(
+            ndof_per_site[i]
+        )
 
     # generate random and generate builder from it
-    random_tb = guess_tb(keyList, sum(ndof_per_site) * sites_in_cell * n_cells)
+    random_tb = guess_tb(keyList, sum(ndof_per_site))
     random_builder = tb_to_builder(
         random_tb, list(dummy_tb.sites()), dummy_tb.symmetry.periods
     )
-- 
GitLab