From 790a6f928f7b2326451945fc909489dfdec643ca Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Tue, 6 Mar 2018 18:07:21 +0100
Subject: [PATCH] add Builder.subs method and tests

---
 kwant/builder.py            | 69 ++++++++++++++++++++++++++++++++++++-
 kwant/tests/test_builder.py | 58 ++++++++++++++++++++++++++++++-
 2 files changed, 125 insertions(+), 2 deletions(-)

diff --git a/kwant/builder.py b/kwant/builder.py
index eb2f0a97..7702ccd4 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -21,7 +21,8 @@ from . import system, graph, KwantDeprecationWarning, UserCodeError
 from .linalg import lll
 from .operator import Density
 from .physics import DiscreteSymmetry
-from ._common import ensure_isinstance, get_parameters, reraise_warnings
+from ._common import (ensure_isinstance, get_parameters, reraise_warnings,
+                      interleave)
 
 
 __all__ = ['Builder', 'Site', 'SiteFamily', 'SimpleSiteFamily', 'Symmetry',
@@ -796,6 +797,15 @@ class ParameterSubstitution:
         return self.function(**_invert_map(self.substitutions, arguments))
 
 
+def _substitute_parameters(value_func, relevant_params, subs):
+    """Substitute 'relevant_params' from 'subs' into 'value_func'."""
+    assert callable(value_func)
+    relevant_subs = {n: subs[n] for n in relevant_params if n in subs}
+    if not relevant_subs:
+        return value_func
+    return ParameterSubstitution(value_func, relevant_subs)
+
+
 class Builder:
     """A tight binding system defined on a graph.
 
@@ -1369,6 +1379,63 @@ class Builder:
         self.update(other)
         return self
 
+    def subs(self, **subs):
+        """Return a copy of this Builder with modified parameter names."""
+        # Get value *functions* only
+        onsites = list(set(
+            onsite for _, onsite in self.site_value_pairs()
+            if callable(onsite)
+        ))
+        hoppings = list(set(
+            hop for _, hop in self.hopping_value_pairs()
+            if callable(hop)
+        ))
+
+        flatten = chain.from_iterable
+
+        # Get parameter names to be substituted for each function,
+        # without the 'site' parameter(s)
+        onsite_params = [get_parameters(v).required[1:] for v in onsites]
+        hopping_params = [get_parameters(v).required[2:] for v in hoppings]
+
+        system_params = set(flatten(chain(onsite_params, hopping_params)))
+        nonexistant_params = set(subs.keys()).difference(system_params)
+        if nonexistant_params:
+            msg = ('Parameters {} are not used by any onsite or hopping '
+                   'value function in this system.'
+                  ).format(nonexistant_params)
+            warnings.warn(msg, RuntimeWarning, stacklevel=2)
+
+        # Precompute map from old onsite/hopping value functions to ones
+        # with substituted parameters.
+        value_map = {
+            value: _substitute_parameters(value, params, subs)
+            for value, params in chain(zip(onsites, onsite_params),
+                                       zip(hoppings, hopping_params))
+        }
+
+        # Construct the a copy of the system with new value functions.
+        if self.leads:
+            raise ValueError("Using 'subs' on a Builder with attached leads "
+                             "is ambiguous. Consider using 'subs' before "
+                             "attaching leads.")
+        result = copy.copy(self)
+        # if we don't assign a new list we will inadvertantly add leads to
+        # the reversed system if we add leads to *this* system
+        # (because we only shallow copy)
+        result.leads = []
+        # Copy the 'H' dictionary, mapping old values to new ones using
+        # 'value_map'. If a value does not appear in the map then it means
+        # that the old value should be used.
+        result.H = {}
+        for tail, hvhv in self.H.items():
+            result.H[tail] = list(flatten(
+                (head, value_map.get(value, value))
+                for head, value in interleave(hvhv)
+            ))
+
+        return result
+
     def fill(self, template, shape, start, *, max_sites=10**7):
         """Populate builder using another one as a template.
 
diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py
index 46bf3478..e4ecd59f 100644
--- a/kwant/tests/test_builder.py
+++ b/kwant/tests/test_builder.py
@@ -1,4 +1,4 @@
-# Copyright 2011-2016 Kwant authors.
+# Copyright 2011-2018 Kwant authors.
 #
 # This file is part of Kwant.  It is subject to the license terms in the file
 # LICENSE.rst found in the top-level directory of this distribution and at
@@ -1301,3 +1301,59 @@ def test_parameter_substitution():
     g = Subs(f, dict(x='a'))
     h = Subs(f, dict(x='a'))
     assert len(set([f, g, h])) == 2
+
+
+def test_subs():
+
+    # Simple case
+
+    def onsite(site, a, b):
+        salt = str(a) + str(b)
+        return kwant.digest.uniform(site.tag, salt=salt)
+
+    def hopping(sitea, siteb, b, c):
+        salt = str(b) + str(c)
+        return kwant.digest.uniform(ta.array((sitea.tag, siteb.tag)), salt=salt)
+
+    lat = kwant.lattice.chain()
+
+    def make_system(sym=kwant.builder.NoSymmetry(), n=3):
+        syst = kwant.Builder(sym)
+        syst[(lat(i) for i in range(n))] = onsite
+        syst[lat.neighbors()] = hopping
+        return syst
+
+    def hamiltonian(syst, **kwargs):
+        return syst.finalized().hamiltonian_submatrix(params=kwargs)
+
+    syst = make_system()
+    # parameter name not an identifier
+    raises(ValueError, syst.subs, a='not-an-identifier?')
+    # substituting a paramter that doesn't exist produces a warning
+    warns(RuntimeWarning, syst.subs, fakeparam='yes')
+    # name clash in value functions
+    raises(ValueError, syst.subs, b='a')
+    raises(ValueError, syst.subs, b='c')
+    raises(ValueError, syst.subs, a='site')
+    raises(ValueError, syst.subs, c='sitea')
+    # cannot call 'subs' on systems with attached leads, because
+    # it is not clear whether the substitutions should propagate
+    # into the leads too.
+    syst = make_system()
+    lead = make_system(kwant.TranslationalSymmetry((-1,)), n=1)
+    syst.attach_lead(lead)
+    raises(ValueError, syst.subs, a='d')
+
+    # test basic substitutions
+    syst = make_system()
+    expected = hamiltonian(syst, a=1, b=2, c=3)
+    # 1 level of substitutions
+    sub_syst = syst.subs(a='d', b='e')
+    assert np.allclose(hamiltonian(sub_syst, d=1, e=2, c=3), expected)
+    # 2 levels of substitution
+    sub_sub_syst = sub_syst.subs(d='g', c='h')
+    assert np.allclose(hamiltonian(sub_sub_syst, g=1, e=2, h=3), expected)
+    # very confusing but technically valid. 'a' does not appear in 'hopping',
+    # so the signature of 'onsite' is valid.
+    sub_syst = syst.subs(a='sitea')
+    assert np.allclose(hamiltonian(sub_syst, sitea=1, b=2, c=3), expected)
-- 
GitLab