Commit c2da972a authored by Christoph Groth's avatar Christoph Groth Committed by Joseph Weston

optimize builder parameter substitution

Co-authored-by: Joseph Weston's avatarJoseph Weston <>
parent 0aa3f324
......@@ -11,7 +11,7 @@ import warnings
import operator
import collections
import copy
from functools import total_ordering, wraps
from functools import total_ordering, wraps, update_wrapper
from itertools import islice, chain
import inspect
import tinyarray as ta
......@@ -714,100 +714,57 @@ def _site_ranges(sites):
return site_ranges
def _substitute_sig(func, substitutions):
"""Substitute different parameter names into a function signature.
func : callable
The function whose signature we wish to copy
substitutions : dict or iterable
Mapping from old parameter names to new. Will be
fed to 'dict'.
substitutions = dict(substitutions) # Copy because we later destroy it
def new_name(name):
return substitutions.pop(name, name)
sig = inspect.signature(func)
new_params = [param.replace(name=new_name(name))
for name, param in sig.parameters.items()]
if substitutions:
raise ValueError('More substitutions than available parameters.')
return sig.replace(parameters=new_params)
def _compose_maps(f, g):
"""Compose the maps f and g from left to right
>>> _compose_maps(
... dict(x='a', y='b'),
... dict(a='c', z='e'))
{'x': 'c', 'y': 'b', 'z': 'e'}
composed = dict(g)
for fk, fv in f.items():
if fv in g:
del composed[fv]
composed[fk] = g[fv]
composed[fk] = fv
class _Substituted:
"""Proxy that renames function parameters."""
# eliminate identity maps k -> k
return dict((k, v) for k, v in composed.items() if k != v)
def __init__(self, func, params):
self.func = func
self.params = params
update_wrapper(self, func)
def __eq__(self, other):
if not isinstance(other, _Substituted):
return False
return (self.func == other.func and self.params == other.params)
def _invert_map(substitutions, arguments):
pmap = {new: old for old, new in substitutions}
return {pmap.get(param, param): value
for param, value in arguments.items()}
def __hash__(self):
return hash((self.func, self.params))
def __signature__(self):
return inspect.Signature(
[inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
for name in self.params])
class ParameterSubstitution:
"""Proxy that renames function parameters."""
__slots__ = ('function', 'substitutions', '__signature__')
def __call__(self, *args):
return self.func(*args)
def __init__(self, function, substitutions):
if isinstance(function, ParameterSubstitution):
self.function = function.function
substitutions = _compose_maps(dict(function.substitutions),
self.function = function
self.substitutions = tuple(sorted(substitutions.items()))
self.__signature__ = _substitute_sig(self.function, self.substitutions)
def __eq__(self, other):
if not isinstance(other, ParameterSubstitution):
return False
return ((self.function, self.substitutions) ==
(other.function, other.substitutions))
def _substitute_params(func, subs):
"""Substitute 'params' from 'subs' into 'func'."""
assert callable(func)
def __hash__(self):
return hash((self.function, self.substitutions))
if isinstance(func, _Substituted):
old_params = func.params
old_func = func.func
old_params = get_parameters(func)
old_func = func
def __call__(self, *args, **kwargs):
arguments = self.__signature__.bind(*args, **kwargs).arguments
return self.function(**_invert_map(self.substitutions, arguments))
params = tuple(subs.get(p, p) for p in old_params)
duplicates = [p for p, count in collections.Counter(params).items()
if count > 1]
if duplicates:
msg = ('Cannot rename parameters ',
','.join('"{}"'.format(d) for d in duplicates),
': parameters with the same name exist')
raise ValueError(''.join(msg))
def _substitute_parameters(value_func, relevant_params, subs):
"""Substitute 'relevant_params' from 'subs' into 'value_func'."""
assert callable(value_func)
relevant_subs = {n: subs[n] for n in relevant_params if n in subs}
if not relevant_subs:
return value_func
return ParameterSubstitution(value_func, relevant_subs)
if params == old_params:
return func
return _Substituted(old_func, params)
class Builder:
......@@ -1385,23 +1342,14 @@ class Builder:
def subs(self, **subs):
"""Return a copy of this Builder with modified parameter names.
The value functions of the returned Builder take longer to
evaluate due to the parameter renaming. This overhead may be
a significant fraction of the total time if the original value
function is particularly quick to evaluate.
# Get value *functions* only
onsites = list(set(
onsite for _, onsite in self.site_value_pairs()
if callable(onsite)
if callable(onsite)))
hoppings = list(set(
hop for _, hop in self.hopping_value_pairs()
if callable(hop)
if callable(hop)))
flatten = chain.from_iterable
......@@ -1420,11 +1368,8 @@ class Builder:
# Precompute map from old onsite/hopping value functions to ones
# with substituted parameters.
value_map = {
value: _substitute_parameters(value, params, subs)
for value, params in chain(zip(onsites, onsite_params),
zip(hoppings, hopping_params))
value_map = {value: _substitute_params(value, subs)
for value in chain(onsites, hoppings)}
# Construct the a copy of the system with new value functions.
if self.leads:
......@@ -1443,8 +1388,7 @@ class Builder:
for tail, hvhv in self.H.items():
result.H[tail] = list(flatten(
(head, value_map.get(value, value))
for head, value in interleave(hvhv)
for head, value in interleave(hvhv)))
return result
......@@ -1259,42 +1259,31 @@ def test_argument_passing():
def test_parameter_substitution():
Subs = builder.ParameterSubstitution
subs = builder._substitute_params
def f(x, y):
return (('x', x), ('y', y))
# 'f' already has a parameter 'y'
assert raises(ValueError, Subs, f, dict(x='y'))
# 'f' takes no parameter 'a'
assert raises(ValueError, Subs, f, dict(a='x'))
assert raises(ValueError, subs, f, dict(x='y'))
# reverse argument order
g = Subs(f, dict(x='y', y='x'))
# Swap argument names.
g = subs(f, dict(x='y', y='x'))
assert g(1, 2) == f(1, 2)
assert g(y=1, x=2) == f(x=1, y=2)
# reverse again
h = Subs(g, dict(x='y', y='x'))
# Swap again.
h = subs(g, dict(x='y', y='x'))
assert h(1, 2) == f(1, 2)
assert h(x=1, y=2) == f(x=1, y=2)
# don't nest wrappers inside each other
assert h.function is f
assert h.func is f
# composing maps
g = Subs(f, dict(x='a'))
h = Subs(g, dict(a='b'))
assert h(b=1, y=2) == f(x=1, y=2)
# different names
g = Subs(f, dict(x='a', y='b'))
# Try different names.
g = subs(f, dict(x='a', y='b'))
assert g(1, 2) == f(1, 2)
assert g(a=1, b=2) == f(x=1, y=2)
assert g(1, b=2) == f(1, y=2)
# Can be used in sets/dicts
g = Subs(f, dict(x='a'))
h = Subs(f, dict(x='a'))
# Can substitutions be used in sets/dicts?
g = subs(f, dict(x='a'))
h = subs(f, dict(x='a'))
assert len(set([f, g, h])) == 2
......@@ -1322,8 +1311,6 @@ def test_subs():
return syst.finalized().hamiltonian_submatrix(params=kwargs)
syst = make_system()
# parameter name not an identifier
raises(ValueError, syst.subs, a='not-an-identifier?')
# substituting a paramter that doesn't exist produces a warning
warns(RuntimeWarning, syst.subs, fakeparam='yes')
# name clash in value functions
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment