From bbb864fe3bf1c7febc639896e5a9e1498c308e6a Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Mon, 8 May 2017 15:26:18 +0200
Subject: [PATCH] add tests for default parameter values in value functions

The semantics for default parameters are as follows. If a value
function has a parameter that takes a default value, then an
exception is raised if the user ever tries to assign a value
to this parameter via 'params'. These semantics are chosen
to eliminate the possibility that "forgotten" default parameters
are not silently overwritten.
---
 kwant/tests/test_builder.py | 77 +++++++++++++++++++++++++------------
 1 file changed, 52 insertions(+), 25 deletions(-)

diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py
index ae4b0e46..3f96aef8 100644
--- a/kwant/tests/test_builder.py
+++ b/kwant/tests/test_builder.py
@@ -8,12 +8,15 @@
 
 import warnings
 import pickle
-from random import Random
 import itertools as it
+import functools as ft
+from random import Random
+
+import numpy as np
+import tinyarray as ta
 from pytest import raises, warns
 from numpy.testing import assert_almost_equal
-import tinyarray as ta
-import numpy as np
+
 import kwant
 from kwant import builder
 from kwant._common import ensure_rng
@@ -1111,45 +1114,69 @@ def test_argument_passing():
     chain = kwant.lattice.chain()
 
     # Test for passing parameters to hamiltonian matrix elements
-    def onsite(site, p1, p2=1):
+    def onsite(site, p1, p2):
         return p1 + p2
 
-    def hopping(site1, site2, p1, p2=1):
+    def hopping(site1, site2, p1, p2):
         return p1 - p2
 
-    def fill_syst(syst):
+    def gen_fill_syst(onsite, hopping, syst):
         syst[(chain(i) for i in range(3))] = onsite
         syst[chain.neighbors()] = hopping
         return syst.finalized()
 
+    fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
+
     syst = fill_syst(kwant.Builder())
     inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,))))
 
-    args= (2, 1)
-    params = dict(p1=2, p2=1)
+    tests = (
+        syst.hamiltonian_submatrix,
+        inf_syst.cell_hamiltonian,
+        inf_syst.inter_cell_hopping,
+        inf_syst.selfenergy,
+        lambda *args, **kw: inf_syst.modes(*args, **kw)[0].wave_functions,
+    )
 
-    np.testing.assert_array_equal(
-        syst.hamiltonian_submatrix(args=args),
-        syst.hamiltonian_submatrix(params=params))
-    np.testing.assert_array_equal(
-        inf_syst.cell_hamiltonian(args=args),
-        inf_syst.cell_hamiltonian(params=params))
-    np.testing.assert_array_equal(
-        inf_syst.inter_cell_hopping(args=args),
-        inf_syst.inter_cell_hopping(params=params))
-    np.testing.assert_array_equal(
-        inf_syst.selfenergy(args=args),
-        inf_syst.selfenergy(params=params))
-    np.testing.assert_array_equal(
-        inf_syst.modes(args=args)[0].wave_functions,
-        inf_syst.modes(params=params)[0].wave_functions)
+    for test in tests:
+        np.testing.assert_array_equal(
+            test(args=(2, 1)), test(params=dict(p1=2, p2=1)))
 
     # test that mixing 'args' and 'params' raises TypeError
     with raises(TypeError):
-        syst.hamiltonian(0, 0, *args, params=params)
+        syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
     with raises(TypeError):
-        inf_syst.hamiltonian(0, 0, *args, params=params)
+        inf_syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
+
+    # test that passing parameters without default values works, and that
+    # passing parameters with default values fails
+    def onsite(site, p1, p2=1):
+        return p1 + p2
+
+    def hopping(site, site2, p1, p2=2):
+        return p1 - p2
+
+    fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
 
+    syst = fill_syst(kwant.Builder())
+    inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,))))
+
+    tests = (
+        syst.hamiltonian_submatrix,
+        inf_syst.cell_hamiltonian,
+        inf_syst.inter_cell_hopping,
+        inf_syst.selfenergy,
+        lambda *args, **kw: inf_syst.modes(*args, **kw)[0].wave_functions,
+    )
+
+    for test in tests:
+        np.testing.assert_array_equal(
+            test(args=(1,)), test(params=dict(p1=1)))
+
+    # providing value for parameter with default value -- error
+    for test in tests:
+        with raises(ValueError):
+            test(params=dict(p1=1, p2=2))
 
     # Some common, some different args for value functions
     def onsite2(site, a, b):
-- 
GitLab