Skip to content
Snippets Groups Projects

Builder fixes

Merged Kostas Vilkelis requested to merge builder_fixes into main
Files
3
+ 130
92
import inspect
from copy import copy
from itertools import product
from typing import Callable
from typing import Callable, Optional
import inspect
import numpy as np
from scipy.sparse import coo_array
import kwant
from kwant.builder import Site
import kwant.lattice
import kwant.builder
from meanfi.tb.tb import _tb_type
@@ -33,109 +34,145 @@ def builder_to_tb(
:
Data with sites and number of orbitals. Only if `return_data=True`.
"""
builder = copy(builder)
# Extract information from builder
dims = len(builder.symmetry.periods)
prim_vecs = builder.symmetry.periods
dims = len(prim_vecs)
sites_list = [*builder.sites()]
norbs_list = [site.family.norbs for site in builder.sites()]
norbs_list = [1 if norbs is None else norbs for norbs in norbs_list]
tb_norbs = sum(norbs_list)
tb_shape = (tb_norbs, tb_norbs)
onsite_idx = tuple([0] * dims)
h_0 = {}
sites_list = [*builder.sites()]
norbs_list = [site[0].norbs for site in builder.sites()]
positions_list = [site[0].pos for site in builder.sites()]
norbs_tot = sum(norbs_list)
# Extract onsite and hopping matrices.
# Based on `kwant.wraparound.wraparound`
# Onsite matrices
for site, val in builder.site_value_pairs():
site = builder.symmetry.to_fd(site)
atom = sites_list.index(site)
row = np.sum(norbs_list[:atom]) + range(norbs_list[atom])
col = copy(row)
row, col = np.array([*product(row, col)]).T
try:
_params = {}
for arg in inspect.getfullargspec(val).args:
if arg in params:
_params[arg] = params[arg]
val = val(site, **_params)
data = val.flatten()
except Exception:
data = val.flatten()
if onsite_idx in h_0:
h_0[onsite_idx] += coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
else:
h_0[onsite_idx] = coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
# Hopping matrices
for hop, val in builder.hopping_value_pairs():
a, b = hop
b_dom = builder.symmetry.which(b)
b_fd = builder.symmetry.to_fd(b)
atoms = np.array([sites_list.index(a), sites_list.index(b_fd)])
row, col = [
np.sum(norbs_list[: atoms[0]]) + range(norbs_list[atoms[0]]),
np.sum(norbs_list[: atoms[1]]) + range(norbs_list[atoms[1]]),
site_idx = sites_list.index(site)
tb_idx = np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx])
row, col = np.array([*product(tb_idx, tb_idx)]).T
if callable(val):
param_keys = inspect.getfullargspec(val).args[1:]
try:
val = val(site, *[params[key] for key in param_keys])
except KeyError as key:
raise KeyError(f"Parameter {key} not found in params.")
data = np.array(val).flatten()
onsite_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
h_0[onsite_idx] = h_0.get(onsite_idx, 0) + onsite_value
for (site1, site2), val in builder.hopping_value_pairs():
site2_dom = builder.symmetry.which(site2)
site2_fd = builder.symmetry.to_fd(site2)
site1_idx, site2_idx = np.array(
[sites_list.index(site1), sites_list.index(site2_fd)]
)
tb_idx1, tb_idx2 = [
np.sum(norbs_list[:site1_idx]) + range(norbs_list[site1_idx]),
np.sum(norbs_list[:site2_idx]) + range(norbs_list[site2_idx]),
]
row, col = np.array([*product(row, col)]).T
try:
_params = {}
for arg in inspect.getfullargspec(val).args:
if arg in params:
_params[arg] = params[arg]
val = val(a, b, **_params)
data = val.flatten()
except Exception:
data = val.flatten()
if tuple(b_dom) in h_0:
h_0[tuple(b_dom)] += coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
if np.linalg.norm(b_dom) == 0:
h_0[tuple(b_dom)] += (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
else:
# Hopping vector in the opposite direction
h_0[tuple(-b_dom)] += (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
else:
h_0[tuple(b_dom)] = coo_array(
(data, (row, col)), shape=(norbs_tot, norbs_tot)
).toarray()
if np.linalg.norm(b_dom) == 0:
h_0[tuple(b_dom)] += (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
else:
h_0[tuple(-b_dom)] = (
coo_array((data, (row, col)), shape=(norbs_tot, norbs_tot))
.toarray()
.T.conj()
)
row, col = np.array([*product(tb_idx1, tb_idx2)]).T
if callable(val):
param_keys = inspect.getfullargspec(val).args[2:]
try:
val = val(site1, site2, *[params[key] for key in param_keys])
except KeyError as key:
raise KeyError(f"Parameter {key} not found in params.")
data = np.array(val).flatten()
hopping_value = coo_array((data, (row, col)), shape=tb_shape).toarray()
hop_key = tuple(site2_dom)
hop_key_back = tuple(-site2_dom)
h_0[hop_key] = h_0.get(hop_key, 0) + hopping_value
h_0[hop_key_back] = h_0.get(hop_key_back, 0) + hopping_value.T.conj()
if return_data:
data = {}
data["norbs"] = norbs_list
data["positions"] = positions_list
data["periods"] = prim_vecs
data["sites"] = sites_list
return h_0, data
else:
return h_0
def tb_to_builder(
h_0: _tb_type, sites_list: list[Site, ...], periods: np.ndarray
) -> kwant.builder.Builder:
"""
Construct a `kwant.builder.Builder` from a tight-binding dictionary.
Parameters
----------
h_0 :
Tight-binding dictionary.
sites_list :
List of sites in the builder's unit cell.
periods :
2d array with periods of the translational symmetry.
Returns
-------
:
`kwant.builder.Builder` that corresponds to the tight-binding dictionary.
"""
builder = kwant.Builder(kwant.TranslationalSymmetry(*periods))
onsite_idx = tuple([0] * len(list(h_0)[0]))
norbs_list = [site.family.norbs for site in sites_list]
norbs_list = [1 if norbs is None else norbs for norbs in norbs_list]
def site_to_tbIdxs(site):
site_idx = sites_list.index(site)
return (np.sum(norbs_list[:site_idx]) + range(norbs_list[site_idx])).astype(int)
# assemble the sites first
for site in sites_list:
tb_idxs = site_to_tbIdxs(site)
value = h_0[onsite_idx][
tb_idxs[0] : tb_idxs[-1] + 1, tb_idxs[0] : tb_idxs[-1] + 1
]
builder[site] = value
# connect hoppings within the unit-cell
for site1, site2 in product(sites_list, sites_list):
if site1 == site2:
continue
tb_idxs1 = site_to_tbIdxs(site1)
tb_idxs2 = site_to_tbIdxs(site2)
value = h_0[onsite_idx][
tb_idxs1[0] : tb_idxs1[-1] + 1, tb_idxs2[0] : tb_idxs2[-1] + 1
]
if np.all(value == 0):
continue
builder[(site1, site2)] = value
# connect hoppings between unit-cells
for key in h_0:
if key == onsite_idx:
continue
for site1, site2_fd in product(sites_list, sites_list):
site2 = builder.symmetry.act(key, site2_fd)
tb_idxs1 = site_to_tbIdxs(site1)
tb_idxs2 = site_to_tbIdxs(site2_fd)
value = h_0[key][
tb_idxs1[0] : tb_idxs1[-1] + 1, tb_idxs2[0] : tb_idxs2[-1] + 1
]
if np.all(value == 0):
continue
builder[(site1, site2)] = value
return builder
def build_interacting_syst(
builder: kwant.builder.Builder,
lattice: kwant.lattice.Polyatomic,
func_onsite: Callable,
func_hop: Callable,
func_hop: Optional[Callable] = None,
max_neighbor: int = 1,
) -> kwant.builder.Builder:
"""
@@ -164,6 +201,7 @@ def build_interacting_syst(
kwant.lattice.TranslationalSymmetry(*builder.symmetry.periods)
)
int_builder[builder.sites()] = func_onsite
for neighbors in range(max_neighbor):
int_builder[lattice.neighbors(neighbors + 1)] = func_hop
if func_hop is not None:
for neighbors in range(max_neighbor):
int_builder[lattice.neighbors(neighbors + 1)] = func_hop
return int_builder
Loading