Commit 34bcd910 authored by Christoph Groth's avatar Christoph Groth Committed by Joseph Weston

wraparound: use positional args exclusively

This requires changes to other modules, but these were kept to a
minimum.  The most important one is that
builder._FinalizedBuilderMixin.hamiltonian() now always calls value
functions using *args.
parent 7462481d
......@@ -116,6 +116,7 @@ def get_parameters(func):
# Signature.parameters is an *ordered mapping*
required_params = [k for (k, v) in pars.items()
if v.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.KEYWORD_ONLY)
and v.default is inspect._empty]
default_params = [k for (k, v) in pars.items()
......
......@@ -514,6 +514,10 @@ class HermConjOfFunc:
def __call__(self, i, j, *args, **kwargs):
return herm_conj(self.function(j, i, *args, **kwargs))
@property
def __signature__(self):
return inspect.signature(self.function)
################ Leads
......@@ -1909,17 +1913,12 @@ class _FinalizedBuilderMixin:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site, *args)
except Exception as exc:
_raise_user_error(exc, value)
assert not takes_kw
args = map(params.__getitem__, required)
try:
value = value(site, *args)
except Exception as exc:
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
......@@ -1938,17 +1937,12 @@ class _FinalizedBuilderMixin:
raise ValueError("Parameters {} have default values "
"and may not be set with 'params'"
.format(', '.join(invalid_params)))
if not takes_kw:
params = {pn: params[pn] for pn in required}
try:
value = value(site_i, site_j, **params)
except Exception as exc:
_raise_user_error(exc, value)
else:
try:
value = value(site_i, site_j, *args)
except Exception as exc:
_raise_user_error(exc, value)
assert not takes_kw
args = map(params.__getitem__, required)
try:
value = value(site_i, site_j, *args)
except Exception as exc:
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
return value
......
# Copyright 2011-2017 Kwant authors.
# Copyright 2011-2018 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -87,39 +87,21 @@ def test_value_types(k=(-1.1, 0.5), E=2, t=1):
H_alt = syst.hamiltonian_submatrix(k, sparse=False)
np.testing.assert_equal(H_alt, H)
# test when Hamiltonian value functions take extra parameters and
# have compatible signatures (can be passed with 'args')
onsites = [
lambda a, E, t: E,
lambda a, E, *args: E,
lambda a, *args: args[0],
]
hoppings = [
lambda a, b, E, t: t,
lambda a, b, E, *args: args[0],
lambda a, b, *args: args[1],
]
args = (E1, t1) + k
for E2, t2 in itertools.product(onsites, hoppings):
syst = wraparound(_simple_syst(lat, E2, t2, sym)).finalized()
H_alt = syst.hamiltonian_submatrix(args, sparse=False)
np.testing.assert_equal(H_alt, H)
# test when hamiltonian value functions take extra parameters and
# have incompatible signaures (must be passed with 'params')
# have incompatible signatures (must be passed with 'params')
onsites = [
lambda a, E: E,
lambda a, **kwargs: kwargs['E'],
lambda a, *, E: E,
lambda a, E, t: E,
lambda a, t, E: E,
]
hoppings = [
lambda a, b, t: t,
lambda a, b, **kwargs: kwargs['t'],
lambda a, b, *, t: t,
lambda a, b, t, E: t,
lambda a, b, E, t: t,
]
params = dict(E=E1, t=t1, **dict(zip(['k_x', 'k_y'], k)))
params = dict(E=E1, t=t1, k_x=k[0], k_y=k[1])
for E2, t2 in itertools.product(onsites, hoppings):
syst = wraparound(_simple_syst(lat, E2, t2, sym)).finalized()
H_alt = syst.hamiltonian_submatrix(params=params,
sparse=False)
H_alt = syst.hamiltonian_submatrix(params=params, sparse=False)
np.testing.assert_equal(H_alt, H)
......@@ -133,8 +115,8 @@ def test_signatures():
syst[lat(-1, 0)] = lambda a, E1: E1
syst[(lat(-1, 0), lat(-1, 1))] = lambda a, b, t1: t1
#
syst[lat(0, 0)] = lambda a, E2, **kwargs: E2
syst[(lat(0, 0), lat(0, 1))] = lambda a, b, t2, **kwargs: t2
syst[lat(0, 0)] = lambda a, E2: E2
syst[(lat(0, 0), lat(0, 1))] = lambda a, b, t2: t2
# hoppings that will be bound as hoppings
syst[(lat(-2, 0), lat(-1, 0))] = -1
......@@ -143,7 +125,7 @@ def test_signatures():
syst[(lat(-2, 0), lat(0, 0))] = -1
syst[(lat(-2, 0), lat(3, 0))] = lambda a, b, t3: t3
syst[(lat(-1, 0), lat(0, 0))] = lambda a, b, t4, **kwargs: t4
syst[(lat(-1, 0), lat(0, 0))] = lambda a, b, t4: t4
syst[(lat(-1, 0), lat(3, 0))] = lambda a, b, t5: t5
wrapped_syst = wraparound(syst)
......@@ -153,26 +135,24 @@ def test_signatures():
momenta = ['k_x', 'k_y']
onsites = [
(lat(-2, 0), momenta, False),
(lat(-1, 0), ['E1', 't1'] + momenta, False),
(lat(0, 0), ['E2', 't2'] + momenta, True),
(lat(-2, 0), momenta),
(lat(-1, 0), ['E1', 't1'] + momenta),
(lat(0, 0), ['E2', 't2'] + momenta),
]
for site, params_should_be, should_take_kwargs in onsites:
for site, params_should_be in onsites:
params, _, takes_kwargs = get_parameters(wrapped_syst[site])
assert params[1:] == params_should_be
assert takes_kwargs == should_take_kwargs
hoppings = [
((lat(-2, 0), lat(-1, 0)), momenta, False),
((lat(-2, 0), lat(0, 0)), ['t3'] + momenta, False),
((lat(-1, 0), lat(0, 0)), ['t4', 't5'] + momenta, True),
((lat(-2, 0), lat(-1, 0)), momenta),
((lat(-2, 0), lat(0, 0)), ['t3'] + momenta),
((lat(-1, 0), lat(0, 0)), ['t4', 't5'] + momenta),
]
for hopping, params_should_be, should_take_kwargs in hoppings:
for hopping, params_should_be in hoppings:
params, _, takes_kwargs = get_parameters(wrapped_syst[hopping])
assert params[2:] == params_should_be
assert takes_kwargs == should_take_kwargs
def test_symmetry():
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment