From 8f411318fadf16b2f9f127862b7a108dd2129168 Mon Sep 17 00:00:00 2001
From: Kostas Vilkelis <kostasvilkelis@gmail.com>
Date: Mon, 13 May 2024 00:39:46 +0200
Subject: [PATCH] simplify builder to tb; introduce tb to builder

---
 meanfi/kwant_helper/utils.py | 93 +++++++++++++++++++++++++++++-------
 meanfi/tests/test_kwant.py   | 39 +++++++++++++++
 2 files changed, 114 insertions(+), 18 deletions(-)
 create mode 100644 meanfi/tests/test_kwant.py

diff --git a/meanfi/kwant_helper/utils.py b/meanfi/kwant_helper/utils.py
index 784d9ad..baf61a5 100644
--- a/meanfi/kwant_helper/utils.py
+++ b/meanfi/kwant_helper/utils.py
@@ -4,6 +4,7 @@ from typing import Callable
 import numpy as np
 from scipy.sparse import coo_array
 import kwant
+from kwant.builder import Site
 import kwant.lattice
 import kwant.builder
 
@@ -57,10 +58,7 @@ def builder_to_tb(
         data = np.array(val).flatten()
         onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
 
-        if onsite_idx in h_0:
-            h_0[onsite_idx] += onsite_value
-        else:
-            h_0[onsite_idx] = onsite_value
+        h_0[onsite_idx] = h_0.get(onsite_idx, 0) + onsite_value
 
     for (site1, site2), val in builder.hopping_value_pairs():
         site2_dom = builder.symmetry.which(site2)
@@ -87,28 +85,87 @@ def builder_to_tb(
 
         hop_key = tuple(site2_dom)
         hop_key_back = tuple(-site2_dom)
-        if hop_key in h_0:
-            h_0[hop_key] += hopping_value
-            if np.linalg.norm(site2_dom) == 0:
-                h_0[hop_key] += hopping_value.conj().T
-            else:
-                h_0[hop_key_back] += hopping_value.conj().T
-        else:
-            h_0[hop_key] = hopping_value
-            if np.linalg.norm(site2_dom) == 0:
-                h_0[hop_key] += hopping_value.conj().T
-            else:
-                h_0[hop_key_back] = hopping_value.conj().T
+        h_0[hop_key] = h_0.get(hop_key, 0) + hopping_value
+        h_0[hop_key_back] = h_0.get(hop_key_back, 0) + hopping_value.T.conj()
 
     if return_data:
         data = {}
-        data["norbs"] = norbs_list
-        data["positions"] = [site.pos for site in sites_list]
+        data["periods"] = prim_vecs
+        data["sites"] = sites_list
         return h_0, data
     else:
         return h_0
 
 
+def tb_to_builder(
+    h_0: _tb_type, sites_list: list[Site, ...], periods: np.ndarray
+) -> kwant.builder.Builder:
+    """
+    Construct a `kwant.builder.Builder` from a tight-binding dictionary.
+
+    Parameters
+    ----------
+    h_0 :
+        Tight-binding dictionary.
+    sites_list :
+        List of sites in the builder's unit cell.
+    periods :
+        2d array with periods of the translational symmetry.
+
+    Returns
+    -------
+    :
+        `kwant.builder.Builder` that corresponds to the tight-binding dictionary.
+    """
+
+    builder = kwant.Builder(kwant.TranslationalSymmetry(*periods))
+    onsite_idx = tuple([0] * len(list(h_0)[0]))
+
+    norbs_list = [site.family.norbs for site in sites_list]
+    norbs_list = [1 if norbs is None else norbs for norbs in norbs_list]
+
+    def site_to_tbIdxs(site):
+        site_idx = sites_list.index(site)
+        return (np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx])).astype(int)
+
+    # assemble the sites first
+    for site in sites_list:
+        tb_idxs = site_to_tbIdxs(site)
+        value = h_0[onsite_idx][
+            tb_idxs[0] : tb_idxs[-1] + 1, tb_idxs[0] : tb_idxs[-1] + 1
+        ]
+        builder[site] = value
+
+    # connect hoppings within the unit-cell
+    for site1, site2 in product(sites_list, sites_list):
+        if site1 == site2:
+            continue
+        tb_idxs1 = site_to_tbIdxs(site1)
+        tb_idxs2 = site_to_tbIdxs(site2)
+        value = h_0[onsite_idx][
+            tb_idxs1[0] : tb_idxs1[-1] + 1, tb_idxs2[0] : tb_idxs2[-1] + 1
+        ]
+        if np.all(value == 0):
+            continue
+        builder[(site1, site2)] = value
+
+    # connect hoppings between unit-cells
+    for key in h_0:
+        if key == onsite_idx:
+            continue
+        for site1, site2_fd in product(sites_list, sites_list):
+            site2 = builder.symmetry.act(key, site2_fd)
+            tb_idxs1 = site_to_tbIdxs(site1)
+            tb_idxs2 = site_to_tbIdxs(site2_fd)
+            value = h_0[key][
+                tb_idxs1[0] : tb_idxs1[-1] + 1, tb_idxs2[0] : tb_idxs2[-1] + 1
+            ]
+            if np.all(value == 0):
+                continue
+            builder[(site1, site2)] = value
+    return builder
+
+
 def build_interacting_syst(
     builder: kwant.builder.Builder,
     lattice: kwant.lattice.Polyatomic,
diff --git a/meanfi/tests/test_kwant.py b/meanfi/tests/test_kwant.py
new file mode 100644
index 0000000..0064782
--- /dev/null
+++ b/meanfi/tests/test_kwant.py
@@ -0,0 +1,39 @@
+import numpy as np
+import pytest
+import kwant
+
+from meanfi.kwant_helper.utils import builder_to_tb, tb_to_builder
+from meanfi.tb.utils import generate_tb_keys, guess_tb
+from meanfi.tb.tb import compare_dicts
+
+repeat_number = 3
+
+
+@pytest.mark.parametrize("seed", range(repeat_number))
+def test_kwant_conversion(seed):
+    """Test the gap prediction for the Hubbard model."""
+    np.random.seed(seed)
+    ndim = np.random.randint(1, 3)
+    cutoff = np.random.randint(1, 3)
+    sites_in_cell = np.random.randint(1, 4)
+    ndof_per_site = np.random.randint(1, 5)
+    keyList = generate_tb_keys(cutoff, ndim)
+
+    # set a dummy lattice to read sites from
+    lattice = kwant.lattice.general(np.eye(ndim), norbs=ndof_per_site)
+    dummy_tb = kwant.Builder(
+        kwant.TranslationalSymmetry(*sites_in_cell * lattice.prim_vecs)
+    )
+    for site in range(sites_in_cell):
+        dummy_tb[lattice(site, *[0 for _ in range(ndim - 1)])] = (
+            np.eye(ndof_per_site) * 2
+        )
+
+    # generate random and generate builder from it
+    random_tb = guess_tb(keyList, ndof_per_site * sites_in_cell)
+    random_builder = tb_to_builder(
+        random_tb, list(dummy_tb.sites()), dummy_tb.symmetry.periods
+    )
+    # convert builder back to tb and compare
+    random_builder_tb = builder_to_tb(random_builder)
+    compare_dicts(random_tb, random_builder_tb)
-- 
GitLab