From 27b7f67e3e10c7df66fb2a10f3aa948d54364bdf Mon Sep 17 00:00:00 2001
From: Kostas Vilkelis <kostasvilkelis@gmail.com>
Date: Fri, 10 May 2024 16:54:50 +0200
Subject: [PATCH] simplify tb from kwant build; rm copy and other redundant
 imports

---
 meanfi/kwant_helper/utils.py | 136 ++++++++++++++---------------------
 1 file changed, 54 insertions(+), 82 deletions(-)

diff --git a/meanfi/kwant_helper/utils.py b/meanfi/kwant_helper/utils.py
index cb7e3a5..336d932 100644
--- a/meanfi/kwant_helper/utils.py
+++ b/meanfi/kwant_helper/utils.py
@@ -1,5 +1,3 @@
-import inspect
-from copy import copy
 from itertools import product
 from typing import Callable
 
@@ -33,99 +31,73 @@ def builder_to_tb(
     :
         Data with sites and number of orbitals. Only if `return_data=True`.
     """
-    builder = copy(builder)
-    # Extract information from builder
-    dims = len(builder.symmetry.periods)
+    prim_vecs = builder.symmetry.periods
+    dims = len(prim_vecs)
+    sites_list = [*builder.sites()]
+    norbs_list = [site.family.norbs for site in builder.sites()]
+    norbs_list = [1 if norbs is None else norbs for norbs in norbs_list]
+
+    tb_norbs = sum(norbs_list)
+    tb_shape = (tb_norbs, tb_norbs)
     onsite_idx = tuple([0] * dims)
     h_0 = {}
-    sites_list = [*builder.sites()]
-    norbs_list = [site[0].norbs for site in builder.sites()]
-    positions_list = [site[0].pos for site in builder.sites()]
-    norbs_tot = sum(norbs_list)
-    # Extract onsite and hopping matrices.
-    # Based on `kwant.wraparound.wraparound`
-    # Onsite matrices
+
+    def _parse_val(val):
+        if callable(val):
+            param_keys = val.__code__.co_varnames[1:]
+            try:
+                val = val(site, *[params[key] for key in param_keys])
+            except KeyError as key:
+                raise KeyError(f"Parameter {key} not found in params.")
+        return val
+
     for site, val in builder.site_value_pairs():
-        site = builder.symmetry.to_fd(site)
-        atom = sites_list.index(site)
-        row = np.sum(norbs_list[:atom]) + range(norbs_list[atom])
-        col = copy(row)
-        row, col = np.array([*product(row, col)]).T
-        try:
-            _params = {}
-            for arg in inspect.getfullargspec(val).args:
-                if arg in params:
-                    _params[arg] = params[arg]
-            val = val(site, **_params)
-            data = val.flatten()
-        except Exception:
-            data = val.flatten()
+        site_idx = sites_list.index(site)
+        tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx])
+        row, col = np.array([*product(tb_idx, tb_idx)]).T
+
+        data = np.array(_parse_val(val)).flatten()
+        onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
+
         if onsite_idx in h_0:
-            h_0[onsite_idx] += coo_array(
-                (data, (row, col)), shape=(norbs_tot, norbs_tot)
-            ).toarray()
+            h_0[onsite_idx] += onsite_value
         else:
-            h_0[onsite_idx] = coo_array(
-                (data, (row, col)), shape=(norbs_tot, norbs_tot)
-            ).toarray()
-    # Hopping matrices
-    for hop, val in builder.hopping_value_pairs():
-        a, b = hop
-        b_dom = builder.symmetry.which(b)
-        b_fd = builder.symmetry.to_fd(b)
-        atoms = np.array([sites_list.index(a), sites_list.index(b_fd)])
-        row, col = [
-            np.sum(norbs_list[: atoms[0]]) + range(norbs_list[atoms[0]]),
-            np.sum(norbs_list[: atoms[1]]) + range(norbs_list[atoms[1]]),
+            h_0[onsite_idx] = onsite_value
+
+    for (site1, site2), val in builder.hopping_value_pairs():
+        site2_dom = builder.symmetry.which(site2)
+        site2_fd = builder.symmetry.to_fd(site2)
+
+        site1_idx, site2_idx = np.array(
+            [sites_list.index(site1), sites_list.index(site2_fd)]
+        )
+        tb_idx1, tb_idx2 = [
+            np.sum(norbs_list[:site1_idx]) + range(norbs_list[site1_idx]),
+            np.sum(norbs_list[:site2_idx]) + range(norbs_list[site2_idx]),
         ]
-        row, col = np.array([*product(row, col)]).T
-        try:
-            _params = {}
-            for arg in inspect.getfullargspec(val).args:
-                if arg in params:
-                    _params[arg] = params[arg]
-            val = val(a, b, **_params)
-            data = val.flatten()
-        except Exception:
-            data = val.flatten()
-        if tuple(b_dom) in h_0:
-            h_0[tuple(b_dom)] += coo_array(
-                (data, (row, col)), shape=(norbs_tot, norbs_tot)
-            ).toarray()
-            if np.linalg.norm(b_dom) == 0:
-                h_0[tuple(b_dom)] += (
-                    coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
-                    .toarray()
-                    .T.conj()
-                )
+        row, col = np.array([*product(tb_idx1, tb_idx2)]).T
+        data = np.array(_parse_val(val)).flatten()
+        hopping_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
+
+        hop_key = tuple(site2_dom)
+        hop_key_back = tuple(-site2_dom)
+        if hop_key in h_0:
+            h_0[hop_key] += hopping_value
+            if np.linalg.norm(site2_dom) == 0:
+                h_0[hop_key] += hopping_value.conj().T
             else:
-                # Hopping vector in the opposite direction
-                h_0[tuple(-b_dom)] += (
-                    coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
-                    .toarray()
-                    .T.conj()
-                )
+                h_0[hop_key_back] += hopping_value.conj().T
         else:
-            h_0[tuple(b_dom)] = coo_array(
-                (data, (row, col)), shape=(norbs_tot, norbs_tot)
-            ).toarray()
-            if np.linalg.norm(b_dom) == 0:
-                h_0[tuple(b_dom)] += (
-                    coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
-                    .toarray()
-                    .T.conj()
-                )
+            h_0[hop_key] = hopping_value
+            if np.linalg.norm(site2_dom) == 0:
+                h_0[hop_key] += hopping_value.conj().T
             else:
-                h_0[tuple(-b_dom)] = (
-                    coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
-                    .toarray()
-                    .T.conj()
-                )
+                h_0[hop_key_back] = hopping_value.conj().T
 
     if return_data:
         data = {}
         data["norbs"] = norbs_list
-        data["positions"] = positions_list
+        data["positions"] = [site.pos for site in sites_list]
         return h_0, data
     else:
         return h_0
-- 
GitLab