Commit 0aa3f324 authored by Christoph Groth's avatar Christoph Groth Committed by Joseph Weston

only accept value functions with named positional arguments

and without default values.
Co-authored-by: Joseph Weston's avatarJoseph Weston <joseph@weston.cloud>
parent 34bcd910
@@ -1,144 +1,162 @@
@@ -1,144 +1,161 @@
# Tutorial 2.4.2. Closed systems
# ==============================
#
......@@ -41,7 +41,7 @@
rsq = x ** 2 + y ** 2
return rsq < r ** 2
def hopx(site1, site2, B=0):
def hopx(site1, site2, B):
# The magnetic field is controlled by the parameter B
y = site1.pos[1]
return -t * exp(-1j * B * y)
......
......@@ -36,7 +36,7 @@
#HIDDEN_END_ehso
#HIDDEN_BEGIN_coid
def onsite(site, pot=0):
def onsite(site, pot):
return 4 * t + potential(site, pot)
syst[(lat(x, y) for x in range(L) for y in range(W))] = onsite
......
......@@ -13,7 +13,6 @@ import inspect
import warnings
import importlib
from contextlib import contextmanager
from collections import namedtuple
__all__ = ['KwantDeprecationWarning', 'UserCodeError']
......@@ -94,36 +93,41 @@ def reraise_warnings(level=3):
warnings.warn(warning.message, stacklevel=level)
_Params = namedtuple('_Params', ('required', 'defaults', 'takes_kwargs'))
def get_parameters(func):
"""Get the names of the parameters to 'func' and whether it takes kwargs.
"""Return the names of parameters of a function.
It is made sure that the function can be called as func(*args) with
'args' corresponding to the returned parameter names.
Returns
-------
required : list
Names of positional, and keyword only parameters that do not have a
default value and that appear in the signature of 'func'.
defaults : list
Names of parameters that have a default value.
takes_kwargs : bool
True if 'func' takes '**kwargs'.
param_names : list
Names of positional parameters that appear in the signature of 'func'.
"""
sig = inspect.signature(func)
pars = sig.parameters
# Signature.parameters is an *ordered mapping*
required_params = [k for (k, v) in pars.items()
if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.KEYWORD_ONLY)
and v.default is inspect._empty]
default_params = [k for (k, v) in pars.items()
if v.default is not inspect._empty]
takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD
for i in pars.values())
return _Params(required_params, default_params, takes_kwargs)
def error(msg):
fname = inspect.getsourcefile(func)
try:
line = inspect.getsourcelines(func)[1]
except OSError:
line = '<unknown line>'
raise ValueError("{}:\nFile {}, line {}, in {}".format(
msg, repr(fname), line, func.__name__))
P = inspect.Parameter
pars = inspect.signature(func).parameters # an *ordered mapping*
names = []
for k, v in pars.items():
if v.kind in (P.POSITIONAL_ONLY, P.POSITIONAL_OR_KEYWORD):
if v.default is P.empty:
names.append(k)
else:
error("Arguments of value functions "
"must not have default values")
elif v.kind is P.KEYWORD_ONLY:
error("Keyword-only arguments are not allowed in value functions")
elif v.kind in (P.VAR_POSITIONAL, P.VAR_KEYWORD):
error("Value functions must not take *args or **kwargs")
return tuple(names)
class lazy_import:
......
......@@ -1407,8 +1407,8 @@ class Builder:
# Get parameter names to be substituted for each function,
# without the 'site' parameter(s)
onsite_params = [get_parameters(v).required[1:] for v in onsites]
hopping_params = [get_parameters(v).required[2:] for v in hoppings]
onsite_params = [get_parameters(v)[1:] for v in onsites]
hopping_params = [get_parameters(v)[2:] for v in hoppings]
system_params = set(flatten(chain(onsite_params, hopping_params)))
nonexistant_params = set(subs.keys()).difference(system_params)
......@@ -1868,21 +1868,16 @@ def _translate_cons_law(cons_law):
class _FinalizedBuilderMixin:
"""Common functionality for all finalized builders"""
def _init_ham_param_maps(self):
"""Find parameters taken by all value functions
def _init_param_names(self):
"""For each value function, store the required parameters.
"""
ham_param_map = {}
for hams, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
for ham in hams:
if (not callable(ham) or ham is Other or
ham in ham_param_map):
pn = {}
for values, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
for value in values:
if not callable(value) or value is Other or value in pn:
continue
# parameters come in the same order as in the function signature
params, defaults, takes_kwargs = get_parameters(ham)
params = params[skip:] # remove site argument(s)
ham_param_map[ham] = (params, defaults, takes_kwargs)
self._ham_param_map = ham_param_map
pn[value] = get_parameters(value)[skip:]
self._param_names = pn
def _init_discrete_symmetries(self, builder):
......@@ -1907,14 +1902,7 @@ class _FinalizedBuilderMixin:
if callable(value):
site = self.symmetry.to_fd(self.sites[i])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
assert not takes_kw
args = map(params.__getitem__, required)
args = map(params.__getitem__, self._param_names[value])
try:
value = value(site, *args)
except Exception as exc:
......@@ -1931,14 +1919,7 @@ class _FinalizedBuilderMixin:
sites = self.sites
site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
if params:
required, defaults, takes_kw = self._ham_param_map[value]
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
assert not takes_kw
args = map(params.__getitem__, required)
args = map(params.__getitem__, self._param_names[value])
try:
value = value(site_i, site_j, *args)
except Exception as exc:
......@@ -2049,7 +2030,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
self.symmetry = builder.symmetry
self.leads = finalized_leads
self.lead_interfaces = lead_interfaces
self._init_ham_param_maps()
self._init_param_names()
self._init_discrete_symmetries(builder)
......@@ -2203,7 +2184,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
self.onsite_hamiltonians = onsite_hamiltonians
self.symmetry = builder.symmetry
self.cell_size = cell_size
self._init_ham_param_maps()
self._init_param_names()
self._init_discrete_symmetries(builder)
......
......@@ -241,8 +241,8 @@ class _FunctionalOnsite:
self.onsite = onsite
self.sites = sites
def __call__(self, site_id, *args, **kwargs):
return self.onsite(self.sites[site_id], *args, **kwargs)
def __call__(self, site_id, *args):
return self.onsite(self.sites[site_id], *args)
class _DictOnsite(_FunctionalOnsite):
......@@ -257,13 +257,11 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
If `onsite` is a function or a mapping (dictionary) then a function
is returned.
"""
parameter_info = ((), (), False)
param_names = ()
if callable(onsite):
# make 'onsite' compatible with hamiltonian value functions
required, defaults, takes_kwargs = get_parameters(onsite)
required = required[1:] # skip 'site' parameter
parameter_info = (tuple(required), defaults, takes_kwargs)
param_names = get_parameters(onsite)[1:]
try:
_onsite = _FunctionalOnsite(onsite, syst.sites)
except AttributeError:
......@@ -301,7 +299,7 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
'different numbers of orbitals on different sites')
raise ValueError(msg)
return _onsite, parameter_info
return _onsite, param_names
cdef class BlockSparseMatrix:
......@@ -434,7 +432,7 @@ cdef class _LocalOperator:
"""
cdef public int check_hermiticity, sum
cdef public object syst, onsite, _onsite_params_info
cdef public object syst, onsite, _onsite_param_names
cdef public gint[:, :] where, _site_ranges
cdef public BlockSparseMatrix _bound_onsite, _bound_hamiltonian
......@@ -448,8 +446,8 @@ cdef class _LocalOperator:
'the site families (lattices).')
self.syst = syst
self.onsite, self._onsite_params_info = \
_normalize_onsite(syst, onsite, check_hermiticity)
self.onsite, self._onsite_param_names = _normalize_onsite(
syst, onsite, check_hermiticity)
self.check_hermiticity = check_hermiticity
self.sum = sum
self._site_ranges = np.asarray(syst.site_ranges, dtype=gint_dtype)
......@@ -598,7 +596,7 @@ cdef class _LocalOperator:
q = cls.__new__(cls)
q.syst = self.syst
q.onsite = self.onsite
q._onsite_params_info = self._onsite_params_info
q._onsite_param_names = self._onsite_param_names
q.where = self.where
q.sum = self.sum
q._site_ranges = self._site_ranges
......@@ -638,23 +636,22 @@ cdef class _LocalOperator:
"""Evaluate the onsite matrices on all elements of `where`"""
assert callable(self.onsite)
assert not (args and params)
params = params or {}
matrix = ta.matrix
onsite = self.onsite
check_hermiticity = self.check_hermiticity
required, defaults, takes_kw = self._onsite_params_info
invalid_params = set(params).intersection(set(defaults))
if invalid_params:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if params and not takes_kw:
params = {pn: params[pn] for pn in required}
if params:
try:
args = tuple(params[pn] for pn in self._onsite_param_names)
except KeyError:
missing = [p for p in self._onsite_param_names
if p not in params]
msg = ('Operator is missing required arguments: ',
', '.join(map('"{}"'.format, missing)))
raise TypeError(''.join(msg))
def get_onsite(a, a_norbs, b, b_norbs):
mat = matrix(onsite(a, *args, **params), complex)
mat = matrix(onsite(a, *args), complex)
_check_onsite(mat, a_norbs, check_hermiticity)
return mat
......@@ -679,14 +676,14 @@ cdef class _LocalOperator:
def __getstate__(self):
return (
(self.check_hermiticity, self.sum),
(self.syst, self.onsite, self._onsite_params_info),
(self.syst, self.onsite, self._onsite_param_names),
tuple(map(np.asarray, (self.where, self._site_ranges))),
(self._bound_onsite, self._bound_hamiltonian),
)
def __setstate__(self, state):
((self.check_hermiticity, self.sum),
(self.syst, self.onsite, self._onsite_params_info),
(self.syst, self.onsite, self._onsite_param_names),
(self.where, self._site_ranges),
(self._bound_onsite, self._bound_hamiltonian),
) = state
......
......@@ -1210,10 +1210,10 @@ def test_argument_passing():
# test that passing parameters without default values works, and that
# passing parameters with default values fails
def onsite(site, p1, p2=1):
def onsite(site, p1, p2):
return p1 + p2
def hopping(site, site2, p1, p2=2):
def hopping(site, site2, p1, p2):
return p1 - p2
fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
......@@ -1231,12 +1231,7 @@ def test_argument_passing():
for test in tests:
np.testing.assert_array_equal(
test(args=(1,)), test(params=dict(p1=1)))
# providing value for parameter with default value -- error
for test in tests:
with raises(ValueError):
test(params=dict(p1=1, p2=2))
test(args=(1, 2)), test(params=dict(p1=1, p2=2)))
# Some common, some different args for value functions
def onsite2(site, a, b):
......
......@@ -420,6 +420,20 @@ def test_arg_passing(A):
wf = np.ones(len(fsyst.sites))
# test missing params
op = A(fsyst, onsite=lambda x, a, b: 1)
params = dict(a=1)
with raises(TypeError) as exc:
op(wf, params=params)
with raises(TypeError) as exc:
op.act(wf, params=params)
with raises(TypeError) as exc:
op.bind(params=params)
if hasattr(op, 'tocoo'):
with raises(TypeError) as exc:
op.tocoo(params=params)
op = A(fsyst)
canonical_args = (1, 2)
params = dict(a=1, b=2)
......
......@@ -132,26 +132,26 @@ def test_signatures():
## testing
momenta = ['k_x', 'k_y']
momenta = ('k_x', 'k_y')
onsites = [
(lat(-2, 0), momenta),
(lat(-1, 0), ['E1', 't1'] + momenta),
(lat(0, 0), ['E2', 't2'] + momenta),
(lat(-1, 0), ('E1', 't1') + momenta),
(lat(0, 0), ('E2', 't2') + momenta),
]
for site, params_should_be in onsites:
params, _, takes_kwargs = get_parameters(wrapped_syst[site])
params = get_parameters(wrapped_syst[site])
assert params[1:] == params_should_be
hoppings = [
((lat(-2, 0), lat(-1, 0)), momenta),
((lat(-2, 0), lat(0, 0)), ['t3'] + momenta),
((lat(-1, 0), lat(0, 0)), ['t4', 't5'] + momenta),
((lat(-2, 0), lat(0, 0)), ('t3',) + momenta),
((lat(-1, 0), lat(0, 0)), ('t4', 't5') + momenta),
]
for hopping, params_should_be in hoppings:
params, _, takes_kwargs = get_parameters(wrapped_syst[hopping])
params = get_parameters(wrapped_syst[hopping])
assert params[2:] == params_should_be
......@@ -174,8 +174,8 @@ def test_symmetry():
new = getattr(wrapped, attr)
orig = getattr(syst, attr)
if callable(orig):
params, _, _ = get_parameters(new)
assert params[1:] == ['k_x', 'k_y']
params = get_parameters(new)
assert params[1:] == ('k_x', 'k_y')
assert np.all(orig(None) == new(None, None, None))
else:
assert np.all(orig == new)
......
......@@ -110,10 +110,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
return val(a, *args[:mnp])
assert callable(val)
params, defaults, takes_kwargs = get_parameters(val)
assert not defaults
assert not takes_kwargs
_set_signature(f, params + momenta)
_set_signature(f, get_parameters(val) + momenta)
return f
@_memoize
......@@ -125,12 +122,9 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
pv = phase * v
return pv + herm_conj(pv)
params = ['_site0']
params = ('_site0',)
if callable(val):
p, defaults, takes_kwargs = get_parameters(val)
assert not defaults
assert not takes_kwargs
params += p[2:] # cut off both site parameters
params += get_parameters(val)[2:] # cut off both site parameters
_set_signature(f, params + momenta)
return f
......@@ -142,12 +136,9 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
v = val(a, sym.act(elem, b), *args[:mnp]) if callable(val) else val
return phase * v
params = ['_site0', '_site1']
params = ('_site0', '_site1')
if callable(val):
p, defaults, takes_kwargs = get_parameters(val)
assert not defaults
assert not takes_kwargs
params += p[2:] # cut off site parameters
params += get_parameters(val)[2:] # cut off site parameters
_set_signature(f, params + momenta)
return f
......@@ -180,10 +171,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
if not callable(val):
selections.append(())
continue
val_params, defaults, val_takes_kwargs = get_parameters(val)
assert not defaults
assert not val_takes_kwargs
val_params = val_params[num_sites:]
val_params = get_parameters(val)[num_sites:]
selections.append((*site_params, *val_params))
for p in val_params:
# Skip parameters that exist in previously added functions,
......@@ -231,6 +219,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
ret = WrappedBuilder(TranslationalSymmetry(periods.pop(keep)))
sym = TranslationalSymmetry(*periods)
momenta.pop(keep)
momenta = tuple(momenta)
mnp = -len(sym.periods) # Used by the bound functions above.
# Store the names of the momentum parameters and the symmetry of the
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment