Commit 1f830ce0 authored by Joseph Weston's avatar Joseph Weston

Merge branch 'params' into 'master'

Simplifications and optimizations of how parameters are handled

Closes #228

See merge request kwant/kwant!243
parents 7462481d d2145780
Pipeline #12213 passed with stages
in 16 minutes and 35 seconds
@@ -1,144 +1,162 @@
@@ -1,144 +1,161 @@
# Tutorial 2.4.2. Closed systems
# ==============================
#
......@@ -41,7 +41,7 @@
rsq = x ** 2 + y ** 2
return rsq < r ** 2
def hopx(site1, site2, B=0):
def hopx(site1, site2, B):
# The magnetic field is controlled by the parameter B
y = site1.pos[1]
return -t * exp(-1j * B * y)
......
......@@ -36,7 +36,7 @@
#HIDDEN_END_ehso
#HIDDEN_BEGIN_coid
def onsite(site, pot=0):
def onsite(site, pot):
return 4 * t + potential(site, pot)
syst[(lat(x, y) for x in range(L) for y in range(W))] = onsite
......
......@@ -24,30 +24,59 @@ would be washed out by the presence of the peak. Now `~kwant.plotter.map`
employs a heuristic for setting the colorscale when there are outliers,
and will emit a warning when this is detected.
System parameter names can be modified
--------------------------------------
After the introduction of ``Builder.fill`` it has become common to construct
System parameter substitution
-----------------------------
After the introduction of ``Builder.fill`` it has become possible to construct
Kwant systems by first creating a "model" system with high symmetry and then
filling a lower symmetry system with this model. Often, however, you want
to use different parameter values in different parts of your system. In
filling a lower symmetry system with this model. Often, however, one wants
to use different parameter values in different parts of a system. In
previous versions of Kwant this was difficult to achieve.
Builders now have a method ``subs`` that makes it easy to substitute different
names for parameters. For example if you have a Builder ``model`` that has
a parameter ``V``, and you wish to have different values for ``V`` in your
scattering region and leads you could do the following::
Builders now have a method ``substitute`` that makes it easy to substitute
different names for parameters. For example if a builder ``model``
has a parameter ``V``, and one wishes to have different values for ``V`` in
the scattering region and leads, one could do the following::
syst = kwant.Builder()
syst.fill(model.subs(V='V_dot', ...))
syst.fill(model.substitute(V='V_dot', ...))
lead = kwant.Builder()
lead.fill(model.subs(V='V_lead'), ...)
lead.fill(model.substitute(V='V_lead'), ...)
syst.attach_lead(lead)
fsyst = syst.finalized()
kwant.smatrix(syst, params=dict(V_dot=0, V_lead=1))
Value functions no longer accept default values for parameters
--------------------------------------------------------------
Using value functions with default values for parameters can be
problematic, especially when re-using value functions between simulations.
When parameters have default values it is easy to forget that such a
parameter exists at all, because it is not necessary to provide them explicitly
to functions that use the Kwant system. This means that other value functions
might be introduced that also depend on the same parameter,
but in an inconsistent way (e.g. a parameter 'phi' that is a superconducting
phase in one value function, but a peierls phase in another). This leads
to bugs that are confusing and hard to track down.
Concretely, the above means that the following no longer works::
syst = kwant.Builder()
# Parameter 't' has a default value of 1
def onsite(site, V, t=1):
return V = 2 * t
def hopping(site_a, site_b, t=1):
return -t
syst[...] = onsite
syst[...] = hopping
# Raises ValueError
syst = syst.finalized()
Interpolated density plots
--------------------------
A new function `~kwant.plotter.density` has been added that can be used to
......
......@@ -13,7 +13,6 @@ import inspect
import warnings
import importlib
from contextlib import contextmanager
from collections import namedtuple
__all__ = ['KwantDeprecationWarning', 'UserCodeError']
......@@ -94,35 +93,41 @@ def reraise_warnings(level=3):
warnings.warn(warning.message, stacklevel=level)
_Params = namedtuple('_Params', ('required', 'defaults', 'takes_kwargs'))
def get_parameters(func):
"""Get the names of the parameters to 'func' and whether it takes kwargs.
"""Return the names of parameters of a function.
It is made sure that the function can be called as func(*args) with
'args' corresponding to the returned parameter names.
Returns
-------
required : list
Names of positional, and keyword only parameters that do not have a
default value and that appear in the signature of 'func'.
defaults : list
Names of parameters that have a default value.
takes_kwargs : bool
True if 'func' takes '**kwargs'.
param_names : list
Names of positional parameters that appear in the signature of 'func'.
"""
sig = inspect.signature(func)
pars = sig.parameters
# Signature.parameters is an *ordered mapping*
required_params = [k for (k, v) in pars.items()
if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY)
and v.default is inspect._empty]
default_params = [k for (k, v) in pars.items()
if v.default is not inspect._empty]
takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD
for i in pars.values())
return _Params(required_params, default_params, takes_kwargs)
def error(msg):
fname = inspect.getsourcefile(func)
try:
line = inspect.getsourcelines(func)[1]
except OSError:
line = '<unknown line>'
raise ValueError("{}:\nFile {}, line {}, in {}".format(
msg, repr(fname), line, func.__name__))
P = inspect.Parameter
pars = inspect.signature(func).parameters # an *ordered mapping*
names = []
for k, v in pars.items():
if v.kind in (P.POSITIONAL_ONLY, P.POSITIONAL_OR_KEYWORD):
if v.default is P.empty:
names.append(k)
else:
error("Arguments of value functions "
"must not have default values")
elif v.kind is P.KEYWORD_ONLY:
error("Keyword-only arguments are not allowed in value functions")
elif v.kind in (P.VAR_POSITIONAL, P.VAR_KEYWORD):
error("Value functions must not take *args or **kwargs")
return tuple(names)
class lazy_import:
......
This diff is collapsed.
......@@ -241,8 +241,8 @@ class _FunctionalOnsite:
self.onsite = onsite
self.sites = sites
def __call__(self, site_id, *args, **kwargs):
return self.onsite(self.sites[site_id], *args, **kwargs)
def __call__(self, site_id, *args):
return self.onsite(self.sites[site_id], *args)
class _DictOnsite(_FunctionalOnsite):
......@@ -257,13 +257,11 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
If `onsite` is a function or a mapping (dictionary) then a function
is returned.
"""
parameter_info = ((), (), False)
param_names = ()
if callable(onsite):
# make 'onsite' compatible with hamiltonian value functions
required, defaults, takes_kwargs = get_parameters(onsite)
required = required[1:] # skip 'site' parameter
parameter_info = (tuple(required), defaults, takes_kwargs)
param_names = get_parameters(onsite)[1:]
try:
_onsite = _FunctionalOnsite(onsite, syst.sites)
except AttributeError:
......@@ -301,7 +299,7 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
'different numbers of orbitals on different sites')
raise ValueError(msg)
return _onsite, parameter_info
return _onsite, param_names
cdef class BlockSparseMatrix:
......@@ -434,7 +432,7 @@ cdef class _LocalOperator:
"""
cdef public int check_hermiticity, sum
cdef public object syst, onsite, _onsite_params_info
cdef public object syst, onsite, _onsite_param_names
cdef public gint[:, :] where, _site_ranges
cdef public BlockSparseMatrix _bound_onsite, _bound_hamiltonian
......@@ -448,8 +446,8 @@ cdef class _LocalOperator:
'the site families (lattices).')
self.syst = syst
self.onsite, self._onsite_params_info = \
_normalize_onsite(syst, onsite, check_hermiticity)
self.onsite, self._onsite_param_names = _normalize_onsite(
syst, onsite, check_hermiticity)
self.check_hermiticity = check_hermiticity
self.sum = sum
self._site_ranges = np.asarray(syst.site_ranges, dtype=gint_dtype)
......@@ -598,7 +596,7 @@ cdef class _LocalOperator:
q = cls.__new__(cls)
q.syst = self.syst
q.onsite = self.onsite
q._onsite_params_info = self._onsite_params_info
q._onsite_param_names = self._onsite_param_names
q.where = self.where
q.sum = self.sum
q._site_ranges = self._site_ranges
......@@ -638,23 +636,22 @@ cdef class _LocalOperator:
"""Evaluate the onsite matrices on all elements of `where`"""
assert callable(self.onsite)
assert not (args and params)
params = params or {}
matrix = ta.matrix
onsite = self.onsite
check_hermiticity = self.check_hermiticity
required, defaults, takes_kw = self._onsite_params_info
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if params and not takes_kw:
params = {pn: params[pn] for pn in required}
if params:
try:
args = tuple(params[pn] for pn in self._onsite_param_names)
except KeyError:
missing = [p for p in self._onsite_param_names
if p not in params]
msg = ('Operator is missing required arguments: ',
', '.join(map('"{}"'.format, missing)))
raise TypeError(''.join(msg))
def get_onsite(a, a_norbs, b, b_norbs):
mat = matrix(onsite(a, *args, **params), complex)
mat = matrix(onsite(a, *args), complex)
_check_onsite(mat, a_norbs, check_hermiticity)
return mat
......@@ -679,14 +676,14 @@ cdef class _LocalOperator:
def __getstate__(self):
return (
(self.check_hermiticity, self.sum),
(self.syst, self.onsite, self._onsite_params_info),
(self.syst, self.onsite, self._onsite_param_names),
tuple(map(np.asarray, (self.where, self._site_ranges))),
(self._bound_onsite, self._bound_hamiltonian),
)
def __setstate__(self, state):
((self.check_hermiticity, self.sum),
(self.syst, self.onsite, self._onsite_params_info),
(self.syst, self.onsite, self._onsite_param_names),
(self.where, self._site_ranges),
(self._bound_onsite, self._bound_hamiltonian),
) = state
......
......@@ -294,7 +294,7 @@ def check_onsite(fsyst, sites, subset=False, check_values=True):
site = fsyst.sites[node].tag
freq[site] = freq.get(site, 0) + 1
if check_values and site in sites:
assert fsyst.onsite_hamiltonians[node] is sites[site]
assert fsyst.onsites[node][0] is sites[site]
if not subset:
# Check that all sites of `fsyst` are in `sites`.
for site in freq.keys():
......@@ -310,7 +310,7 @@ def check_hoppings(fsyst, hops):
tail, head = edge
tail = fsyst.sites[tail].tag
head = fsyst.sites[head].tag
value = fsyst.hoppings[edge_id]
value = fsyst.hoppings[edge_id][0]
if value is builder.Other:
assert (head, tail) in hops
else:
......@@ -1208,12 +1208,16 @@ def test_argument_passing():
with raises(TypeError):
inf_syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
# test that missing any parameters raises TypeError
with raises(TypeError):
syst.hamiltonian(0, 0, params=dict(fake=10))
# test that passing parameters without default values works, and that
# passing parameters with default values fails
def onsite(site, p1, p2=1):
def onsite(site, p1, p2):
return p1 + p2
def hopping(site, site2, p1, p2=2):
def hopping(site, site2, p1, p2):
return p1 - p2
fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
......@@ -1231,12 +1235,7 @@ def test_argument_passing():
for test in tests:
np.testing.assert_array_equal(
test(args=(1,)), test(params=dict(p1=1)))
# providing value for parameter with default value -- error
for test in tests:
with raises(ValueError):
test(params=dict(p1=1, p2=2))
test(args=(1, 2)), test(params=dict(p1=1, p2=2)))
# Some common, some different args for value functions
def onsite2(site, a, b):
......@@ -1264,42 +1263,31 @@ def test_argument_passing():
def test_parameter_substitution():
Subs = builder.ParameterSubstitution
subs = builder._substitute_params
def f(x, y):
return (('x', x), ('y', y))
# 'f' already has a parameter 'y'
assert raises(ValueError, Subs, f, dict(x='y'))
# 'f' takes no parameter 'a'
assert raises(ValueError, Subs, f, dict(a='x'))
assert raises(ValueError, subs, f, dict(x='y'))
# reverse argument order
g = Subs(f, dict(x='y', y='x'))
# Swap argument names.
g = subs(f, dict(x='y', y='x'))
assert g(1, 2) == f(1, 2)
assert g(y=1, x=2) == f(x=1, y=2)
# reverse again
h = Subs(g, dict(x='y', y='x'))
# Swap again.
h = subs(g, dict(x='y', y='x'))
assert h(1, 2) == f(1, 2)
assert h(x=1, y=2) == f(x=1, y=2)
# don't nest wrappers inside each other
assert h.function is f
# composing maps
g = Subs(f, dict(x='a'))
h = Subs(g, dict(a='b'))
assert h(b=1, y=2) == f(x=1, y=2)
assert h.func is f
# different names
g = Subs(f, dict(x='a', y='b'))
# Try different names.
g = subs(f, dict(x='a', y='b'))
assert g(1, 2) == f(1, 2)
assert g(a=1, b=2) == f(x=1, y=2)
assert g(1, b=2) == f(1, y=2)
# Can be used in sets/dicts
g = Subs(f, dict(x='a'))
h = Subs(f, dict(x='a'))
# Can substitutions be used in sets/dicts?
g = subs(f, dict(x='a'))
h = subs(f, dict(x='a'))
assert len(set([f, g, h])) == 2
......@@ -1327,33 +1315,31 @@ def test_subs():
return syst.finalized().hamiltonian_submatrix(params=kwargs)
syst = make_system()
# parameter name not an identifier
raises(ValueError, syst.subs, a='not-an-identifier?')
# substituting a paramter that doesn't exist produces a warning
warns(RuntimeWarning, syst.subs, fakeparam='yes')
warns(RuntimeWarning, syst.substitute, fakeparam='yes')
# name clash in value functions
raises(ValueError, syst.subs, b='a')
raises(ValueError, syst.subs, b='c')
raises(ValueError, syst.subs, a='site')
raises(ValueError, syst.subs, c='sitea')
# cannot call 'subs' on systems with attached leads, because
raises(ValueError, syst.substitute, b='a')
raises(ValueError, syst.substitute, b='c')
raises(ValueError, syst.substitute, a='site')
raises(ValueError, syst.substitute, c='sitea')
# cannot call 'substitute' on systems with attached leads, because
# it is not clear whether the substitutions should propagate
# into the leads too.
syst = make_system()
lead = make_system(kwant.TranslationalSymmetry((-1,)), n=1)
syst.attach_lead(lead)
raises(ValueError, syst.subs, a='d')
raises(ValueError, syst.substitute, a='d')
# test basic substitutions
syst = make_system()
expected = hamiltonian(syst, a=1, b=2, c=3)
# 1 level of substitutions
sub_syst = syst.subs(a='d', b='e')
sub_syst = syst.substitute(a='d', b='e')
assert np.allclose(hamiltonian(sub_syst, d=1, e=2, c=3), expected)
# 2 levels of substitution
sub_sub_syst = sub_syst.subs(d='g', c='h')
sub_sub_syst = sub_syst.substitute(d='g', c='h')
assert np.allclose(hamiltonian(sub_sub_syst, g=1, e=2, h=3), expected)
# very confusing but technically valid. 'a' does not appear in 'hopping',
# so the signature of 'onsite' is valid.
sub_syst = syst.subs(a='sitea')
sub_syst = syst.substitute(a='sitea')
assert np.allclose(hamiltonian(sub_syst, sitea=1, b=2, c=3), expected)
......@@ -420,6 +420,20 @@ def test_arg_passing(A):
wf = np.ones(len(fsyst.sites))
# test missing params
op = A(fsyst, onsite=lambda x, a, b: 1)
params = dict(a=1)
with raises(TypeError) as exc:
op(wf, params=params)
with raises(TypeError) as exc:
op.act(wf, params=params)
with raises(TypeError) as exc:
op.bind(params=params)
if hasattr(op, 'tocoo'):
with raises(TypeError) as exc:
op.tocoo(params=params)
op = A(fsyst)
canonical_args = (1, 2)
params = dict(a=1, b=2)
......
......@@ -27,7 +27,7 @@ def test_hamiltonian_submatrix():
mat = syst2.hamiltonian_submatrix()
assert mat.shape == (3, 3)
# Sorting is required due to unknown compression order of builder.
perm = np.argsort(syst2.onsite_hamiltonians)
perm = np.argsort([os[0] for os in syst2.onsites])
mat_should_be = np.array([[0, 1j, 0], [-1j, 0.5, 2j], [0, -2j, 1]])
mat = mat[perm, :]
......
# Copyright 2011-2017 Kwant authors.
# Copyright 2011-2018 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -87,39 +87,21 @@ def test_value_types(k=(-1.1, 0.5), E=2, t=1):
H_alt = syst.hamiltonian_submatrix(k, sparse=False)
np.testing.assert_equal(H_alt, H)
# test when Hamiltonian value functions take extra parameters and
# have compatible signatures (can be passed with 'args')
onsites = [
lambda a, E, t: E,
lambda a, E, *args: E,
lambda a, *args: args[0],
]
hoppings = [
lambda a, b, E, t: t,
lambda a, b, E, *args: args[0],
lambda a, b, *args: args[1],
]
args = (E1, t1) + k
for E2, t2 in itertools.product(onsites, hoppings):
syst = wraparound(_simple_syst(lat, E2, t2, sym)).finalized()
H_alt = syst.hamiltonian_submatrix(args, sparse=False)
np.testing.assert_equal(H_alt, H)
# test when hamiltonian value functions take extra parameters and
# have incompatible signaures (must be passed with 'params')
# have incompatible signatures (must be passed with 'params')
onsites = [
lambda a, E: E,
lambda a, **kwargs: kwargs['E'],
lambda a, *, E: E,
lambda a, E, t: E,
lambda a, t, E: E,
]
hoppings = [
lambda a, b, t: t,
lambda a, b, **kwargs: kwargs['t'],
lambda a, b, *, t: t,
lambda a, b, t, E: t,
lambda a, b, E, t: t,
]
params = dict(E=E1, t=t1, **dict(zip(['k_x', 'k_y'], k)))
params = dict(E=E1, t=t1, k_x=k[0], k_y=k[1])
for E2, t2 in itertools.product(onsites, hoppings):
syst = wraparound(_simple_syst(lat, E2, t2, sym)).finalized()
H_alt = syst.hamiltonian_submatrix(params=params,
sparse=False)
H_alt = syst.hamiltonian_submatrix(params=params, sparse=False)
np.testing.assert_equal(H_alt, H)
......@@ -133,8 +115,8 @@ def test_signatures():
syst[lat(-1, 0)] = lambda a, E1: E1
syst[(lat(-1, 0), lat(-1, 1))] = lambda a, b, t1: t1
#
syst[lat(0, 0)] = lambda a, E2, **kwargs: E2
syst[(lat(0, 0), lat(0, 1))] = lambda a, b, t2, **kwargs: t2
syst[lat(0, 0)] = lambda a, E2: E2
syst[(lat(0, 0), lat(0, 1))] = lambda a, b, t2: t2
# hoppings that will be bound as hoppings
syst[(lat(-2, 0), lat(-1, 0))] = -1
......@@ -143,36 +125,34 @@ def test_signatures():
syst[(lat(-2, 0), lat(0, 0))] = -1
syst[(lat(-2, 0), lat(3, 0))] = lambda a, b, t3: t3
syst[(lat(-1, 0), lat(0, 0))] = lambda a, b, t4, **kwargs: t4
syst[(lat(-1, 0), lat(0, 0))] = lambda a, b, t4: t4
syst[(lat(-1, 0), lat(3, 0))] = lambda a, b, t5: t5
wrapped_syst = wraparound(syst)
## testing
momenta = ['k_x', 'k_y']
momenta = ('k_x', 'k_y')
onsites = [
(lat(-2, 0), momenta, False),
(lat(-1, 0), ['E1', 't1'] + momenta, False),
(lat(0, 0), ['E2', 't2'] + momenta, True),
(lat(-2, 0), momenta),
(lat(-1, 0), ('E1', 't1') + momenta),
(lat(0, 0), ('E2', 't2') + momenta),
]
for site, params_should_be, should_take_kwargs in onsites:
params, _, takes_kwargs = get_parameters(wrapped_syst[site])
for site, params_should_be in onsites:
params = get_parameters(wrapped_syst[site])
assert params[1:] == params_should_be
assert takes_kwargs == should_take_kwargs
hoppings = [
((lat(-2, 0), lat(-1, 0)), momenta, False),
((lat(-2, 0), lat(0, 0)), ['t3'] + momenta, False),
((lat(-1, 0), lat(0, 0)), ['t4', 't5'] + momenta, True),
((lat(-2, 0), lat(-1, 0)), momenta),
((lat(-2, 0), lat(0, 0)), ('t3',) + momenta),
((lat(-1, 0), lat(0, 0)), ('t4', 't5') + momenta),
]
for hopping, params_should_be, should_take_kwargs in hoppings:
params, _, takes_kwargs = get_parameters(wrapped_syst[hopping])
for hopping, params_should_be in hoppings:
params = get_parameters(wrapped_syst[hopping])
assert params[2:] == params_should_be
assert takes_kwargs == should_take_kwargs
def test_symmetry():
......@@ -194,8 +174,8 @@ def test_symmetry():
new = getattr(wrapped, attr)
orig = getattr(syst, attr)
if callable(orig):
params, _, _ = get_parameters(new)
assert params[1:] == ['k_x', 'k_y']
params = get_parameters(new)
assert params[1:] == ('k_x', 'k_y')
assert np.all(orig(None) == new(None, None, None))
else:
assert np.all(orig == new)
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment