From 790a6f928f7b2326451945fc909489dfdec643ca Mon Sep 17 00:00:00 2001 From: Joseph Weston <joseph@weston.cloud> Date: Tue, 6 Mar 2018 18:07:21 +0100 Subject: [PATCH] add Builder.subs method and tests --- kwant/builder.py | 69 ++++++++++++++++++++++++++++++++++++- kwant/tests/test_builder.py | 58 ++++++++++++++++++++++++++++++- 2 files changed, 125 insertions(+), 2 deletions(-) diff --git a/kwant/builder.py b/kwant/builder.py index eb2f0a97..7702ccd4 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -21,7 +21,8 @@ from . import system, graph, KwantDeprecationWarning, UserCodeError from .linalg import lll from .operator import Density from .physics import DiscreteSymmetry -from ._common import ensure_isinstance, get_parameters, reraise_warnings +from ._common import (ensure_isinstance, get_parameters, reraise_warnings, + interleave) __all__ = ['Builder', 'Site', 'SiteFamily', 'SimpleSiteFamily', 'Symmetry', @@ -796,6 +797,15 @@ class ParameterSubstitution: return self.function(**_invert_map(self.substitutions, arguments)) +def _substitute_parameters(value_func, relevant_params, subs): + """Substitute 'relevant_params' from 'subs' into 'value_func'.""" + assert callable(value_func) + relevant_subs = {n: subs[n] for n in relevant_params if n in subs} + if not relevant_subs: + return value_func + return ParameterSubstitution(value_func, relevant_subs) + + class Builder: """A tight binding system defined on a graph. @@ -1369,6 +1379,63 @@ class Builder: self.update(other) return self + def subs(self, **subs): + """Return a copy of this Builder with modified parameter names.""" + # Get value *functions* only + onsites = list(set( + onsite for _, onsite in self.site_value_pairs() + if callable(onsite) + )) + hoppings = list(set( + hop for _, hop in self.hopping_value_pairs() + if callable(hop) + )) + + flatten = chain.from_iterable + + # Get parameter names to be substituted for each function, + # without the 'site' parameter(s) + onsite_params = [get_parameters(v).required[1:] for v in onsites] + hopping_params = [get_parameters(v).required[2:] for v in hoppings] + + system_params = set(flatten(chain(onsite_params, hopping_params))) + nonexistant_params = set(subs.keys()).difference(system_params) + if nonexistant_params: + msg = ('Parameters {} are not used by any onsite or hopping ' + 'value function in this system.' + ).format(nonexistant_params) + warnings.warn(msg, RuntimeWarning, stacklevel=2) + + # Precompute map from old onsite/hopping value functions to ones + # with substituted parameters. + value_map = { + value: _substitute_parameters(value, params, subs) + for value, params in chain(zip(onsites, onsite_params), + zip(hoppings, hopping_params)) + } + + # Construct the a copy of the system with new value functions. + if self.leads: + raise ValueError("Using 'subs' on a Builder with attached leads " + "is ambiguous. Consider using 'subs' before " + "attaching leads.") + result = copy.copy(self) + # if we don't assign a new list we will inadvertantly add leads to + # the reversed system if we add leads to *this* system + # (because we only shallow copy) + result.leads = [] + # Copy the 'H' dictionary, mapping old values to new ones using + # 'value_map'. If a value does not appear in the map then it means + # that the old value should be used. + result.H = {} + for tail, hvhv in self.H.items(): + result.H[tail] = list(flatten( + (head, value_map.get(value, value)) + for head, value in interleave(hvhv) + )) + + return result + def fill(self, template, shape, start, *, max_sites=10**7): """Populate builder using another one as a template. diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py index 46bf3478..e4ecd59f 100644 --- a/kwant/tests/test_builder.py +++ b/kwant/tests/test_builder.py @@ -1,4 +1,4 @@ -# Copyright 2011-2016 Kwant authors. +# Copyright 2011-2018 Kwant authors. # # This file is part of Kwant. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution and at @@ -1301,3 +1301,59 @@ def test_parameter_substitution(): g = Subs(f, dict(x='a')) h = Subs(f, dict(x='a')) assert len(set([f, g, h])) == 2 + + +def test_subs(): + + # Simple case + + def onsite(site, a, b): + salt = str(a) + str(b) + return kwant.digest.uniform(site.tag, salt=salt) + + def hopping(sitea, siteb, b, c): + salt = str(b) + str(c) + return kwant.digest.uniform(ta.array((sitea.tag, siteb.tag)), salt=salt) + + lat = kwant.lattice.chain() + + def make_system(sym=kwant.builder.NoSymmetry(), n=3): + syst = kwant.Builder(sym) + syst[(lat(i) for i in range(n))] = onsite + syst[lat.neighbors()] = hopping + return syst + + def hamiltonian(syst, **kwargs): + return syst.finalized().hamiltonian_submatrix(params=kwargs) + + syst = make_system() + # parameter name not an identifier + raises(ValueError, syst.subs, a='not-an-identifier?') + # substituting a paramter that doesn't exist produces a warning + warns(RuntimeWarning, syst.subs, fakeparam='yes') + # name clash in value functions + raises(ValueError, syst.subs, b='a') + raises(ValueError, syst.subs, b='c') + raises(ValueError, syst.subs, a='site') + raises(ValueError, syst.subs, c='sitea') + # cannot call 'subs' on systems with attached leads, because + # it is not clear whether the substitutions should propagate + # into the leads too. + syst = make_system() + lead = make_system(kwant.TranslationalSymmetry((-1,)), n=1) + syst.attach_lead(lead) + raises(ValueError, syst.subs, a='d') + + # test basic substitutions + syst = make_system() + expected = hamiltonian(syst, a=1, b=2, c=3) + # 1 level of substitutions + sub_syst = syst.subs(a='d', b='e') + assert np.allclose(hamiltonian(sub_syst, d=1, e=2, c=3), expected) + # 2 levels of substitution + sub_sub_syst = sub_syst.subs(d='g', c='h') + assert np.allclose(hamiltonian(sub_sub_syst, g=1, e=2, h=3), expected) + # very confusing but technically valid. 'a' does not appear in 'hopping', + # so the signature of 'onsite' is valid. + sub_syst = syst.subs(a='sitea') + assert np.allclose(hamiltonian(sub_syst, sitea=1, b=2, c=3), expected) -- GitLab