Commit 1f830ce0 authored by Joseph Weston's avatar Joseph Weston

Merge branch 'params' into 'master'

Simplifications and optimizations of how parameters are handled

Closes #228

See merge request !243
parents 7462481d d2145780
Pipeline #12213 passed with stages
in 16 minutes and 35 seconds
@@ -1,144 +1,162 @@
@@ -1,144 +1,161 @@
# Tutorial 2.4.2. Closed systems
# ==============================
#
......@@ -41,7 +41,7 @@
rsq = x ** 2 + y ** 2
return rsq < r ** 2
def hopx(site1, site2, B=0):
def hopx(site1, site2, B):
# The magnetic field is controlled by the parameter B
y = site1.pos[1]
return -t * exp(-1j * B * y)
......
......@@ -36,7 +36,7 @@
#HIDDEN_END_ehso
#HIDDEN_BEGIN_coid
def onsite(site, pot=0):
def onsite(site, pot):
return 4 * t + potential(site, pot)
syst[(lat(x, y) for x in range(L) for y in range(W))] = onsite
......
......@@ -24,30 +24,59 @@ would be washed out by the presence of the peak. Now `~kwant.plotter.map`
employs a heuristic for setting the colorscale when there are outliers,
and will emit a warning when this is detected.
System parameter names can be modified
--------------------------------------
After the introduction of ``Builder.fill`` it has become common to construct
System parameter substitution
-----------------------------
After the introduction of ``Builder.fill`` it has become possible to construct
Kwant systems by first creating a "model" system with high symmetry and then
filling a lower symmetry system with this model. Often, however, you want
to use different parameter values in different parts of your system. In
filling a lower symmetry system with this model. Often, however, one wants
to use different parameter values in different parts of a system. In
previous versions of Kwant this was difficult to achieve.
Builders now have a method ``subs`` that makes it easy to substitute different
names for parameters. For example if you have a Builder ``model`` that has
a parameter ``V``, and you wish to have different values for ``V`` in your
scattering region and leads you could do the following::
Builders now have a method ``substitute`` that makes it easy to substitute
different names for parameters. For example if a builder ``model``
has a parameter ``V``, and one wishes to have different values for ``V`` in
the scattering region and leads, one could do the following::
syst = kwant.Builder()
syst.fill(model.subs(V='V_dot', ...))
syst.fill(model.substitute(V='V_dot', ...))
lead = kwant.Builder()
lead.fill(model.subs(V='V_lead'), ...)
lead.fill(model.substitute(V='V_lead'), ...)
syst.attach_lead(lead)
fsyst = syst.finalized()
kwant.smatrix(syst, params=dict(V_dot=0, V_lead=1))
Value functions no longer accept default values for parameters
--------------------------------------------------------------
Using value functions with default values for parameters can be
problematic, especially when re-using value functions between simulations.
When parameters have default values it is easy to forget that such a
parameter exists at all, because it is not necessary to provide them explicitly
to functions that use the Kwant system. This means that other value functions
might be introduced that also depend on the same parameter,
but in an inconsistent way (e.g. a parameter 'phi' that is a superconducting
phase in one value function, but a peierls phase in another). This leads
to bugs that are confusing and hard to track down.
Concretely, the above means that the following no longer works::
syst = kwant.Builder()
# Parameter 't' has a default value of 1
def onsite(site, V, t=1):
return V = 2 * t
def hopping(site_a, site_b, t=1):
return -t
syst[...] = onsite
syst[...] = hopping
# Raises ValueError
syst = syst.finalized()
Interpolated density plots
--------------------------
A new function `~kwant.plotter.density` has been added that can be used to
......
......@@ -13,7 +13,6 @@ import inspect
import warnings
import importlib
from contextlib import contextmanager
from collections import namedtuple
__all__ = ['KwantDeprecationWarning', 'UserCodeError']
......@@ -94,35 +93,41 @@ def reraise_warnings(level=3):
warnings.warn(warning.message, stacklevel=level)
_Params = namedtuple('_Params', ('required', 'defaults', 'takes_kwargs'))
def get_parameters(func):
"""Get the names of the parameters to 'func' and whether it takes kwargs.
"""Return the names of parameters of a function.
It is made sure that the function can be called as func(*args) with
'args' corresponding to the returned parameter names.
Returns
-------
required : list
Names of positional, and keyword only parameters that do not have a
default value and that appear in the signature of 'func'.
defaults : list
Names of parameters that have a default value.
takes_kwargs : bool
True if 'func' takes '**kwargs'.
param_names : list
Names of positional parameters that appear in the signature of 'func'.
"""
sig = inspect.signature(func)
pars = sig.parameters
# Signature.parameters is an *ordered mapping*
required_params = [k for (k, v) in pars.items()
if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY)
and v.default is inspect._empty]
default_params = [k for (k, v) in pars.items()
if v.default is not inspect._empty]
takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD
for i in pars.values())
return _Params(required_params, default_params, takes_kwargs)
def error(msg):
fname = inspect.getsourcefile(func)
try:
line = inspect.getsourcelines(func)[1]
except OSError:
line = '<unknown line>'
raise ValueError("{}:\nFile {}, line {}, in {}".format(
msg, repr(fname), line, func.__name__))
P = inspect.Parameter
pars = inspect.signature(func).parameters # an *ordered mapping*
names = []
for k, v in pars.items():
if v.kind in (P.POSITIONAL_ONLY, P.POSITIONAL_OR_KEYWORD):
if v.default is P.empty:
names.append(k)
else:
error("Arguments of value functions "
"must not have default values")
elif v.kind is P.KEYWORD_ONLY:
error("Keyword-only arguments are not allowed in value functions")
elif v.kind in (P.VAR_POSITIONAL, P.VAR_KEYWORD):
error("Value functions must not take *args or **kwargs")
return tuple(names)
class lazy_import:
......
......@@ -11,7 +11,7 @@ import warnings
import operator
import collections
import copy
from functools import total_ordering, wraps
from functools import total_ordering, wraps, update_wrapper
from itertools import islice, chain
import inspect
import tinyarray as ta
......@@ -514,6 +514,10 @@ class HermConjOfFunc:
def __call__(self, i, j, *args, **kwargs):
return herm_conj(self.function(j, i, *args, **kwargs))
@property
def __signature__(self):
return inspect.signature(self.function)
################ Leads
......@@ -710,100 +714,57 @@ def _site_ranges(sites):
return site_ranges
def _substitute_sig(func, substitutions):
"""Substitute different parameter names into a function signature.
Parameters
----------
func : callable
The function whose signature we wish to copy
substitutions : dict or iterable
Mapping from old parameter names to new. Will be
fed to 'dict'.
Returns
-------
inspect.Signature
"""
substitutions = dict(substitutions) # Copy because we later destroy it
def new_name(name):
return substitutions.pop(name, name)
sig = inspect.signature(func)
new_params = [param.replace(name=new_name(name))
for name, param in sig.parameters.items()]
if substitutions:
raise ValueError('More substitutions than available parameters.')
return sig.replace(parameters=new_params)
def _compose_maps(f, g):
"""Compose the maps f and g from left to right
Examples
--------
>>> _compose_maps(
... dict(x='a', y='b'),
... dict(a='c', z='e'))
{'x': 'c', 'y': 'b', 'z': 'e'}
"""
composed = dict(g)
for fk, fv in f.items():
if fv in g:
del composed[fv]
composed[fk] = g[fv]
else:
composed[fk] = fv
class _Substituted:
"""Proxy that renames function parameters."""
# eliminate identity maps k -> k
return dict((k, v) for k, v in composed.items() if k != v)
def __init__(self, func, params):
self.func = func
self.params = params
update_wrapper(self, func)
def __eq__(self, other):
if not isinstance(other, _Substituted):
return False
return (self.func == other.func and self.params == other.params)
def _invert_map(substitutions, arguments):
pmap = {new: old for old, new in substitutions}
return {pmap.get(param, param): value
for param, value in arguments.items()}
def __hash__(self):
return hash((self.func, self.params))
@property
def __signature__(self):
return inspect.Signature(
[inspect.Parameter(name, inspect.Parameter.POSITIONAL_ONLY)
for name in self.params])
class ParameterSubstitution:
"""Proxy that renames function parameters."""
__slots__ = ('function', 'substitutions', '__signature__')
def __call__(self, *args):
return self.func(*args)
def __init__(self, function, substitutions):
if isinstance(function, ParameterSubstitution):
self.function = function.function
substitutions = _compose_maps(dict(function.substitutions),
substitutions)
else:
self.function = function
self.substitutions = tuple(sorted(substitutions.items()))
self.__signature__ = _substitute_sig(self.function, self.substitutions)
def __eq__(self, other):
if not isinstance(other, ParameterSubstitution):
return False
return ((self.function, self.substitutions) ==
(other.function, other.substitutions))
def _substitute_params(func, subs):
"""Substitute 'params' from 'subs' into 'func'."""
assert callable(func)
def __hash__(self):
return hash((self.function, self.substitutions))
if isinstance(func, _Substituted):
old_params = func.params
old_func = func.func
else:
old_params = get_parameters(func)
old_func = func
def __call__(self, *args, **kwargs):
arguments = self.__signature__.bind(*args, **kwargs).arguments
return self.function(**_invert_map(self.substitutions, arguments))
params = tuple(subs.get(p, p) for p in old_params)
duplicates = [p for p, count in collections.Counter(params).items()
if count > 1]
if duplicates:
msg = ('Cannot rename parameters ',
','.join('"{}"'.format(d) for d in duplicates),
': parameters with the same name exist')
raise ValueError(''.join(msg))
def _substitute_parameters(value_func, relevant_params, subs):
"""Substitute 'relevant_params' from 'subs' into 'value_func'."""
assert callable(value_func)
relevant_subs = {n: subs[n] for n in relevant_params if n in subs}
if not relevant_subs:
return value_func
return ParameterSubstitution(value_func, relevant_subs)
if params == old_params:
return func
else:
return _Substituted(old_func, params)
class Builder:
......@@ -1379,32 +1340,29 @@ class Builder:
self.update(other)
return self
def subs(self, **subs):
def substitute(self, **subs):
"""Return a copy of this Builder with modified parameter names.
Notes
-----
The value functions of the returned Builder take longer to
evaluate due to the parameter renaming. This overhead may be
a significant fraction of the total time if the original value
function is particularly quick to evaluate.
"""
# Construct the a copy of the system with new value functions.
if self.leads:
raise ValueError("For simplicity, 'subsitute' is limited "
"to builders without leads. Use 'substitute' "
"before attaching leads to avoid this error.")
# Get value *functions* only
onsites = list(set(
onsite for _, onsite in self.site_value_pairs()
if callable(onsite)
))
if callable(onsite)))
hoppings = list(set(
hop for _, hop in self.hopping_value_pairs()
if callable(hop)
))
if callable(hop)))
flatten = chain.from_iterable
# Get parameter names to be substituted for each function,
# without the 'site' parameter(s)
onsite_params = [get_parameters(v).required[1:] for v in onsites]
hopping_params = [get_parameters(v).required[2:] for v in hoppings]
onsite_params = [get_parameters(v)[1:] for v in onsites]
hopping_params = [get_parameters(v)[2:] for v in hoppings]
system_params = set(flatten(chain(onsite_params, hopping_params)))
nonexistant_params = set(subs.keys()).difference(system_params)
......@@ -1416,17 +1374,9 @@ class Builder:
# Precompute map from old onsite/hopping value functions to ones
# with substituted parameters.
value_map = {
value: _substitute_parameters(value, params, subs)
for value, params in chain(zip(onsites, onsite_params),
zip(hoppings, hopping_params))
}
value_map = {value: _substitute_params(value, subs)
for value in chain(onsites, hoppings)}
# Construct the a copy of the system with new value functions.
if self.leads:
raise ValueError("Using 'subs' on a Builder with attached leads "
"is ambiguous. Consider using 'subs' before "
"attaching leads.")
result = copy.copy(self)
# if we don't assign a new list we will inadvertantly add leads to
# the reversed system if we add leads to *this* system
......@@ -1439,8 +1389,7 @@ class Builder:
for tail, hvhv in self.H.items():
result.H[tail] = list(flatten(
(head, value_map.get(value, value))
for head, value in interleave(hvhv)
))
for head, value in interleave(hvhv)))
return result
......@@ -1864,23 +1813,6 @@ def _translate_cons_law(cons_law):
class _FinalizedBuilderMixin:
"""Common functionality for all finalized builders"""
def _init_ham_param_maps(self):
"""Find parameters taken by all value functions
"""
ham_param_map = {}
for hams, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
for ham in hams:
if (not callable(ham) or ham is Other or
ham in ham_param_map):
continue
# parameters come in the same order as in the function signature
params, defaults, takes_kwargs = get_parameters(ham)
params = params[skip:] # remove site argument(s)
ham_param_map[ham] = (params, defaults, takes_kwargs)
self._ham_param_map = ham_param_map
def _init_discrete_symmetries(self, builder):
def operator(op):
return Density(self, op, check_hermiticity=False)
......@@ -1899,56 +1831,44 @@ class _FinalizedBuilderMixin:
if args and params:
raise TypeError("'args' and 'params' are mutually exclusive.")
if i == j:
value = self.onsite_hamiltonians[i]
if callable(value):
value, param_names = self.onsites[i]
if param_names is not None: # 'value' is callable
site = self.symmetry.to_fd(self.sites[i])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site, *args)
except Exception as exc:
_raise_user_error(exc, value)
args = map(params.__getitem__, param_names)
try:
value = value(site, *args)
except Exception as exc:
if isinstance(exc, KeyError) and params:
missing = [p for p in param_names if p not in params]
if missing:
msg = ('System is missing required arguments: ',
', '.join(map('"{}"'.format, missing)))
raise TypeError(''.join(msg))
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
value, param_names = self.hoppings[edge_id]
conj = value is Other
if conj:
i, j = j, i
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
if callable(value):
value, param_names = self.hoppings[edge_id]
if param_names is not None: # 'value' is callable
sites = self.sites
site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site_i, site_j, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site_i, site_j, *args)
except Exception as exc:
_raise_user_error(exc, value)
args = map(params.__getitem__, param_names)
try:
value = value(site_i, site_j, *args)
except Exception as exc:
if isinstance(exc, KeyError) and params:
missing = [p for p in param_names if p not in params]
if missing:
msg = ('System is missing required arguments: ',
', '.join(map('"{}"'.format, missing)))
raise TypeError(''.join(msg))
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
return value
......@@ -1977,6 +1897,26 @@ class _FinalizedBuilderMixin:
self._symmetries))
# The same (value, parameters) pair will be used for many sites/hoppings,
# so we cache it to avoid wasting extra memory.
def _value_params_pair_cache(nstrip):
def get(value):
entry = cache.get(id(value))
if entry is None:
if isinstance(value, _Substituted):
entry = value.func, value.params[nstrip:]
elif callable(value):
entry = value, get_parameters(value)[nstrip:]
else:
# None means: value is not callable. (That's faster to check.)
entry = value, None
cache[id(value)] = entry
return entry
assert nstrip in [1, 2]
cache = {}
return get
class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
"""Finalized `Builder` with leads.
......@@ -2042,20 +1982,24 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
lead_interfaces.append(np.array(interface))
onsite_hamiltonians = [builder.H[site][1] for site in sites]
hoppings = [builder._get_edge(sites[tail], sites[head])
for tail, head in g]
# Because many onsites/hoppings share the same (value, parameter)
# pairs, we keep them in a cache so that we only store a given pair
# in memory *once*. This is a similar idea to interning strings.
cache = _value_params_pair_cache(1)
onsites = [cache(builder.H[site][1]) for site in sites]
cache = _value_params_pair_cache(2)
hoppings = [cache(builder._get_edge(sites[tail], sites[head]))
for tail, head in g]
self.graph = g
self.sites = sites
self.site_ranges = _site_ranges(sites)
self.id_by_site = id_by_site
self.hoppings = hoppings
self.onsite_hamiltonians = onsite_hamiltonians
self.onsites = onsites
self.symmetry = builder.symmetry
self.leads = finalized_leads
self.lead_interfaces = lead_interfaces
self._init_ham_param_maps()
self._init_discrete_symmetries(builder)
......@@ -2164,12 +2108,17 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
for site_id, site in enumerate(sites):
id_by_site[site] = site_id
# In the following, because many onsites/hoppings share the same
# (value, parameter) pairs, we keep them in 'cache' so that we only
# store a given pair in memory *once*. This is like interning strings.
#### Make graph and extract onsite Hamiltonians.
cache = _value_params_pair_cache(1)
g = graph.Graph()
g.num_nodes = len(sites) # Some sites could not appear in any edge.
onsite_hamiltonians = []
onsites = []
for tail_id, tail in enumerate(sites[:cell_size]):
onsite_hamiltonians.append(builder.H[tail][1])
onsites.append(cache(builder.H[tail][1]))
for head in builder._out_neighbors(tail):
head_id = id_by_site.get(head)
if head_id is None:
......@@ -2191,6 +2140,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
g = g.compressed()
#### Extract hoppings.
cache = _value_params_pair_cache(2)
hoppings = []
for tail_id, head_id in g:
tail = sites[tail_id]
......@@ -2199,17 +2149,16 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
# The tail belongs to the previous domain. Find the
# corresponding hopping with the tail in the fund. domain.
tail, head = sym.to_fd(tail, head)
hoppings.append(builder._get_edge(tail, head))
hoppings.append(cache(builder._get_edge(tail, head)))
self.graph = g
self.sites = sites
self.site_ranges = _site_ranges(sites)
self.id_by_site = id_by_site
self.hoppings = hoppings
self.onsite_hamiltonians = onsite_hamiltonians
self.onsites = onsites
self.symmetry = builder.symmetry
self.cell_size = cell_size
self._init_ham_param_maps()
self._init_discrete_symmetries(builder)
......
......@@ -241,8 +241,8 @@ class _FunctionalOnsite:
self.onsite = onsite
self.sites = sites
def __call__(self, site_id, *args, **kwargs):
return self.onsite(self.sites[site_id], *args, **kwargs)
def __call__(self, site_id, *args):
return self.onsite(self.sites[site_id], *args)
class _DictOnsite(_FunctionalOnsite):
......@@ -257,13 +257,11 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
If `onsite` is a function or a mapping (dictionary) then a function
is returned.
"""
parameter_info = ((), (), False)
param_names = ()
if callable(onsite):
# make 'onsite' compatible with hamiltonian value functions
required, defaults, takes_kwargs = get_parameters(onsite)