Commit 790a6f92 authored by Joseph Weston's avatar Joseph Weston
Browse files

add Builder.subs method and tests

parent 7bb1bab2
......@@ -21,7 +21,8 @@ from . import system, graph, KwantDeprecationWarning, UserCodeError
from .linalg import lll
from .operator import Density
from .physics import DiscreteSymmetry
from ._common import ensure_isinstance, get_parameters, reraise_warnings
from ._common import (ensure_isinstance, get_parameters, reraise_warnings,
__all__ = ['Builder', 'Site', 'SiteFamily', 'SimpleSiteFamily', 'Symmetry',
......@@ -796,6 +797,15 @@ class ParameterSubstitution:
return self.function(**_invert_map(self.substitutions, arguments))
def _substitute_parameters(value_func, relevant_params, subs):
"""Substitute 'relevant_params' from 'subs' into 'value_func'."""
assert callable(value_func)
relevant_subs = {n: subs[n] for n in relevant_params if n in subs}
if not relevant_subs:
return value_func
return ParameterSubstitution(value_func, relevant_subs)
class Builder:
"""A tight binding system defined on a graph.
......@@ -1369,6 +1379,63 @@ class Builder:
return self
def subs(self, **subs):
"""Return a copy of this Builder with modified parameter names."""
# Get value *functions* only
onsites = list(set(
onsite for _, onsite in self.site_value_pairs()
if callable(onsite)
hoppings = list(set(
hop for _, hop in self.hopping_value_pairs()
if callable(hop)
flatten = chain.from_iterable
# Get parameter names to be substituted for each function,
# without the 'site' parameter(s)
onsite_params = [get_parameters(v).required[1:] for v in onsites]
hopping_params = [get_parameters(v).required[2:] for v in hoppings]
system_params = set(flatten(chain(onsite_params, hopping_params)))
nonexistant_params = set(subs.keys()).difference(system_params)
if nonexistant_params:
msg = ('Parameters {} are not used by any onsite or hopping '
'value function in this system.'
warnings.warn(msg, RuntimeWarning, stacklevel=2)
# Precompute map from old onsite/hopping value functions to ones
# with substituted parameters.
value_map = {
value: _substitute_parameters(value, params, subs)
for value, params in chain(zip(onsites, onsite_params),
zip(hoppings, hopping_params))
# Construct the a copy of the system with new value functions.
if self.leads:
raise ValueError("Using 'subs' on a Builder with attached leads "
"is ambiguous. Consider using 'subs' before "
"attaching leads.")
result = copy.copy(self)
# if we don't assign a new list we will inadvertantly add leads to
# the reversed system if we add leads to *this* system
# (because we only shallow copy)
result.leads = []
# Copy the 'H' dictionary, mapping old values to new ones using
# 'value_map'. If a value does not appear in the map then it means
# that the old value should be used.
result.H = {}
for tail, hvhv in self.H.items():
result.H[tail] = list(flatten(
(head, value_map.get(value, value))
for head, value in interleave(hvhv)
return result
def fill(self, template, shape, start, *, max_sites=10**7):
"""Populate builder using another one as a template.
# Copyright 2011-2016 Kwant authors.
# Copyright 2011-2018 Kwant authors.
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -1301,3 +1301,59 @@ def test_parameter_substitution():
g = Subs(f, dict(x='a'))
h = Subs(f, dict(x='a'))
assert len(set([f, g, h])) == 2
def test_subs():
# Simple case
def onsite(site, a, b):
salt = str(a) + str(b)
return kwant.digest.uniform(site.tag, salt=salt)
def hopping(sitea, siteb, b, c):
salt = str(b) + str(c)
return kwant.digest.uniform(ta.array((sitea.tag, siteb.tag)), salt=salt)
lat = kwant.lattice.chain()
def make_system(sym=kwant.builder.NoSymmetry(), n=3):
syst = kwant.Builder(sym)
syst[(lat(i) for i in range(n))] = onsite
syst[lat.neighbors()] = hopping
return syst
def hamiltonian(syst, **kwargs):
return syst.finalized().hamiltonian_submatrix(params=kwargs)
syst = make_system()
# parameter name not an identifier
raises(ValueError, syst.subs, a='not-an-identifier?')
# substituting a paramter that doesn't exist produces a warning
warns(RuntimeWarning, syst.subs, fakeparam='yes')
# name clash in value functions
raises(ValueError, syst.subs, b='a')
raises(ValueError, syst.subs, b='c')
raises(ValueError, syst.subs, a='site')
raises(ValueError, syst.subs, c='sitea')
# cannot call 'subs' on systems with attached leads, because
# it is not clear whether the substitutions should propagate
# into the leads too.
syst = make_system()
lead = make_system(kwant.TranslationalSymmetry((-1,)), n=1)
raises(ValueError, syst.subs, a='d')
# test basic substitutions
syst = make_system()
expected = hamiltonian(syst, a=1, b=2, c=3)
# 1 level of substitutions
sub_syst = syst.subs(a='d', b='e')
assert np.allclose(hamiltonian(sub_syst, d=1, e=2, c=3), expected)
# 2 levels of substitution
sub_sub_syst = sub_syst.subs(d='g', c='h')
assert np.allclose(hamiltonian(sub_sub_syst, g=1, e=2, h=3), expected)
# very confusing but technically valid. 'a' does not appear in 'hopping',
# so the signature of 'onsite' is valid.
sub_syst = syst.subs(a='sitea')
assert np.allclose(hamiltonian(sub_syst, sitea=1, b=2, c=3), expected)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment