diff --git a/kwant/_common.py b/kwant/_common.py index d3d6c46acf36e3f8733e33765fd2ce3c7095887a..3af56a12b2618f2cb7f7ba550e0adf13223a6ced 100644 --- a/kwant/_common.py +++ b/kwant/_common.py @@ -92,9 +92,11 @@ def get_parameters(func): Returns ------- - names : list - Positional, keyword and keyword only parameter names in the order - that they appear in the signature of 'func'. + required_params : list + Names of positional, and keyword only parameters that do not have a + default value and that appear in the signature of 'func'. + default_params : list + Names of parameters that have a default value. takes_kwargs : bool True if 'func' takes '**kwargs'. """ @@ -102,9 +104,12 @@ def get_parameters(func): pars = sig.parameters # Signature.parameters is an *ordered mapping* - names = [k for (k, v) in pars.items() - if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY)] + required_params = [k for (k, v) in pars.items() + if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY) + and v.default is inspect._empty] + default_params = [k for (k, v) in pars.items() + if v.default is not inspect._empty] takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD for i in pars.values()) - return names, takes_kwargs + return required_params, default_params, takes_kwargs diff --git a/kwant/builder.py b/kwant/builder.py index bd94cb12b11dd0f65f76b97941af2fec9a2518c1..c97f145c4e761f614d3a5da9738ffa2e261f32c4 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -1620,9 +1620,9 @@ class Builder: ham in _ham_param_map): continue # parameters come in the same order as in the function signature - params, takes_kwargs = get_parameters(ham) + params, defaults, takes_kwargs = get_parameters(ham) params = params[skip:] # remove site argument(s) - _ham_param_map[ham] = (params, takes_kwargs) + _ham_param_map[ham] = (params, defaults, takes_kwargs) #### Assemble and return result. result = FiniteSystem() @@ -1766,9 +1766,9 @@ class Builder: ham in _ham_param_map): continue # parameters come in the same order as in the function signature - params, takes_kwargs = get_parameters(ham) + params, defaults, takes_kwargs = get_parameters(ham) params = params[skip:] # remove site argument(s) - _ham_param_map[ham] = (params, takes_kwargs) + _ham_param_map[ham] = (params, defaults, takes_kwargs) #### Assemble and return result. result = InfiniteSystem() @@ -1891,11 +1891,16 @@ class FiniteSystem(system.FiniteSystem): value = self.onsite_hamiltonians[i] if callable(value): if params: - param_names, takes_kwargs = self._ham_param_map[value] - if not takes_kwargs: - params = {pn: params[pn] for pn in param_names} + required, defaults, takes_kw = self._ham_param_map[value] + invalid_params = set(params).intersection(set(defaults)) + if invalid_params: + raise ValueError("Parameters {} have default values " + "and may not be set with 'params'" + .format(', '.join(invalid_params))) + if not takes_kw: + params = {pn: params[pn] for pn in required} try: - value = value(self.sites[i], **params) + value = value(self.sites[i], **params) except Exception as exc: _raise_user_error(exc, value) else: @@ -1914,9 +1919,14 @@ class FiniteSystem(system.FiniteSystem): if callable(value): sites = self.sites if params: - param_names, takes_kwargs = self._ham_param_map[value] - if not takes_kwargs: - params = {pn: params[pn] for pn in param_names} + required, defaults, takes_kw = self._ham_param_map[value] + invalid_params = set(params).intersection(set(defaults)) + if invalid_params: + raise ValueError("Parameters {} have default values " + "and may not be set with 'params'" + .format(', '.join(invalid_params))) + if not takes_kw: + params = {pn: params[pn] for pn in required} try: value = value(sites[i], sites[j], **params) except Exception as exc: @@ -1971,9 +1981,14 @@ class InfiniteSystem(system.InfiniteSystem): if callable(value): site = self.symmetry.to_fd(self.sites[i]) if params: - param_names, takes_kwargs = self._ham_param_map[value] - if not takes_kwargs: - params = {pn: params[pn] for pn in param_names} + required, defaults, takes_kw = self._ham_param_map[value] + invalid_params = set(params).intersection(set(defaults)) + if invalid_params: + raise ValueError("Parameters {} have default values " + "and may not be set with 'params'" + .format(', '.join(invalid_params))) + if not takes_kw: + params = {pn: params[pn] for pn in required} try: value = value(site, **params) except Exception as exc: @@ -1995,9 +2010,14 @@ class InfiniteSystem(system.InfiniteSystem): sites = self.sites site_i, site_j = self.symmetry.to_fd(sites[i], sites[j]) if params: - param_names, takes_kwargs = self._ham_param_map[value] - if not takes_kwargs: - params = {pn: params[pn] for pn in param_names} + required, defaults, takes_kw = self._ham_param_map[value] + invalid_params = set(params).intersection(set(defaults)) + if invalid_params: + raise ValueError("Parameters {} have default values " + "and may not be set with 'params'" + .format(', '.join(invalid_params))) + if not takes_kw: + params = {pn: params[pn] for pn in required} try: value = value(site_i, site_j, **params) except Exception as exc: diff --git a/kwant/operator.pyx b/kwant/operator.pyx index b1d167f1a791f3148bce11f60920fe9c22924a24..4ef1371eca0384ce0d9ab02a252d5862974db1ae 100644 --- a/kwant/operator.pyx +++ b/kwant/operator.pyx @@ -257,13 +257,13 @@ def _normalize_onsite(syst, onsite, check_hermiticity): If `onsite` is a function or a mapping (dictionary) then a function is returned. """ - parameter_info = ((), False) + parameter_info = ((), (), False) if callable(onsite): # make 'onsite' compatible with hamiltonian value functions - parameters, takes_kwargs = get_parameters(onsite) - parameters = parameters[1:] # skip 'site' parameter - parameter_info = (tuple(parameters), takes_kwargs) + required, defaults, takes_kwargs = get_parameters(onsite) + required = required[1:] # skip 'site' parameter + parameter_info = (tuple(required), defaults, takes_kwargs) try: _onsite = _FunctionalOnsite(onsite, syst.sites) except AttributeError: @@ -628,9 +628,15 @@ cdef class _LocalOperator: onsite = self.onsite check_hermiticity = self.check_hermiticity - param_names, takes_kwargs = self._onsite_params_info - if params and not takes_kwargs: - params = {pn: params[pn] for pn in param_names} + required, defaults, takes_kw = self._onsite_params_info + invalid_params = set(params).intersection(set(defaults)) + if invalid_params: + raise ValueError("Parameters {} have default values " + "and may not be set with 'params'" + .format(', '.join(invalid_params))) + + if params and not takes_kw: + params = {pn: params[pn] for pn in required} def get_onsite(a, a_norbs, b, b_norbs): mat = matrix(onsite(a, *args, **params), complex) diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py index ae4b0e4630eea7402523cdf9dc184d39f6a6a5bd..3f96aef8a37b890a769aab0a1df3a9cf31a94dc7 100644 --- a/kwant/tests/test_builder.py +++ b/kwant/tests/test_builder.py @@ -8,12 +8,15 @@ import warnings import pickle -from random import Random import itertools as it +import functools as ft +from random import Random + +import numpy as np +import tinyarray as ta from pytest import raises, warns from numpy.testing import assert_almost_equal -import tinyarray as ta -import numpy as np + import kwant from kwant import builder from kwant._common import ensure_rng @@ -1111,45 +1114,69 @@ def test_argument_passing(): chain = kwant.lattice.chain() # Test for passing parameters to hamiltonian matrix elements - def onsite(site, p1, p2=1): + def onsite(site, p1, p2): return p1 + p2 - def hopping(site1, site2, p1, p2=1): + def hopping(site1, site2, p1, p2): return p1 - p2 - def fill_syst(syst): + def gen_fill_syst(onsite, hopping, syst): syst[(chain(i) for i in range(3))] = onsite syst[chain.neighbors()] = hopping return syst.finalized() + fill_syst = ft.partial(gen_fill_syst, onsite, hopping) + syst = fill_syst(kwant.Builder()) inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,)))) - args= (2, 1) - params = dict(p1=2, p2=1) + tests = ( + syst.hamiltonian_submatrix, + inf_syst.cell_hamiltonian, + inf_syst.inter_cell_hopping, + inf_syst.selfenergy, + lambda *args, **kw: inf_syst.modes(*args, **kw)[0].wave_functions, + ) - np.testing.assert_array_equal( - syst.hamiltonian_submatrix(args=args), - syst.hamiltonian_submatrix(params=params)) - np.testing.assert_array_equal( - inf_syst.cell_hamiltonian(args=args), - inf_syst.cell_hamiltonian(params=params)) - np.testing.assert_array_equal( - inf_syst.inter_cell_hopping(args=args), - inf_syst.inter_cell_hopping(params=params)) - np.testing.assert_array_equal( - inf_syst.selfenergy(args=args), - inf_syst.selfenergy(params=params)) - np.testing.assert_array_equal( - inf_syst.modes(args=args)[0].wave_functions, - inf_syst.modes(params=params)[0].wave_functions) + for test in tests: + np.testing.assert_array_equal( + test(args=(2, 1)), test(params=dict(p1=2, p2=1))) # test that mixing 'args' and 'params' raises TypeError with raises(TypeError): - syst.hamiltonian(0, 0, *args, params=params) + syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1)) with raises(TypeError): - inf_syst.hamiltonian(0, 0, *args, params=params) + inf_syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1)) + + # test that passing parameters without default values works, and that + # passing parameters with default values fails + def onsite(site, p1, p2=1): + return p1 + p2 + + def hopping(site, site2, p1, p2=2): + return p1 - p2 + + fill_syst = ft.partial(gen_fill_syst, onsite, hopping) + syst = fill_syst(kwant.Builder()) + inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,)))) + + tests = ( + syst.hamiltonian_submatrix, + inf_syst.cell_hamiltonian, + inf_syst.inter_cell_hopping, + inf_syst.selfenergy, + lambda *args, **kw: inf_syst.modes(*args, **kw)[0].wave_functions, + ) + + for test in tests: + np.testing.assert_array_equal( + test(args=(1,)), test(params=dict(p1=1))) + + # providing value for parameter with default value -- error + for test in tests: + with raises(ValueError): + test(params=dict(p1=1, p2=2)) # Some common, some different args for value functions def onsite2(site, a, b): diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py index 9642f43f3b70204a5407a2c83c5d4738c6e2475d..c1233a97cb71aef6e31b240bac616081b6089f0d 100644 --- a/kwant/tests/test_wraparound.py +++ b/kwant/tests/test_wraparound.py @@ -159,7 +159,7 @@ def test_signatures(): ] for site, params_should_be, should_take_kwargs in onsites: - params, takes_kwargs = get_parameters(wrapped_syst[site]) + params, _, takes_kwargs = get_parameters(wrapped_syst[site]) assert params[1:] == params_should_be assert takes_kwargs == should_take_kwargs @@ -170,7 +170,7 @@ def test_signatures(): ] for hopping, params_should_be, should_take_kwargs in hoppings: - params, takes_kwargs = get_parameters(wrapped_syst[hopping]) + params, _, takes_kwargs = get_parameters(wrapped_syst[hopping]) assert params[2:] == params_should_be assert takes_kwargs == should_take_kwargs @@ -194,7 +194,7 @@ def test_symmetry(): new = getattr(wrapped, attr) orig = getattr(syst, attr) if callable(orig): - params, _ = get_parameters(new) + params, _, _ = get_parameters(new) assert params[1:] == ['k_x', 'k_y'] assert np.all(orig(None) == new(None, None, None)) else: diff --git a/kwant/wraparound.py b/kwant/wraparound.py index 4d7f9471af50ad7f35bdddac6e41176bb41eca01..2e1cb1dd1b63b2f7cd051a9687616724b81a6759 100644 --- a/kwant/wraparound.py +++ b/kwant/wraparound.py @@ -129,7 +129,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')): else: return val(a, *args[:mnp]) - params, takes_kwargs = get_parameters(val) + params, defaults, takes_kwargs = get_parameters(val) extra_params = params[1:] return _modify_signature(f, params + momenta, takes_kwargs) @@ -156,7 +156,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')): params, takes_kwargs = ['_site0'], False if callable(val): - p, takes_kwargs = get_parameters(val) + p, defaults, takes_kwargs = get_parameters(val) extra_params = p[2:] # cut off both site parameters params += extra_params @@ -184,7 +184,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')): params, takes_kwargs = ['_site0', '_site1'], False if callable(val): - p, takes_kwargs = get_parameters(val) + p, defaults, takes_kwargs = get_parameters(val) extra_params = p[2:] # cut off site parameters params += extra_params @@ -209,7 +209,7 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')): name, inspect.Parameter.POSITIONAL_OR_KEYWORD) # now all the other parameters, except for the momenta for val in filter(callable, vals): - val_params, val_takes_kwargs = get_parameters(val) + val_params, defaults, val_takes_kwargs = get_parameters(val) val_params = val_params[num_sites:] # remove site parameters takes_kwargs = takes_kwargs or val_takes_kwargs for p in val_params: