From b4e2e1bdafcc803c9c734bde930e8ff6e2519496 Mon Sep 17 00:00:00 2001
From: Kostas Vilkelis <kostasvilkelis@gmail.com>
Date: Sat, 11 May 2024 01:34:24 +0200
Subject: [PATCH] resolve bug in param parser

---
 meanfi/kwant_helper/utils.py | 24 +++++++++++++++---------
 1 file changed, 15 insertions(+), 9 deletions(-)

diff --git a/meanfi/kwant_helper/utils.py b/meanfi/kwant_helper/utils.py
index 336d932..784d9ad 100644
--- a/meanfi/kwant_helper/utils.py
+++ b/meanfi/kwant_helper/utils.py
@@ -42,21 +42,19 @@ def builder_to_tb(
     onsite_idx = tuple([0] * dims)
     h_0 = {}
 
-    def _parse_val(val):
+    for site, val in builder.site_value_pairs():
+        site_idx = sites_list.index(site)
+        tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx])
+        row, col = np.array([*product(tb_idx, tb_idx)]).T
+
         if callable(val):
             param_keys = val.__code__.co_varnames[1:]
             try:
                 val = val(site, *[params[key] for key in param_keys])
             except KeyError as key:
                 raise KeyError(f"Parameter {key} not found in params.")
-        return val
-
-    for site, val in builder.site_value_pairs():
-        site_idx = sites_list.index(site)
-        tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx])
-        row, col = np.array([*product(tb_idx, tb_idx)]).T
 
-        data = np.array(_parse_val(val)).flatten()
+        data = np.array(val).flatten()
         onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
 
         if onsite_idx in h_0:
@@ -76,7 +74,15 @@ def builder_to_tb(
             np.sum(norbs_list[:site2_idx]) + range(norbs_list[site2_idx]),
         ]
         row, col = np.array([*product(tb_idx1, tb_idx2)]).T
-        data = np.array(_parse_val(val)).flatten()
+
+        if callable(val):
+            param_keys = val.__code__.co_varnames[2:]
+            try:
+                val = val(site1, site2, *[params[key] for key in param_keys])
+            except KeyError as key:
+                raise KeyError(f"Parameter {key} not found in params.")
+
+        data = np.array(val).flatten()
         hopping_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
 
         hop_key = tuple(site2_dom)
-- 
GitLab