Skip to content
Snippets Groups Projects
Commit 27b7f67e authored by Kostas Vilkelis's avatar Kostas Vilkelis :flamingo:
Browse files

simplify tb from kwant build; rm copy and other redundant imports

parent 26e212c6
No related branches found
No related tags found
1 merge request!8Builder fixes
Pipeline #180295 failed
This commit is part of merge request !8. Comments created here will be created in the context of that merge request.
import inspect
from copy import copy
from itertools import product
from typing import Callable
......@@ -33,99 +31,73 @@ def builder_to_tb(
:
Data with sites and number of orbitals. Only if `return_data=True`.
"""
builder = copy(builder)
# Extract information from builder
dims = len(builder.symmetry.periods)
prim_vecs = builder.symmetry.periods
dims = len(prim_vecs)
sites_list = [*builder.sites()]
norbs_list = [site.family.norbs for site in builder.sites()]
norbs_list = [1 if norbs is None else norbs for norbs in norbs_list]
tb_norbs = sum(norbs_list)
tb_shape = (tb_norbs, tb_norbs)
onsite_idx = tuple([0] * dims)
h_0 = {}
sites_list = [*builder.sites()]
norbs_list = [site[0].norbs for site in builder.sites()]
positions_list = [site[0].pos for site in builder.sites()]
norbs_tot = sum(norbs_list)
# Extract onsite and hopping matrices.
# Based on `kwant.wraparound.wraparound`
# Onsite matrices
def _parse_val(val):
if callable(val):
param_keys = val.__code__.co_varnames[1:]
try:
val = val(site, *[params[key] for key in param_keys])
except KeyError as key:
raise KeyError(f"Parameter {key} not found in params.")
return val
for site, val in builder.site_value_pairs():
site = builder.symmetry.to_fd(site)
atom = sites_list.index(site)
row = np.sum(norbs_list[:atom]) + range(norbs_list[atom])
col = copy(row)
row, col = np.array([*product(row, col)]).T
try:
_params = {}
for arg in inspect.getfullargspec(val).args:
if arg in params:
_params[arg] = params[arg]
val = val(site, **_params)
data = val.flatten()
except Exception:
data = val.flatten()
site_idx = sites_list.index(site)
tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx])
row, col = np.array([*product(tb_idx, tb_idx)]).T
data = np.array(_parse_val(val)).flatten()
onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
if onsite_idx in h_0:
h_0[onsite_idx] += coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
h_0[onsite_idx] += onsite_value
else:
h_0[onsite_idx] = coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
# Hopping matrices
for hop, val in builder.hopping_value_pairs():
a, b = hop
b_dom = builder.symmetry.which(b)
b_fd = builder.symmetry.to_fd(b)
atoms = np.array([sites_list.index(a), sites_list.index(b_fd)])
row, col = [
np.sum(norbs_list[: atoms[0]]) + range(norbs_list[atoms[0]]),
np.sum(norbs_list[: atoms[1]]) + range(norbs_list[atoms[1]]),
h_0[onsite_idx] = onsite_value
for (site1, site2), val in builder.hopping_value_pairs():
site2_dom = builder.symmetry.which(site2)
site2_fd = builder.symmetry.to_fd(site2)
site1_idx, site2_idx = np.array(
[sites_list.index(site1), sites_list.index(site2_fd)]
)
tb_idx1, tb_idx2 = [
np.sum(norbs_list[:site1_idx]) + range(norbs_list[site1_idx]),
np.sum(norbs_list[:site2_idx]) + range(norbs_list[site2_idx]),
]
row, col = np.array([*product(row, col)]).T
try:
_params = {}
for arg in inspect.getfullargspec(val).args:
if arg in params:
_params[arg] = params[arg]
val = val(a, b, **_params)
data = val.flatten()
except Exception:
data = val.flatten()
if tuple(b_dom) in h_0:
h_0[tuple(b_dom)] += coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
if np.linalg.norm(b_dom) == 0:
h_0[tuple(b_dom)] += (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
row, col = np.array([*product(tb_idx1, tb_idx2)]).T
data = np.array(_parse_val(val)).flatten()
hopping_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
hop_key = tuple(site2_dom)
hop_key_back = tuple(-site2_dom)
if hop_key in h_0:
h_0[hop_key] += hopping_value
if np.linalg.norm(site2_dom) == 0:
h_0[hop_key] += hopping_value.conj().T
else:
# Hopping vector in the opposite direction
h_0[tuple(-b_dom)] += (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
h_0[hop_key_back] += hopping_value.conj().T
else:
h_0[tuple(b_dom)] = coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
if np.linalg.norm(b_dom) == 0:
h_0[tuple(b_dom)] += (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
h_0[hop_key] = hopping_value
if np.linalg.norm(site2_dom) == 0:
h_0[hop_key] += hopping_value.conj().T
else:
h_0[tuple(-b_dom)] = (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
h_0[hop_key_back] = hopping_value.conj().T
if return_data:
data = {}
data["norbs"] = norbs_list
data["positions"] = positions_list
data["positions"] = [site.pos for site in sites_list]
return h_0, data
else:
return h_0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment