diff --git a/meanfi/kwant_helper/utils.py b/meanfi/kwant_helper/utils.py index cb7e3a543cee488c88984cb6dd5e5d6966d90f00..336d932e816c7712ac5bbc0ba5cfc956aafb0d54 100644 --- a/meanfi/kwant_helper/utils.py +++ b/meanfi/kwant_helper/utils.py @@ -1,5 +1,3 @@ -import inspect -from copy import copy from itertools import product from typing import Callable @@ -33,99 +31,73 @@ def builder_to_tb( : Data with sites and number of orbitals. Only if `return_data=True`. """ - builder = copy(builder) - # Extract information from builder - dims = len(builder.symmetry.periods) + prim_vecs = builder.symmetry.periods + dims = len(prim_vecs) + sites_list = [*builder.sites()] + norbs_list = [site.family.norbs for site in builder.sites()] + norbs_list = [1 if norbs is None else norbs for norbs in norbs_list] + + tb_norbs = sum(norbs_list) + tb_shape = (tb_norbs, tb_norbs) onsite_idx = tuple([0] * dims) h_0 = {} - sites_list = [*builder.sites()] - norbs_list = [site[0].norbs for site in builder.sites()] - positions_list = [site[0].pos for site in builder.sites()] - norbs_tot = sum(norbs_list) - # Extract onsite and hopping matrices. - # Based on `kwant.wraparound.wraparound` - # Onsite matrices + + def _parse_val(val): + if callable(val): + param_keys = val.__code__.co_varnames[1:] + try: + val = val(site, *[params[key] for key in param_keys]) + except KeyError as key: + raise KeyError(f"Parameter {key} not found in params.") + return val + for site, val in builder.site_value_pairs(): - site = builder.symmetry.to_fd(site) - atom = sites_list.index(site) - row = np.sum(norbs_list[:atom]) + range(norbs_list[atom]) - col = copy(row) - row, col = np.array([*product(row, col)]).T - try: - _params = {} - for arg in inspect.getfullargspec(val).args: - if arg in params: - _params[arg] = params[arg] - val = val(site, **_params) - data = val.flatten() - except Exception: - data = val.flatten() + site_idx = sites_list.index(site) + tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx]) + row, col = np.array([*product(tb_idx, tb_idx)]).T + + data = np.array(_parse_val(val)).flatten() + onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray() + if onsite_idx in h_0: - h_0[onsite_idx] += coo_array( - (data, (row, col)), shape=(norbs_tot, norbs_tot) - ).toarray() + h_0[onsite_idx] += onsite_value else: - h_0[onsite_idx] = coo_array( - (data, (row, col)), shape=(norbs_tot, norbs_tot) - ).toarray() - # Hopping matrices - for hop, val in builder.hopping_value_pairs(): - a, b = hop - b_dom = builder.symmetry.which(b) - b_fd = builder.symmetry.to_fd(b) - atoms = np.array([sites_list.index(a), sites_list.index(b_fd)]) - row, col = [ - np.sum(norbs_list[: atoms[0]]) + range(norbs_list[atoms[0]]), - np.sum(norbs_list[: atoms[1]]) + range(norbs_list[atoms[1]]), + h_0[onsite_idx] = onsite_value + + for (site1, site2), val in builder.hopping_value_pairs(): + site2_dom = builder.symmetry.which(site2) + site2_fd = builder.symmetry.to_fd(site2) + + site1_idx, site2_idx = np.array( + [sites_list.index(site1), sites_list.index(site2_fd)] + ) + tb_idx1, tb_idx2 = [ + np.sum(norbs_list[:site1_idx]) + range(norbs_list[site1_idx]), + np.sum(norbs_list[:site2_idx]) + range(norbs_list[site2_idx]), ] - row, col = np.array([*product(row, col)]).T - try: - _params = {} - for arg in inspect.getfullargspec(val).args: - if arg in params: - _params[arg] = params[arg] - val = val(a, b, **_params) - data = val.flatten() - except Exception: - data = val.flatten() - if tuple(b_dom) in h_0: - h_0[tuple(b_dom)] += coo_array( - (data, (row, col)), shape=(norbs_tot, norbs_tot) - ).toarray() - if np.linalg.norm(b_dom) == 0: - h_0[tuple(b_dom)] += ( - coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot)) - .toarray() - .T.conj() - ) + row, col = np.array([*product(tb_idx1, tb_idx2)]).T + data = np.array(_parse_val(val)).flatten() + hopping_value = coo_array((data, (row, col)), shape=tb_shape).toarray() + + hop_key = tuple(site2_dom) + hop_key_back = tuple(-site2_dom) + if hop_key in h_0: + h_0[hop_key] += hopping_value + if np.linalg.norm(site2_dom) == 0: + h_0[hop_key] += hopping_value.conj().T else: - # Hopping vector in the opposite direction - h_0[tuple(-b_dom)] += ( - coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot)) - .toarray() - .T.conj() - ) + h_0[hop_key_back] += hopping_value.conj().T else: - h_0[tuple(b_dom)] = coo_array( - (data, (row, col)), shape=(norbs_tot, norbs_tot) - ).toarray() - if np.linalg.norm(b_dom) == 0: - h_0[tuple(b_dom)] += ( - coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot)) - .toarray() - .T.conj() - ) + h_0[hop_key] = hopping_value + if np.linalg.norm(site2_dom) == 0: + h_0[hop_key] += hopping_value.conj().T else: - h_0[tuple(-b_dom)] = ( - coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot)) - .toarray() - .T.conj() - ) + h_0[hop_key_back] = hopping_value.conj().T if return_data: data = {} data["norbs"] = norbs_list - data["positions"] = positions_list + data["positions"] = [site.pos for site in sites_list] return h_0, data else: return h_0