From b4e2e1bdafcc803c9c734bde930e8ff6e2519496 Mon Sep 17 00:00:00 2001 From: Kostas Vilkelis <kostasvilkelis@gmail.com> Date: Sat, 11 May 2024 01:34:24 +0200 Subject: [PATCH] resolve bug in param parser --- meanfi/kwant_helper/utils.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/meanfi/kwant_helper/utils.py b/meanfi/kwant_helper/utils.py index 336d932..784d9ad 100644 --- a/meanfi/kwant_helper/utils.py +++ b/meanfi/kwant_helper/utils.py @@ -42,21 +42,19 @@ def builder_to_tb( onsite_idx = tuple([0] * dims) h_0 = {} - def _parse_val(val): + for site, val in builder.site_value_pairs(): + site_idx = sites_list.index(site) + tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx]) + row, col = np.array([*product(tb_idx, tb_idx)]).T + if callable(val): param_keys = val.__code__.co_varnames[1:] try: val = val(site, *[params[key] for key in param_keys]) except KeyError as key: raise KeyError(f"Parameter {key} not found in params.") - return val - - for site, val in builder.site_value_pairs(): - site_idx = sites_list.index(site) - tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx]) - row, col = np.array([*product(tb_idx, tb_idx)]).T - data = np.array(_parse_val(val)).flatten() + data = np.array(val).flatten() onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray() if onsite_idx in h_0: @@ -76,7 +74,15 @@ def builder_to_tb( np.sum(norbs_list[:site2_idx]) + range(norbs_list[site2_idx]), ] row, col = np.array([*product(tb_idx1, tb_idx2)]).T - data = np.array(_parse_val(val)).flatten() + + if callable(val): + param_keys = val.__code__.co_varnames[2:] + try: + val = val(site1, site2, *[params[key] for key in param_keys]) + except KeyError as key: + raise KeyError(f"Parameter {key} not found in params.") + + data = np.array(val).flatten() hopping_value = coo_array((data, (row, col)), shape=tb_shape).toarray() hop_key = tuple(site2_dom) -- GitLab