From 0aa3f324a6328f272728f0872377b29e30e9aa1b Mon Sep 17 00:00:00 2001 From: Christoph Groth <christoph.groth@cea.fr> Date: Thu, 13 Sep 2018 11:38:22 +0200 Subject: [PATCH] only accept value functions with named positional arguments and without default values. Co-authored-by: Joseph Weston <joseph@weston.cloud> --- doc/source/code/figure/closed_system.py.diff | 4 +- doc/source/code/figure/quantum_well.py.diff | 2 +- kwant/_common.py | 56 +++++++++++--------- kwant/builder.py | 47 +++++----------- kwant/operator.pyx | 45 ++++++++-------- kwant/tests/test_builder.py | 11 ++-- kwant/tests/test_operator.py | 14 +++++ kwant/tests/test_wraparound.py | 18 +++---- kwant/wraparound.py | 25 +++------ 9 files changed, 101 insertions(+), 121 deletions(-) diff --git a/doc/source/code/figure/closed_system.py.diff b/doc/source/code/figure/closed_system.py.diff index 21f979f9..d70b1cc7 100644 --- a/doc/source/code/figure/closed_system.py.diff +++ b/doc/source/code/figure/closed_system.py.diff @@ -1,4 +1,4 @@ -@@ -1,144 +1,162 @@ +@@ -1,144 +1,161 @@ # Tutorial 2.4.2. Closed systems # ============================== # @@ -41,7 +41,7 @@ rsq = x ** 2 + y ** 2 return rsq < r ** 2 - def hopx(site1, site2, B=0): + def hopx(site1, site2, B): # The magnetic field is controlled by the parameter B y = site1.pos[1] return -t * exp(-1j * B * y) diff --git a/doc/source/code/figure/quantum_well.py.diff b/doc/source/code/figure/quantum_well.py.diff index 7b81c866..1c81f1c5 100644 --- a/doc/source/code/figure/quantum_well.py.diff +++ b/doc/source/code/figure/quantum_well.py.diff @@ -36,7 +36,7 @@ #HIDDEN_END_ehso #HIDDEN_BEGIN_coid - def onsite(site, pot=0): + def onsite(site, pot): return 4 * t + potential(site, pot) syst[(lat(x, y) for x in range(L) for y in range(W))] = onsite diff --git a/kwant/_common.py b/kwant/_common.py index e2732643..71b4f782 100644 --- a/kwant/_common.py +++ b/kwant/_common.py @@ -13,7 +13,6 @@ import inspect import warnings import importlib from contextlib import contextmanager -from collections import namedtuple __all__ = ['KwantDeprecationWarning', 'UserCodeError'] @@ -94,36 +93,41 @@ def reraise_warnings(level=3): warnings.warn(warning.message, stacklevel=level) -_Params = namedtuple('_Params', ('required', 'defaults', 'takes_kwargs')) - - def get_parameters(func): - """Get the names of the parameters to 'func' and whether it takes kwargs. + """Return the names of parameters of a function. + + It is made sure that the function can be called as func(*args) with + 'args' corresponding to the returned parameter names. Returns ------- - required : list - Names of positional, and keyword only parameters that do not have a - default value and that appear in the signature of 'func'. - defaults : list - Names of parameters that have a default value. - takes_kwargs : bool - True if 'func' takes '**kwargs'. + param_names : list + Names of positional parameters that appear in the signature of 'func'. """ - sig = inspect.signature(func) - pars = sig.parameters - - # Signature.parameters is an *ordered mapping* - required_params = [k for (k, v) in pars.items() - if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.KEYWORD_ONLY) - and v.default is inspect._empty] - default_params = [k for (k, v) in pars.items() - if v.default is not inspect._empty] - takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD - for i in pars.values()) - return _Params(required_params, default_params, takes_kwargs) + def error(msg): + fname = inspect.getsourcefile(func) + try: + line = inspect.getsourcelines(func)[1] + except OSError: + line = '<unknown line>' + raise ValueError("{}:\nFile {}, line {}, in {}".format( + msg, repr(fname), line, func.__name__)) + + P = inspect.Parameter + pars = inspect.signature(func).parameters # an *ordered mapping* + names = [] + for k, v in pars.items(): + if v.kind in (P.POSITIONAL_ONLY, P.POSITIONAL_OR_KEYWORD): + if v.default is P.empty: + names.append(k) + else: + error("Arguments of value functions " + "must not have default values") + elif v.kind is P.KEYWORD_ONLY: + error("Keyword-only arguments are not allowed in value functions") + elif v.kind in (P.VAR_POSITIONAL, P.VAR_KEYWORD): + error("Value functions must not take *args or **kwargs") + return tuple(names) class lazy_import: diff --git a/kwant/builder.py b/kwant/builder.py index fbec28ee..9d83d999 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -1407,8 +1407,8 @@ class Builder: # Get parameter names to be substituted for each function, # without the 'site' parameter(s) - onsite_params = [get_parameters(v).required[1:] for v in onsites] - hopping_params = [get_parameters(v).required[2:] for v in hoppings] + onsite_params = [get_parameters(v)[1:] for v in onsites] + hopping_params = [get_parameters(v)[2:] for v in hoppings] system_params = set(flatten(chain(onsite_params, hopping_params))) nonexistant_params = set(subs.keys()).difference(system_params) @@ -1868,21 +1868,16 @@ def _translate_cons_law(cons_law): class _FinalizedBuilderMixin: """Common functionality for all finalized builders""" - def _init_ham_param_maps(self): - """Find parameters taken by all value functions + def _init_param_names(self): + """For each value function, store the required parameters. """ - ham_param_map = {} - for hams, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]: - for ham in hams: - if (not callable(ham) or ham is Other or - ham in ham_param_map): + pn = {} + for values, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]: + for value in values: + if not callable(value) or value is Other or value in pn: continue - # parameters come in the same order as in the function signature - params, defaults, takes_kwargs = get_parameters(ham) - params = params[skip:] # remove site argument(s) - ham_param_map[ham] = (params, defaults, takes_kwargs) - - self._ham_param_map = ham_param_map + pn[value] = get_parameters(value)[skip:] + self._param_names = pn def _init_discrete_symmetries(self, builder): @@ -1907,14 +1902,7 @@ class _FinalizedBuilderMixin: if callable(value): site = self.symmetry.to_fd(self.sites[i]) if params: - required, defaults, takes_kw = self._ham_param_map[value] - invalid_params = set(params).intersection(set(defaults)) - if invalid_params: - raise ValueError("Parameters {} have default values " - "and may not be set with 'params'" - .format(', '.join(invalid_params))) - assert not takes_kw - args = map(params.__getitem__, required) + args = map(params.__getitem__, self._param_names[value]) try: value = value(site, *args) except Exception as exc: @@ -1931,14 +1919,7 @@ class _FinalizedBuilderMixin: sites = self.sites site_i, site_j = self.symmetry.to_fd(sites[i], sites[j]) if params: - required, defaults, takes_kw = self._ham_param_map[value] - invalid_params = set(params).intersection(set(defaults)) - if invalid_params: - raise ValueError("Parameters {} have default values " - "and may not be set with 'params'" - .format(', '.join(invalid_params))) - assert not takes_kw - args = map(params.__getitem__, required) + args = map(params.__getitem__, self._param_names[value]) try: value = value(site_i, site_j, *args) except Exception as exc: @@ -2049,7 +2030,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): self.symmetry = builder.symmetry self.leads = finalized_leads self.lead_interfaces = lead_interfaces - self._init_ham_param_maps() + self._init_param_names() self._init_discrete_symmetries(builder) @@ -2203,7 +2184,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): self.onsite_hamiltonians = onsite_hamiltonians self.symmetry = builder.symmetry self.cell_size = cell_size - self._init_ham_param_maps() + self._init_param_names() self._init_discrete_symmetries(builder) diff --git a/kwant/operator.pyx b/kwant/operator.pyx index 2787b281..e0483d60 100644 --- a/kwant/operator.pyx +++ b/kwant/operator.pyx @@ -241,8 +241,8 @@ class _FunctionalOnsite: self.onsite = onsite self.sites = sites - def __call__(self, site_id, *args, **kwargs): - return self.onsite(self.sites[site_id], *args, **kwargs) + def __call__(self, site_id, *args): + return self.onsite(self.sites[site_id], *args) class _DictOnsite(_FunctionalOnsite): @@ -257,13 +257,11 @@ def _normalize_onsite(syst, onsite, check_hermiticity): If `onsite` is a function or a mapping (dictionary) then a function is returned. """ - parameter_info = ((), (), False) + param_names = () if callable(onsite): # make 'onsite' compatible with hamiltonian value functions - required, defaults, takes_kwargs = get_parameters(onsite) - required = required[1:] # skip 'site' parameter - parameter_info = (tuple(required), defaults, takes_kwargs) + param_names = get_parameters(onsite)[1:] try: _onsite = _FunctionalOnsite(onsite, syst.sites) except AttributeError: @@ -301,7 +299,7 @@ def _normalize_onsite(syst, onsite, check_hermiticity): 'different numbers of orbitals on different sites') raise ValueError(msg) - return _onsite, parameter_info + return _onsite, param_names cdef class BlockSparseMatrix: @@ -434,7 +432,7 @@ cdef class _LocalOperator: """ cdef public int check_hermiticity, sum - cdef public object syst, onsite, _onsite_params_info + cdef public object syst, onsite, _onsite_param_names cdef public gint[:, :] where, _site_ranges cdef public BlockSparseMatrix _bound_onsite, _bound_hamiltonian @@ -448,8 +446,8 @@ cdef class _LocalOperator: 'the site families (lattices).') self.syst = syst - self.onsite, self._onsite_params_info = \ - _normalize_onsite(syst, onsite, check_hermiticity) + self.onsite, self._onsite_param_names = _normalize_onsite( + syst, onsite, check_hermiticity) self.check_hermiticity = check_hermiticity self.sum = sum self._site_ranges = np.asarray(syst.site_ranges, dtype=gint_dtype) @@ -598,7 +596,7 @@ cdef class _LocalOperator: q = cls.__new__(cls) q.syst = self.syst q.onsite = self.onsite - q._onsite_params_info = self._onsite_params_info + q._onsite_param_names = self._onsite_param_names q.where = self.where q.sum = self.sum q._site_ranges = self._site_ranges @@ -638,23 +636,22 @@ cdef class _LocalOperator: """Evaluate the onsite matrices on all elements of `where`""" assert callable(self.onsite) assert not (args and params) - params = params or {} matrix = ta.matrix onsite = self.onsite check_hermiticity = self.check_hermiticity - required, defaults, takes_kw = self._onsite_params_info - invalid_params = set(params).intersection(set(defaults)) - if invalid_params: - raise ValueError("Parameters {} have default values " - "and may not be set with 'params'" - .format(', '.join(invalid_params))) - - if params and not takes_kw: - params = {pn: params[pn] for pn in required} + if params: + try: + args = tuple(params[pn] for pn in self._onsite_param_names) + except KeyError: + missing = [p for p in self._onsite_param_names + if p not in params] + msg = ('Operator is missing required arguments: ', + ', '.join(map('"{}"'.format, missing))) + raise TypeError(''.join(msg)) def get_onsite(a, a_norbs, b, b_norbs): - mat = matrix(onsite(a, *args, **params), complex) + mat = matrix(onsite(a, *args), complex) _check_onsite(mat, a_norbs, check_hermiticity) return mat @@ -679,14 +676,14 @@ cdef class _LocalOperator: def __getstate__(self): return ( (self.check_hermiticity, self.sum), - (self.syst, self.onsite, self._onsite_params_info), + (self.syst, self.onsite, self._onsite_param_names), tuple(map(np.asarray, (self.where, self._site_ranges))), (self._bound_onsite, self._bound_hamiltonian), ) def __setstate__(self, state): ((self.check_hermiticity, self.sum), - (self.syst, self.onsite, self._onsite_params_info), + (self.syst, self.onsite, self._onsite_param_names), (self.where, self._site_ranges), (self._bound_onsite, self._bound_hamiltonian), ) = state diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py index 37c9a2bf..df62331f 100644 --- a/kwant/tests/test_builder.py +++ b/kwant/tests/test_builder.py @@ -1210,10 +1210,10 @@ def test_argument_passing(): # test that passing parameters without default values works, and that # passing parameters with default values fails - def onsite(site, p1, p2=1): + def onsite(site, p1, p2): return p1 + p2 - def hopping(site, site2, p1, p2=2): + def hopping(site, site2, p1, p2): return p1 - p2 fill_syst = ft.partial(gen_fill_syst, onsite, hopping) @@ -1231,12 +1231,7 @@ def test_argument_passing(): for test in tests: np.testing.assert_array_equal( - test(args=(1,)), test(params=dict(p1=1))) - - # providing value for parameter with default value -- error - for test in tests: - with raises(ValueError): - test(params=dict(p1=1, p2=2)) + test(args=(1, 2)), test(params=dict(p1=1, p2=2))) # Some common, some different args for value functions def onsite2(site, a, b): diff --git a/kwant/tests/test_operator.py b/kwant/tests/test_operator.py index e38e1177..a9545bf8 100644 --- a/kwant/tests/test_operator.py +++ b/kwant/tests/test_operator.py @@ -420,6 +420,20 @@ def test_arg_passing(A): wf = np.ones(len(fsyst.sites)) + # test missing params + op = A(fsyst, onsite=lambda x, a, b: 1) + params = dict(a=1) + with raises(TypeError) as exc: + op(wf, params=params) + with raises(TypeError) as exc: + op.act(wf, params=params) + with raises(TypeError) as exc: + op.bind(params=params) + if hasattr(op, 'tocoo'): + with raises(TypeError) as exc: + op.tocoo(params=params) + + op = A(fsyst) canonical_args = (1, 2) params = dict(a=1, b=2) diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py index b9953ab1..d126b470 100644 --- a/kwant/tests/test_wraparound.py +++ b/kwant/tests/test_wraparound.py @@ -132,26 +132,26 @@ def test_signatures(): ## testing - momenta = ['k_x', 'k_y'] + momenta = ('k_x', 'k_y') onsites = [ (lat(-2, 0), momenta), - (lat(-1, 0), ['E1', 't1'] + momenta), - (lat(0, 0), ['E2', 't2'] + momenta), + (lat(-1, 0), ('E1', 't1') + momenta), + (lat(0, 0), ('E2', 't2') + momenta), ] for site, params_should_be in onsites: - params, _, takes_kwargs = get_parameters(wrapped_syst[site]) + params = get_parameters(wrapped_syst[site]) assert params[1:] == params_should_be hoppings = [ ((lat(-2, 0), lat(-1, 0)), momenta), - ((lat(-2, 0), lat(0, 0)), ['t3'] + momenta), - ((lat(-1, 0), lat(0, 0)), ['t4', 't5'] + momenta), + ((lat(-2, 0), lat(0, 0)), ('t3',) + momenta), + ((lat(-1, 0), lat(0, 0)), ('t4', 't5') + momenta), ] for hopping, params_should_be in hoppings: - params, _, takes_kwargs = get_parameters(wrapped_syst[hopping]) + params = get_parameters(wrapped_syst[hopping]) assert params[2:] == params_should_be @@ -174,8 +174,8 @@ def test_symmetry(): new = getattr(wrapped, attr) orig = getattr(syst, attr) if callable(orig): - params, _, _ = get_parameters(new) - assert params[1:] == ['k_x', 'k_y'] + params = get_parameters(new) + assert params[1:] == ('k_x', 'k_y') assert np.all(orig(None) == new(None, None, None)) else: assert np.all(orig == new) diff --git a/kwant/wraparound.py b/kwant/wraparound.py index ec854ec1..d707c265 100644 --- a/kwant/wraparound.py +++ b/kwant/wraparound.py @@ -110,10 +110,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): return val(a, *args[:mnp]) assert callable(val) - params, defaults, takes_kwargs = get_parameters(val) - assert not defaults - assert not takes_kwargs - _set_signature(f, params + momenta) + _set_signature(f, get_parameters(val) + momenta) return f @_memoize @@ -125,12 +122,9 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): pv = phase * v return pv + herm_conj(pv) - params = ['_site0'] + params = ('_site0',) if callable(val): - p, defaults, takes_kwargs = get_parameters(val) - assert not defaults - assert not takes_kwargs - params += p[2:] # cut off both site parameters + params += get_parameters(val)[2:] # cut off both site parameters _set_signature(f, params + momenta) return f @@ -142,12 +136,9 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): v = val(a, sym.act(elem, b), *args[:mnp]) if callable(val) else val return phase * v - params = ['_site0', '_site1'] + params = ('_site0', '_site1') if callable(val): - p, defaults, takes_kwargs = get_parameters(val) - assert not defaults - assert not takes_kwargs - params += p[2:] # cut off site parameters + params += get_parameters(val)[2:] # cut off site parameters _set_signature(f, params + momenta) return f @@ -180,10 +171,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): if not callable(val): selections.append(()) continue - val_params, defaults, val_takes_kwargs = get_parameters(val) - assert not defaults - assert not val_takes_kwargs - val_params = val_params[num_sites:] + val_params = get_parameters(val)[num_sites:] selections.append((*site_params, *val_params)) for p in val_params: # Skip parameters that exist in previously added functions, @@ -231,6 +219,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): ret = WrappedBuilder(TranslationalSymmetry(periods.pop(keep))) sym = TranslationalSymmetry(*periods) momenta.pop(keep) + momenta = tuple(momenta) mnp = -len(sym.periods) # Used by the bound functions above. # Store the names of the momentum parameters and the symmetry of the -- GitLab