From 0aa3f324a6328f272728f0872377b29e30e9aa1b Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Thu, 13 Sep 2018 11:38:22 +0200
Subject: [PATCH] only accept value functions with named positional arguments

and without default values.

Co-authored-by: Joseph Weston <joseph@weston.cloud>
---
 doc/source/code/figure/closed_system.py.diff |  4 +-
 doc/source/code/figure/quantum_well.py.diff  |  2 +-
 kwant/_common.py                             | 56 +++++++++++---------
 kwant/builder.py                             | 47 +++++-----------
 kwant/operator.pyx                           | 45 ++++++++--------
 kwant/tests/test_builder.py                  | 11 ++--
 kwant/tests/test_operator.py                 | 14 +++++
 kwant/tests/test_wraparound.py               | 18 +++----
 kwant/wraparound.py                          | 25 +++------
 9 files changed, 101 insertions(+), 121 deletions(-)

diff --git a/doc/source/code/figure/closed_system.py.diff b/doc/source/code/figure/closed_system.py.diff
index 21f979f9..d70b1cc7 100644
--- a/doc/source/code/figure/closed_system.py.diff
+++ b/doc/source/code/figure/closed_system.py.diff
@@ -1,4 +1,4 @@
-@@ -1,144 +1,162 @@
+@@ -1,144 +1,161 @@
  # Tutorial 2.4.2. Closed systems
  # ==============================
  #
@@ -41,7 +41,7 @@
          rsq = x ** 2 + y ** 2
          return rsq < r ** 2
  
-     def hopx(site1, site2, B=0):
+     def hopx(site1, site2, B):
          # The magnetic field is controlled by the parameter B
          y = site1.pos[1]
          return -t * exp(-1j * B * y)
diff --git a/doc/source/code/figure/quantum_well.py.diff b/doc/source/code/figure/quantum_well.py.diff
index 7b81c866..1c81f1c5 100644
--- a/doc/source/code/figure/quantum_well.py.diff
+++ b/doc/source/code/figure/quantum_well.py.diff
@@ -36,7 +36,7 @@
  #HIDDEN_END_ehso
  
  #HIDDEN_BEGIN_coid
-     def onsite(site, pot=0):
+     def onsite(site, pot):
          return 4 * t + potential(site, pot)
  
      syst[(lat(x, y) for x in range(L) for y in range(W))] = onsite
diff --git a/kwant/_common.py b/kwant/_common.py
index e2732643..71b4f782 100644
--- a/kwant/_common.py
+++ b/kwant/_common.py
@@ -13,7 +13,6 @@ import inspect
 import warnings
 import importlib
 from contextlib import contextmanager
-from collections import namedtuple
 
 __all__ = ['KwantDeprecationWarning', 'UserCodeError']
 
@@ -94,36 +93,41 @@ def reraise_warnings(level=3):
         warnings.warn(warning.message, stacklevel=level)
 
 
-_Params = namedtuple('_Params', ('required', 'defaults', 'takes_kwargs'))
-
-
 def get_parameters(func):
-    """Get the names of the parameters to 'func' and whether it takes kwargs.
+    """Return the names of parameters of a function.
+
+    It is made sure that the function can be called as func(*args) with
+    'args' corresponding to the returned parameter names.
 
     Returns
     -------
-    required : list
-        Names of positional, and keyword only parameters that do not have a
-        default value and that appear in the signature of 'func'.
-    defaults : list
-        Names of parameters that have a default value.
-    takes_kwargs : bool
-        True if 'func' takes '**kwargs'.
+    param_names : list
+        Names of positional parameters that appear in the signature of 'func'.
     """
-    sig = inspect.signature(func)
-    pars = sig.parameters
-
-    # Signature.parameters is an *ordered mapping*
-    required_params = [k for (k, v) in pars.items()
-                       if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
-                                     inspect.Parameter.POSITIONAL_ONLY,
-                                     inspect.Parameter.KEYWORD_ONLY)
-                       and v.default is inspect._empty]
-    default_params = [k for (k, v) in pars.items()
-                      if v.default is not inspect._empty]
-    takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD
-                       for i in pars.values())
-    return _Params(required_params, default_params, takes_kwargs)
+    def error(msg):
+        fname = inspect.getsourcefile(func)
+        try:
+            line = inspect.getsourcelines(func)[1]
+        except OSError:
+            line = '<unknown line>'
+        raise ValueError("{}:\nFile {}, line {}, in {}".format(
+            msg, repr(fname), line, func.__name__))
+
+    P = inspect.Parameter
+    pars = inspect.signature(func).parameters # an *ordered mapping*
+    names = []
+    for k, v in pars.items():
+        if v.kind in (P.POSITIONAL_ONLY, P.POSITIONAL_OR_KEYWORD):
+            if v.default is P.empty:
+                names.append(k)
+            else:
+                error("Arguments of value functions "
+                      "must not have default values")
+        elif v.kind is P.KEYWORD_ONLY:
+            error("Keyword-only arguments are not allowed in value functions")
+        elif v.kind in (P.VAR_POSITIONAL, P.VAR_KEYWORD):
+            error("Value functions must not take *args or **kwargs")
+    return tuple(names)
 
 
 class lazy_import:
diff --git a/kwant/builder.py b/kwant/builder.py
index fbec28ee..9d83d999 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -1407,8 +1407,8 @@ class Builder:
 
         # Get parameter names to be substituted for each function,
         # without the 'site' parameter(s)
-        onsite_params = [get_parameters(v).required[1:] for v in onsites]
-        hopping_params = [get_parameters(v).required[2:] for v in hoppings]
+        onsite_params = [get_parameters(v)[1:] for v in onsites]
+        hopping_params = [get_parameters(v)[2:] for v in hoppings]
 
         system_params = set(flatten(chain(onsite_params, hopping_params)))
         nonexistant_params = set(subs.keys()).difference(system_params)
@@ -1868,21 +1868,16 @@ def _translate_cons_law(cons_law):
 class _FinalizedBuilderMixin:
     """Common functionality for all finalized builders"""
 
-    def _init_ham_param_maps(self):
-        """Find parameters taken by all value functions
+    def _init_param_names(self):
+        """For each value function, store the required parameters.
         """
-        ham_param_map = {}
-        for hams, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
-            for ham in hams:
-                if (not callable(ham) or ham is Other or
-                    ham in ham_param_map):
+        pn = {}
+        for values, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
+            for value in values:
+                if not callable(value) or value is Other or value in pn:
                     continue
-                # parameters come in the same order as in the function signature
-                params, defaults, takes_kwargs = get_parameters(ham)
-                params = params[skip:]  # remove site argument(s)
-                ham_param_map[ham] = (params, defaults, takes_kwargs)
-
-        self._ham_param_map = ham_param_map
+                pn[value] = get_parameters(value)[skip:]
+        self._param_names = pn
 
 
     def _init_discrete_symmetries(self, builder):
@@ -1907,14 +1902,7 @@ class _FinalizedBuilderMixin:
             if callable(value):
                 site = self.symmetry.to_fd(self.sites[i])
                 if params:
-                    required, defaults, takes_kw = self._ham_param_map[value]
-                    invalid_params = set(params).intersection(set(defaults))
-                    if invalid_params:
-                        raise ValueError("Parameters {} have default values "
-                                         "and may not be set with 'params'"
-                                         .format(', '.join(invalid_params)))
-                    assert not takes_kw
-                    args = map(params.__getitem__, required)
+                    args = map(params.__getitem__, self._param_names[value])
                 try:
                     value = value(site, *args)
                 except Exception as exc:
@@ -1931,14 +1919,7 @@ class _FinalizedBuilderMixin:
                 sites = self.sites
                 site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
                 if params:
-                    required, defaults, takes_kw = self._ham_param_map[value]
-                    invalid_params = set(params).intersection(set(defaults))
-                    if invalid_params:
-                        raise ValueError("Parameters {} have default values "
-                                         "and may not be set with 'params'"
-                                         .format(', '.join(invalid_params)))
-                    assert not takes_kw
-                    args = map(params.__getitem__, required)
+                    args = map(params.__getitem__, self._param_names[value])
                 try:
                     value = value(site_i, site_j, *args)
                 except Exception as exc:
@@ -2049,7 +2030,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
         self.symmetry = builder.symmetry
         self.leads = finalized_leads
         self.lead_interfaces = lead_interfaces
-        self._init_ham_param_maps()
+        self._init_param_names()
         self._init_discrete_symmetries(builder)
 
 
@@ -2203,7 +2184,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
         self.onsite_hamiltonians = onsite_hamiltonians
         self.symmetry = builder.symmetry
         self.cell_size = cell_size
-        self._init_ham_param_maps()
+        self._init_param_names()
         self._init_discrete_symmetries(builder)
 
 
diff --git a/kwant/operator.pyx b/kwant/operator.pyx
index 2787b281..e0483d60 100644
--- a/kwant/operator.pyx
+++ b/kwant/operator.pyx
@@ -241,8 +241,8 @@ class _FunctionalOnsite:
         self.onsite = onsite
         self.sites = sites
 
-    def __call__(self, site_id, *args, **kwargs):
-        return self.onsite(self.sites[site_id], *args, **kwargs)
+    def __call__(self, site_id, *args):
+        return self.onsite(self.sites[site_id], *args)
 
 
 class _DictOnsite(_FunctionalOnsite):
@@ -257,13 +257,11 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
     If `onsite` is a function or a mapping (dictionary) then a function
     is returned.
     """
-    parameter_info =  ((), (), False)
+    param_names = ()
 
     if callable(onsite):
         # make 'onsite' compatible with hamiltonian value functions
-        required, defaults, takes_kwargs = get_parameters(onsite)
-        required = required[1:]  # skip 'site' parameter
-        parameter_info = (tuple(required), defaults, takes_kwargs)
+        param_names = get_parameters(onsite)[1:]
         try:
             _onsite = _FunctionalOnsite(onsite, syst.sites)
         except AttributeError:
@@ -301,7 +299,7 @@ def _normalize_onsite(syst, onsite, check_hermiticity):
                    'different numbers of orbitals on different sites')
             raise ValueError(msg)
 
-    return _onsite, parameter_info
+    return _onsite, param_names
 
 
 cdef class BlockSparseMatrix:
@@ -434,7 +432,7 @@ cdef class _LocalOperator:
     """
 
     cdef public int check_hermiticity, sum
-    cdef public object syst, onsite, _onsite_params_info
+    cdef public object syst, onsite, _onsite_param_names
     cdef public gint[:, :]  where, _site_ranges
     cdef public BlockSparseMatrix _bound_onsite, _bound_hamiltonian
 
@@ -448,8 +446,8 @@ cdef class _LocalOperator:
                              'the site families (lattices).')
 
         self.syst = syst
-        self.onsite, self._onsite_params_info = \
-            _normalize_onsite(syst, onsite, check_hermiticity)
+        self.onsite, self._onsite_param_names = _normalize_onsite(
+            syst, onsite, check_hermiticity)
         self.check_hermiticity = check_hermiticity
         self.sum = sum
         self._site_ranges = np.asarray(syst.site_ranges, dtype=gint_dtype)
@@ -598,7 +596,7 @@ cdef class _LocalOperator:
         q = cls.__new__(cls)
         q.syst = self.syst
         q.onsite = self.onsite
-        q._onsite_params_info = self._onsite_params_info
+        q._onsite_param_names = self._onsite_param_names
         q.where = self.where
         q.sum = self.sum
         q._site_ranges = self._site_ranges
@@ -638,23 +636,22 @@ cdef class _LocalOperator:
         """Evaluate the onsite matrices on all elements of `where`"""
         assert callable(self.onsite)
         assert not (args and params)
-        params = params or {}
         matrix = ta.matrix
         onsite = self.onsite
         check_hermiticity = self.check_hermiticity
 
-        required, defaults, takes_kw = self._onsite_params_info
-        invalid_params = set(params).intersection(set(defaults))
-        if invalid_params:
-            raise ValueError("Parameters {} have default values "
-                             "and may not be set with 'params'"
-                             .format(', '.join(invalid_params)))
-
-        if params and not takes_kw:
-            params = {pn: params[pn] for pn in required}
+        if params:
+            try:
+                args = tuple(params[pn] for pn in self._onsite_param_names)
+            except KeyError:
+                missing = [p for p in self._onsite_param_names
+                           if p not in params]
+                msg = ('Operator is missing required arguments: ',
+                       ', '.join(map('"{}"'.format, missing)))
+                raise TypeError(''.join(msg))
 
         def get_onsite(a, a_norbs, b, b_norbs):
-            mat = matrix(onsite(a, *args, **params), complex)
+            mat = matrix(onsite(a, *args), complex)
             _check_onsite(mat, a_norbs, check_hermiticity)
             return mat
 
@@ -679,14 +676,14 @@ cdef class _LocalOperator:
     def __getstate__(self):
         return (
             (self.check_hermiticity, self.sum),
-            (self.syst, self.onsite, self._onsite_params_info),
+            (self.syst, self.onsite, self._onsite_param_names),
             tuple(map(np.asarray, (self.where, self._site_ranges))),
             (self._bound_onsite, self._bound_hamiltonian),
         )
 
     def __setstate__(self, state):
         ((self.check_hermiticity, self.sum),
-         (self.syst, self.onsite, self._onsite_params_info),
+         (self.syst, self.onsite, self._onsite_param_names),
          (self.where, self._site_ranges),
          (self._bound_onsite, self._bound_hamiltonian),
         ) = state
diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py
index 37c9a2bf..df62331f 100644
--- a/kwant/tests/test_builder.py
+++ b/kwant/tests/test_builder.py
@@ -1210,10 +1210,10 @@ def test_argument_passing():
 
     # test that passing parameters without default values works, and that
     # passing parameters with default values fails
-    def onsite(site, p1, p2=1):
+    def onsite(site, p1, p2):
         return p1 + p2
 
-    def hopping(site, site2, p1, p2=2):
+    def hopping(site, site2, p1, p2):
         return p1 - p2
 
     fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
@@ -1231,12 +1231,7 @@ def test_argument_passing():
 
     for test in tests:
         np.testing.assert_array_equal(
-            test(args=(1,)), test(params=dict(p1=1)))
-
-    # providing value for parameter with default value -- error
-    for test in tests:
-        with raises(ValueError):
-            test(params=dict(p1=1, p2=2))
+            test(args=(1, 2)), test(params=dict(p1=1, p2=2)))
 
     # Some common, some different args for value functions
     def onsite2(site, a, b):
diff --git a/kwant/tests/test_operator.py b/kwant/tests/test_operator.py
index e38e1177..a9545bf8 100644
--- a/kwant/tests/test_operator.py
+++ b/kwant/tests/test_operator.py
@@ -420,6 +420,20 @@ def test_arg_passing(A):
 
     wf = np.ones(len(fsyst.sites))
 
+    # test missing params
+    op = A(fsyst, onsite=lambda x, a, b: 1)
+    params = dict(a=1)
+    with raises(TypeError) as exc:
+        op(wf, params=params)
+    with raises(TypeError) as exc:
+        op.act(wf, params=params)
+    with raises(TypeError) as exc:
+        op.bind(params=params)
+    if hasattr(op, 'tocoo'):
+        with raises(TypeError) as exc:
+            op.tocoo(params=params)
+
+
     op = A(fsyst)
     canonical_args = (1, 2)
     params = dict(a=1, b=2)
diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py
index b9953ab1..d126b470 100644
--- a/kwant/tests/test_wraparound.py
+++ b/kwant/tests/test_wraparound.py
@@ -132,26 +132,26 @@ def test_signatures():
 
     ## testing
 
