diff --git a/meanfi/kwant_helper/utils.py b/meanfi/kwant_helper/utils.py index 336d932e816c7712ac5bbc0ba5cfc956aafb0d54..784d9ad4e696aec66a3e81b3eec72fa6bb25b76e 100644 --- a/meanfi/kwant_helper/utils.py +++ b/meanfi/kwant_helper/utils.py @@ -42,21 +42,19 @@ def builder_to_tb( onsite_idx = tuple([0] * dims) h_0 = {} - def _parse_val(val): + for site, val in builder.site_value_pairs(): + site_idx = sites_list.index(site) + tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx]) + row, col = np.array([*product(tb_idx, tb_idx)]).T + if callable(val): param_keys = val.__code__.co_varnames[1:] try: val = val(site, *[params[key] for key in param_keys]) except KeyError as key: raise KeyError(f"Parameter {key} not found in params.") - return val - - for site, val in builder.site_value_pairs(): - site_idx = sites_list.index(site) - tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx]) - row, col = np.array([*product(tb_idx, tb_idx)]).T - data = np.array(_parse_val(val)).flatten() + data = np.array(val).flatten() onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray() if onsite_idx in h_0: @@ -76,7 +74,15 @@ def builder_to_tb( np.sum(norbs_list[:site2_idx]) + range(norbs_list[site2_idx]), ] row, col = np.array([*product(tb_idx1, tb_idx2)]).T - data = np.array(_parse_val(val)).flatten() + + if callable(val): + param_keys = val.__code__.co_varnames[2:] + try: + val = val(site1, site2, *[params[key] for key in param_keys]) + except KeyError as key: + raise KeyError(f"Parameter {key} not found in params.") + + data = np.array(val).flatten() hopping_value = coo_array((data, (row, col)), shape=tb_shape).toarray() hop_key = tuple(site2_dom)