Skip to content
Snippets Groups Projects
Commit 17845d55 authored by Christoph Groth's avatar Christoph Groth
Browse files

builder: gather common code in _FinalizedBuilderMixin

parent f659d101
No related branches found
No related tags found
No related merge requests found
......@@ -267,7 +267,8 @@ nitpick_ignore = [('py:class', 'Warning'), ('py:class', 'Exception'),
('py:class', 'object'), ('py:class', 'tuple'),
('py:class', 'kwant.operator._LocalOperator'),
('py:class', 'numpy.ndarray'),
('py:class', 'kwant.solvers.common.BlockResult')]
('py:class', 'kwant.solvers.common.BlockResult'),
('py:class', 'kwant.builder._FinalizedBuilderMixin')]
# Use custom MathJax CDN, as cdn.mathjax.org will soon shut down
......
......@@ -1590,29 +1590,6 @@ def _raise_user_error(exc, func):
raise UserCodeError(msg.format(func.__name__)) from exc
def discrete_symmetry(self, args=(), *, params=None):
if self._cons_law is not None:
eigvals, eigvecs = self._cons_law
eigvals = eigvals.tocoo(args, params=params)
if not np.allclose(eigvals.data, np.round(eigvals.data)):
raise ValueError("Conservation law must have integer eigenvalues.")
eigvals = np.round(eigvals).tocsr()
# Avoid appearance of zero eigenvalues
eigvals = eigvals + 0.5 * sparse.identity(eigvals.shape[0])
eigvals.eliminate_zeros()
eigvecs = eigvecs.tocoo(args, params=params)
projectors = [eigvecs.dot(eigvals == val)
for val in sorted(np.unique(eigvals.data))]
else:
projectors = None
def evaluate(op):
return None if op is None else op.tocoo(args, params=params)
return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in
self._symmetries))
def _translate_cons_law(cons_law):
"""Translate a conservation law from builder format to something that can
be used to initialize operator.Density.
......@@ -1650,22 +1627,123 @@ def _translate_cons_law(cons_law):
return vals, vecs
def _init_discrete_symmetries(self, builder):
def _operator(op):
class _FinalizedBuilderMixin:
"""Common functionality for all finalized builders"""
def _init_ham_param_maps(self):
"""Find parameters taken by all value functions
"""
ham_param_map = {}
for hams, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
for ham in hams:
if (not callable(ham) or ham is Other or
ham in ham_param_map):
continue
# parameters come in the same order as in the function signature
params, defaults, takes_kwargs = get_parameters(ham)
params = params[skip:] # remove site argument(s)
ham_param_map[ham] = (params, defaults, takes_kwargs)
self._ham_param_map = ham_param_map
def _init_discrete_symmetries(self, builder):
def operator(op):
return Density(self, op, check_hermiticity=False)
if builder.conservation_law is None:
self._cons_law = None
else:
self._cons_law = tuple(map(
_operator, _translate_cons_law(builder.conservation_law)))
self._symmetries = tuple(None if op is None else _operator(op)
a = _translate_cons_law(builder.conservation_law)
self._cons_law = tuple(map(operator, a))
self._symmetries = tuple(None if op is None else operator(op)
for op in [builder.time_reversal,
builder.particle_hole,
builder.chiral])
def hamiltonian(self, i, j, *args, params=None):
if args and params:
raise TypeError("'args' and 'params' are mutually exclusive.")
if i == j:
value = self.onsite_hamiltonians[i]
if callable(value):
site = self.symmetry.to_fd(self.sites[i])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site, *args)
except Exception as exc:
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
conj = value is Other
if conj:
i, j = j, i
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
if callable(value):
sites = self.sites
site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site_i, site_j, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site_i, site_j, *args)
except Exception as exc:
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
return value
def discrete_symmetry(self, args=(), *, params=None):
if self._cons_law is not None:
eigvals, eigvecs = self._cons_law
eigvals = eigvals.tocoo(args, params=params)
if not np.allclose(eigvals.data, np.round(eigvals.data)):
raise ValueError("Conservation law must have integer"
" eigenvalues.")
eigvals = np.round(eigvals).tocsr()
# Avoid appearance of zero eigenvalues
eigvals = eigvals + 0.5 * sparse.identity(eigvals.shape[0])
eigvals.eliminate_zeros()
eigvecs = eigvecs.tocoo(args, params=params)
projectors = [eigvecs.dot(eigvals == val)
for val in sorted(np.unique(eigvals.data))]
else:
projectors = None
class FiniteSystem(system.FiniteSystem):
def evaluate(op):
return None if op is None else op.tocoo(args, params=params)
return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in
self._symmetries))
class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
"""Finalized `Builder` with leads.
Usable as input for the solvers in `kwant.solvers`.
......@@ -1730,97 +1808,28 @@ class FiniteSystem(system.FiniteSystem):
lead_interfaces.append(np.array(interface))
#### Find parameters taken by all value functions
onsite_hamiltonians = [builder.H[site][1] for site in sites]
hoppings = [builder._get_edge(sites[tail], sites[head])
for tail, head in g]
onsite_hamiltonians = [builder.H[site][1] for site in sites]
_ham_param_map = {}
for hams, skip in [(onsite_hamiltonians, 1), (hoppings, 2)]:
for ham in hams:
if (not callable(ham) or ham is Other or
ham in _ham_param_map):
continue
# parameters come in the same order as in the function signature
params, defaults, takes_kwargs = get_parameters(ham)
params = params[skip:] # remove site argument(s)
_ham_param_map[ham] = (params, defaults, takes_kwargs)
self.graph = g
self.sites = sites
self.site_ranges = _site_ranges(sites)
self.id_by_site = id_by_site
self.leads = finalized_leads
self.hoppings = hoppings
self.onsite_hamiltonians = onsite_hamiltonians
self._ham_param_map = _ham_param_map
self.lead_interfaces = lead_interfaces
self.symmetry = builder.symmetry
_init_discrete_symmetries(self, builder)
self.leads = finalized_leads
self.lead_interfaces = lead_interfaces
self._init_ham_param_maps()
self._init_discrete_symmetries(builder)
def hamiltonian(self, i, j, *args, params=None):
if args and params:
raise TypeError("'args' and 'params' are mutually exclusive.")
if i == j:
value = self.onsite_hamiltonians[i]
if callable(value):
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(self.sites[i], **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(self.sites[i], *args)
except Exception as exc:
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
conj = value is Other
if conj:
i, j = j, i
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
if callable(value):
sites = self.sites
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(sites[i], sites[j], **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(sites[i], sites[j], *args)
except Exception as exc:
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
return value
def pos(self, i):
return self.sites[i].pos
discrete_symmetry = discrete_symmetry
class InfiniteSystem(system.InfiniteSystem):
class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
"""Finalized infinite system, extracted from a `Builder`.
Attributes
......@@ -1851,7 +1860,6 @@ class InfiniteSystem(system.InfiniteSystem):
sym = builder.symmetry
assert sym.num_directions == 1
#### For each site of the fundamental domain, determine whether it has
#### neighbors in the previous domain or not.
lsites_with = [] # Fund. domain sites with neighbors in prev. dom
......@@ -1959,90 +1967,24 @@ class InfiniteSystem(system.InfiniteSystem):
tail, head = sym.to_fd(tail, head)
hoppings.append(builder._get_edge(tail, head))
#### Find parameters taken by all value functions
_ham_param_map = {}
for hams, skip in [(onsite_hamiltonians, 1), (hoppings, 2)]:
for ham in hams:
if (not callable(ham) or ham is Other or
ham in _ham_param_map):
continue
# parameters come in the same order as in the function signature
params, defaults, takes_kwargs = get_parameters(ham)
params = params[skip:] # remove site argument(s)
_ham_param_map[ham] = (params, defaults, takes_kwargs)
self.cell_size = cell_size
self.graph = g
self.sites = sites
self.id_by_site = id_by_site
self.site_ranges = _site_ranges(sites)
self.graph = g
self.id_by_site = id_by_site
self.hoppings = hoppings
self.onsite_hamiltonians = onsite_hamiltonians
self._ham_param_map = _ham_param_map
self.symmetry = builder.symmetry
_init_discrete_symmetries(self, builder)
self.cell_size = cell_size
self._init_ham_param_maps()
self._init_discrete_symmetries(builder)
def hamiltonian(self, i, j, *args, params=None):
if args and params:
raise TypeError("'args' and 'params' are mutually exclusive.")
if i == j:
if i >= self.cell_size:
i -= self.cell_size
value = self.onsite_hamiltonians[i]
if callable(value):
site = self.symmetry.to_fd(self.sites[i])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site, *args)
except Exception as exc:
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
conj = value is Other
if conj:
i, j = j, i
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
if callable(value):
sites = self.sites
site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site_i, site_j, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site_i, site_j, *args)
except Exception as exc:
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
return value
cs = self.cell_size
if i == j >= cs:
i -= cs
j -= cs
return super().hamiltonian(i, j, *args, params=params)
def pos(self, i):
return self.sites[i].pos
discrete_symmetry = discrete_symmetry
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment