Skip to content
Snippets Groups Projects
Commit cfa38b5c authored by Joseph Weston's avatar Joseph Weston
Browse files

merge change of semantics for default parameters

The semantics for default parameters are as follows. If a value function has a
parameter that takes a default value, then an exception is raised if the user
ever tries to assign a value to this parameter via 'params'. These semantics
are chosen to eliminate the possibility that "forgotten" default parameters are
not silently overwritten.
parents 714a38b7 c11c43c8
No related branches found
No related tags found
No related merge requests found
......@@ -92,9 +92,11 @@ def get_parameters(func):
Returns
-------
names : list
Positional, keyword and keyword only parameter names in the order
that they appear in the signature of 'func'.
required_params : list
Names of positional, and keyword only parameters that do not have a
default value and that appear in the signature of 'func'.
default_params : list
Names of parameters that have a default value.
takes_kwargs : bool
True if 'func' takes '**kwargs'.
"""
......@@ -102,9 +104,12 @@ def get_parameters(func):
pars = sig.parameters
# Signature.parameters is an *ordered mapping*
names = [k for (k, v) in pars.items()
if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY)]
required_params = [k for (k, v) in pars.items()
if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY)
and v.default is inspect._empty]
default_params = [k for (k, v) in pars.items()
if v.default is not inspect._empty]
takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD
for i in pars.values())
return names, takes_kwargs
return required_params, default_params, takes_kwargs
......@@ -1620,9 +1620,9 @@ class Builder:
ham in _ham_param_map):
continue
# parameters come in the same order as in the function signature
params, takes_kwargs = get_parameters(ham)
params, defaults, takes_kwargs = get_parameters(ham)
params = params[skip:] # remove site argument(s)
_ham_param_map[ham] = (params, takes_kwargs)
_ham_param_map[ham] = (params, defaults, takes_kwargs)
#### Assemble and return result.
result = FiniteSystem()
......@@ -1766,9 +1766,9 @@ class Builder:
ham in _ham_param_map):
continue
# parameters come in the same order as in the function signature
params, takes_kwargs = get_parameters(ham)
params, defaults, takes_kwargs = get_parameters(ham)
params = params[skip:] # remove site argument(s)
_ham_param_map[ham] = (params, takes_kwargs)
_ham_param_map[ham] = (params, defaults, takes_kwargs)
#### Assemble and return result.
result = InfiniteSystem()
......@@ -1891,11 +1891,16 @@ class FiniteSystem(system.FiniteSystem):
value = self.onsite_hamiltonians[i]
if callable(value):
if params:
param_names, takes_kwargs = self._ham_param_map[value]
if not takes_kwargs:
params = {pn: params[pn] for pn in param_names}
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(self.sites[i], **params)
value = value(self.sites[i], **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
......@@ -1914,9 +1919,14 @@ class FiniteSystem(system.FiniteSystem):
if callable(value):
sites = self.sites
if params:
param_names, takes_kwargs = self._ham_param_map[value]
if not takes_kwargs:
params = {pn: params[pn] for pn in param_names}
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(sites[i], sites[j], **params)
except Exception as exc:
......@@ -1971,9 +1981,14 @@ class InfiniteSystem(system.InfiniteSystem):
if callable(value):
site = self.symmetry.to_fd(self.sites[i])
if params:
param_names, takes_kwargs = self._ham_param_map[value]
if not takes_kwargs:
params = {pn: params[pn] for pn in param_names}
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site, **params)
except Exception as exc:
......@@ -1995,9 +2010,14 @@ class InfiniteSystem(system.InfiniteSystem):
sites = self.sites
site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
if params:
param_names, takes_kwargs = self._ham_param_map[value]
if not takes_kwargs:
params = {pn: params[pn] for pn in param_names}
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site_i, site_j, **params)
except Exception as exc:
......
......@@ -257,13 +257,13 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
If `onsite` is a function or a mapping (dictionary) then a function
is returned.
"""
parameter_info = ((), False)
parameter_info = ((), (), False)
if callable(onsite):
# make 'onsite' compatible with hamiltonian value functions
parameters, takes_kwargs = get_parameters(onsite)
parameters = parameters[1:] # skip 'site' parameter
parameter_info = (tuple(parameters), takes_kwargs)
required, defaults, takes_kwargs = get_parameters(onsite)
required = required[1:] # skip 'site' parameter
parameter_info = (tuple(required), defaults, takes_kwargs)
try:
_onsite = _FunctionalOnsite(onsite, syst.sites)
except AttributeError:
......@@ -628,9 +628,15 @@ cdef class _LocalOperator:
onsite = self.onsite
check_hermiticity = self.check_hermiticity
param_names, takes_kwargs = self._onsite_params_info
if params and not takes_kwargs:
params = {pn: params[pn] for pn in param_names}
required, defaults, takes_kw = self._onsite_params_info
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if params and not takes_kw:
params = {pn: params[pn] for pn in required}
def get_onsite(a, a_norbs, b, b_norbs):
mat = matrix(onsite(a, *args, **params), complex)
......
......@@ -8,12 +8,15 @@
import warnings
import pickle
from random import Random
import itertools as it
import functools as ft
from random import Random
import numpy as np
import tinyarray as ta
from pytest import raises, warns
from numpy.testing import assert_almost_equal
import tinyarray as ta
import numpy as np
import kwant
from kwant import builder
from kwant._common import ensure_rng
......@@ -1111,45 +1114,69 @@ def test_argument_passing():
chain = kwant.lattice.chain()
# Test for passing parameters to hamiltonian matrix elements
def onsite(site, p1, p2=1):
def onsite(site, p1, p2):
return p1 + p2
def hopping(site1, site2, p1, p2=1):
def hopping(site1, site2, p1, p2):
return p1 - p2
def fill_syst(syst):
def gen_fill_syst(onsite, hopping, syst):
syst[(chain(i) for i in range(3))] = onsite
syst[chain.neighbors()] = hopping
return syst.finalized()
fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
syst = fill_syst(kwant.Builder())
inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,))))
args= (2, 1)
params = dict(p1=2, p2=1)
tests = (
syst.hamiltonian_submatrix,
inf_syst.cell_hamiltonian,
inf_syst.inter_cell_hopping,
inf_syst.selfenergy,
lambda *args, **kw: inf_syst.modes(*args, **kw)[0].wave_functions,
)
np.testing.assert_array_equal(
syst.hamiltonian_submatrix(args=args),
syst.hamiltonian_submatrix(params=params))
np.testing.assert_array_equal(
inf_syst.cell_hamiltonian(args=args),
inf_syst.cell_hamiltonian(params=params))
np.testing.assert_array_equal(
inf_syst.inter_cell_hopping(args=args),
inf_syst.inter_cell_hopping(params=params))
np.testing.assert_array_equal(
inf_syst.selfenergy(args=args),
inf_syst.selfenergy(params=params))
np.testing.assert_array_equal(
inf_syst.modes(args=args)[0].wave_functions,
inf_syst.modes(params=params)[0].wave_functions)
for test in tests:
np.testing.assert_array_equal(
test(args=(2, 1)), test(params=dict(p1=2, p2=1)))
# test that mixing 'args' and 'params' raises TypeError
with raises(TypeError):
syst.hamiltonian(0, 0, *args, params=params)
syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
with raises(TypeError):
inf_syst.hamiltonian(0, 0, *args, params=params)
inf_syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
# test that passing parameters without default values works, and that
# passing parameters with default values fails
def onsite(site, p1, p2=1):
return p1 + p2
def hopping(site, site2, p1, p2=2):
return p1 - p2
fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
syst = fill_syst(kwant.Builder())
inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,))))
tests = (
syst.hamiltonian_submatrix,
inf_syst.cell_hamiltonian,
inf_syst.inter_cell_hopping,
inf_syst.selfenergy,
lambda *args, **kw: inf_syst.modes(*args, **kw)[0].wave_functions,
)
for test in tests:
np.testing.assert_array_equal(
test(args=(1,)), test(params=dict(p1=1)))
# providing value for parameter with default value -- error
for test in tests:
with raises(ValueError):
test(params=dict(p1=1, p2=2))
# Some common, some different args for value functions
def onsite2(site, a, b):
......
......@@ -159,7 +159,7 @@ def test_signatures():
]
for site, params_should_be, should_take_kwargs in onsites:
params, takes_kwargs = get_parameters(wrapped_syst[site])
params, _, takes_kwargs = get_parameters(wrapped_syst[site])
assert params[1:] == params_should_be
assert takes_kwargs == should_take_kwargs
......@@ -170,7 +170,7 @@ def test_signatures():
]
for hopping, params_should_be, should_take_kwargs in hoppings:
params, takes_kwargs = get_parameters(wrapped_syst[hopping])
params, _, takes_kwargs = get_parameters(wrapped_syst[hopping])
assert params[2:] == params_should_be
assert takes_kwargs == should_take_kwargs
......@@ -194,7 +194,7 @@ def test_symmetry():
new = getattr(wrapped, attr)
orig = getattr(syst, attr)
if callable(orig):
params, _ = get_parameters(new)
params, _, _ = get_parameters(new)
assert params[1:] == ['k_x', 'k_y']
assert np.all(orig(None) == new(None, None, None))
else:
......
......@@ -129,7 +129,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')):
else:
return val(a, *args[:mnp])
params, takes_kwargs = get_parameters(val)
params, defaults, takes_kwargs = get_parameters(val)
extra_params = params[1:]
return _modify_signature(f, params + momenta, takes_kwargs)
......@@ -156,7 +156,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')):
params, takes_kwargs = ['_site0'], False
if callable(val):
p, takes_kwargs = get_parameters(val)
p, defaults, takes_kwargs = get_parameters(val)
extra_params = p[2:] # cut off both site parameters
params += extra_params
......@@ -184,7 +184,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')):
params, takes_kwargs = ['_site0', '_site1'], False
if callable(val):
p, takes_kwargs = get_parameters(val)
p, defaults, takes_kwargs = get_parameters(val)
extra_params = p[2:] # cut off site parameters
params += extra_params
......@@ -209,7 +209,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')):
name, inspect.Parameter.POSITIONAL_OR_KEYWORD)
# now all the other parameters, except for the momenta
for val in filter(callable, vals):
val_params, val_takes_kwargs = get_parameters(val)
val_params, defaults, val_takes_kwargs = get_parameters(val)
val_params = val_params[num_sites:] # remove site parameters
takes_kwargs = takes_kwargs or val_takes_kwargs
for p in val_params:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment