From 451e0b7111dce7baaeadf78b75a65ef248e63b6b Mon Sep 17 00:00:00 2001
From: antoniolrm <am@antoniomanesco.org>
Date: Fri, 17 May 2024 17:30:02 +0200
Subject: [PATCH] add supercell test with callables

---
 meanfi/tests/test_kwant.py | 81 +++++++++++++++++++++++++++++++++++++-
 1 file changed, 80 insertions(+), 1 deletion(-)

diff --git a/meanfi/tests/test_kwant.py b/meanfi/tests/test_kwant.py
index 53b2c83..a39c289 100644
--- a/meanfi/tests/test_kwant.py
+++ b/meanfi/tests/test_kwant.py
@@ -2,6 +2,8 @@ import numpy as np
 import pytest
 import kwant
 
+import itertools as it
+
 from meanfi.kwant_helper.utils import builder_to_tb, tb_to_builder
 from meanfi.tb.utils import generate_tb_keys, guess_tb
 from meanfi.tb.tb import compare_dicts
@@ -40,5 +42,82 @@ def test_kwant_conversion(seed):
         random_tb, list(dummy_tb.sites()), dummy_tb.symmetry.periods
     )
     # convert builder back to tb and compare
-    random_builder_tb = builder_to_tb(random_builder)
+    random_builder_tb = builder_to_tb(random_builder, return_data=True)
     compare_dicts(random_tb, random_builder_tb)
+
+
+@pytest.mark.parametrize("seed", range(repeat_number))
+def test_kwant_supercell(seed):
+    np.random.seed(seed)
+    ndim = np.random.randint(1, 3)
+    cutoff = np.random.randint(1, 3)
+    sites_in_cell = np.random.randint(1, 4)
+    ndof_per_site = [np.random.randint(1, 4) for site in range(sites_in_cell)]
+    keyList = generate_tb_keys(cutoff, ndim)
+    n_cells = np.random.randint(1, 4)
+
+    vecs = np.random.rand(ndim, ndim)
+
+    # set a dummy lattice to read sites from
+    lattice = kwant.lattice.general(
+        vecs,
+        basis=np.random.rand(sites_in_cell, ndim) @ vecs,
+        norbs=ndof_per_site,
+    )
+
+    def random_matrix_kwant_digest(n, m, k):
+        matrix = np.zeros((n, m))
+        for i in zip(it.product(range(n), range(m))):
+            matrix[*i[0]] = kwant.digest.uniform(str(n * m * np.prod(i[0]) + k))
+        return matrix
+
+    def onsite(site, alpha, beta):
+        n = site.family.norbs
+        amplitude = alpha * random_matrix_kwant_digest(n, n, 0)
+        phase = 1j * 2 * np.pi * beta * random_matrix_kwant_digest(n, n, 1)
+        onsite_matrix = amplitude * phase
+        onsite_matrix += onsite_matrix.conj().T
+        return onsite_matrix
+
+    def hopping(site1, site2, gamma, delta):
+        n1 = site1.family.norbs
+        n2 = site2.family.norbs
+        amplitude = gamma * random_matrix_kwant_digest(n1, n2, 0)
+        phase = 1j * 2 * np.pi * delta * random_matrix_kwant_digest(n1, n2, 1)
+        hopping_matrix = amplitude * phase
+        return hopping_matrix
+
+    random_builder = kwant.Builder(
+        kwant.TranslationalSymmetry(*n_cells * lattice.prim_vecs)
+    )
+    for i, sublattice in enumerate(lattice.sublattices):
+        random_builder[lattice.shape(lambda pos: True, tuple(ndim * [0]))] = onsite
+    random_builder[lattice.neighbors()] = hopping
+
+    params_num = np.random.rand(4)
+    params = dict(
+        alpha=params_num[0],
+        beta=params_num[1],
+        gamma=params_num[2],
+        delta=params_num[3],
+    )
+
+    random_tb, data = builder_to_tb(random_builder, params=params, return_data=True)
+    random_builder_test = tb_to_builder(random_tb, data["sites"], data["periods"])
+    for site_pair in zip(it.product(data["sites"], data["sites"])):
+        site1, site2 = site_pair[0]
+        if site1 == site2:
+            assert np.isclose(
+                random_builder[site1](site=site1, alpha=params["alpha"], beta=params["beta"]),
+                random_builder_test[site1],
+            ).all()
+        else:
+            try:
+                assert np.isclose(
+                    random_builder[site1, site2](site1, site2, gamma=params["gamma"], delta=params["delta"]),
+                    random_builder_test[site1, site2],
+                ).all()
+            except KeyError:
+                continue
+            except:
+                raise
\ No newline at end of file
-- 
GitLab