-    momenta = ['k_x', 'k_y']
+    momenta = ('k_x', 'k_y')
 
     onsites = [
         (lat(-2, 0), momenta),
-        (lat(-1, 0), ['E1', 't1'] + momenta),
-        (lat(0, 0), ['E2', 't2'] + momenta),
+        (lat(-1, 0), ('E1', 't1') + momenta),
+        (lat(0, 0), ('E2', 't2') + momenta),
     ]
 
     for site, params_should_be in onsites:
-        params, _, takes_kwargs = get_parameters(wrapped_syst[site])
+        params = get_parameters(wrapped_syst[site])
         assert params[1:] == params_should_be
 
     hoppings = [
         ((lat(-2, 0), lat(-1, 0)), momenta),
-        ((lat(-2, 0), lat(0, 0)), ['t3'] + momenta),
-        ((lat(-1, 0), lat(0, 0)), ['t4', 't5'] + momenta),
+        ((lat(-2, 0), lat(0, 0)), ('t3',) + momenta),
+        ((lat(-1, 0), lat(0, 0)), ('t4', 't5') + momenta),
     ]
 
     for hopping, params_should_be in hoppings:
-        params, _, takes_kwargs = get_parameters(wrapped_syst[hopping])
+        params = get_parameters(wrapped_syst[hopping])
         assert params[2:] == params_should_be
 
 
@@ -174,8 +174,8 @@ def test_symmetry():
             new = getattr(wrapped, attr)
             orig = getattr(syst, attr)
             if callable(orig):
-                params, _, _ = get_parameters(new)
-                assert params[1:] == ['k_x', 'k_y']
+                params = get_parameters(new)
+                assert params[1:] == ('k_x', 'k_y')
                 assert np.all(orig(None) == new(None, None, None))
             else:
                 assert np.all(orig == new)
diff --git a/kwant/wraparound.py b/kwant/wraparound.py
index ec854ec1..d707c265 100644
--- a/kwant/wraparound.py
+++ b/kwant/wraparound.py
@@ -110,10 +110,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
             return val(a, *args[:mnp])
 
         assert callable(val)
-        params, defaults, takes_kwargs = get_parameters(val)
-        assert not defaults
-        assert not takes_kwargs
-        _set_signature(f, params + momenta)
+        _set_signature(f, get_parameters(val) + momenta)
         return f
 
     @_memoize
@@ -125,12 +122,9 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
             pv = phase * v
             return pv + herm_conj(pv)
 
-        params = ['_site0']
+        params = ('_site0',)
         if callable(val):
-            p, defaults, takes_kwargs = get_parameters(val)
-            assert not defaults
-            assert not takes_kwargs
-            params += p[2:]     # cut off both site parameters
+            params += get_parameters(val)[2:] # cut off both site parameters
         _set_signature(f, params + momenta)
         return f
 
@@ -142,12 +136,9 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
             v = val(a, sym.act(elem, b), *args[:mnp]) if callable(val) else val
             return phase * v
 
-        params = ['_site0', '_site1']
+        params = ('_site0', '_site1')
         if callable(val):
-            p, defaults, takes_kwargs = get_parameters(val)
-            assert not defaults
-            assert not takes_kwargs
-            params += p[2:]  # cut off site parameters
+            params += get_parameters(val)[2:] # cut off site parameters
         _set_signature(f, params + momenta)
         return f
 
@@ -180,10 +171,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
             if not callable(val):
                 selections.append(())
                 continue
-            val_params, defaults, val_takes_kwargs = get_parameters(val)
-            assert not defaults
-            assert not val_takes_kwargs
-            val_params = val_params[num_sites:]
+            val_params = get_parameters(val)[num_sites:]
             selections.append((*site_params, *val_params))
             for p in val_params:
                 # Skip parameters that exist in previously added functions,
@@ -231,6 +219,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
         ret = WrappedBuilder(TranslationalSymmetry(periods.pop(keep)))
         sym = TranslationalSymmetry(*periods)
         momenta.pop(keep)
+    momenta = tuple(momenta)
     mnp = -len(sym.periods)      # Used by the bound functions above.
 
     # Store the names of the momentum parameters and the symmetry of the
-- 
GitLab