Commit f05371b2 authored by Christoph Groth's avatar Christoph Groth

merge backwards-compatibility fixes from stable

parents b2071c71 02edf61d
Pipeline #17000 failed with stages
in 46 minutes and 20 seconds
......@@ -219,3 +219,10 @@ image in the fundamental domain.
This change is documented here for completeness. We expect that the vast
majority of users of Kwant will not be affected by it.
.. _whatsnew13-params-api-change:
API change that affects low-level systems
-----------------------------------------
The `~kwant.system.System.hamiltonian` method of low-level systems must now accept a
`params` keyword parameter.
This diff is collapsed.
......@@ -243,17 +243,6 @@ def make_dense_full(ham, args, params, CGraph gr, diag,
return h_sub
def _check_parameters_match(expected_parameters, params):
if params is None:
params = {}
missing = set(expected_parameters) - set(params)
if missing:
msg = ('System is missing required parameters: ',
', '.join(map('"{}"'.format, missing)))
raise TypeError(''.join(msg))
@deprecate_args
@cython.binding(True)
@cython.embedsignature(True)
......@@ -301,9 +290,6 @@ def hamiltonian_submatrix(self, args=(), to_sites=None, from_sites=None,
n = self.graph.num_nodes
matrix = ta.matrix
if not args: # Then perhaps parameters
_check_parameters_match(self.parameters, params)
if from_sites is None:
diag = n * [None]
from_norb = np.empty(n, gint_dtype)
......
......@@ -1868,6 +1868,9 @@ class _FinalizedBuilderMixin:
if param_names is not None: # 'value' is callable
site = self.symmetry.to_fd(self.sites[i])
if params:
# See body of _value_params_pair_cache().
if isinstance(param_names, Exception):
raise param_names
args = map(params.__getitem__, param_names)
try:
value = value(site, *args)
......@@ -1891,6 +1894,9 @@ class _FinalizedBuilderMixin:
sites = self.sites
site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
if params:
# See body of _value_params_pair_cache().
if isinstance(param_names, Exception):
raise param_names
args = map(params.__getitem__, param_names)
try:
value = value(site_i, site_j, *args)
......@@ -1940,7 +1946,18 @@ def _value_params_pair_cache(nstrip):
if isinstance(value, _Substituted):
entry = value.func, value.params[nstrip:]
elif callable(value):
entry = value, get_parameters(value)[nstrip:]
try:
param_names = get_parameters(value)
except ValueError as ex:
# The parameter names are determined and stored in advance
# for future use. This has failed, but it will only turn
# into a problem if user code ever uses the 'params'
# mechanism. To maintain backwards compatibility, we catch
# and store the exception so that it can be raised whenever
# appropriate.
entry = value, ex
else:
entry = value, param_names[nstrip:]
else:
# None means: value is not callable. (That's faster to check.)
entry = value, None
......@@ -2036,13 +2053,23 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
hoppings = [cache(builder._get_edge(sites[tail], sites[head]))
for tail, head in g]
# System parameters are the union of the parameters
# of onsites and hoppings.
# Here 'onsites' and 'hoppings' are pairs whos second element
# is a tuple of parameter names when matrix element is a function,
# and None otherwise.
parameters = frozenset(chain.from_iterable(
p for _, p in chain(onsites, hoppings) if p))
# Compute the union of the parameters of onsites and hoppings. Here,
# 'onsites' and 'hoppings' are pairs whose second element is one of
# three things:
#
# * a tuple of parameter names when the matrix element is a function,
# * 'None' when it is a constant,
# * an exception when the parameter names could not have been
# determined (See body of _value_params_pair_cache()).
parameters = []
for _, names in chain(onsites, hoppings):
if isinstance(names, Exception):
parameters = None
break
if names:
parameters.extend(names)
else:
parameters = frozenset(parameters)
self.graph = g
self.sites = sites
......@@ -2205,13 +2232,23 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
tail, head = sym.to_fd(tail, head)
hoppings.append(cache(builder._get_edge(tail, head)))
# System parameters are the union of the parameters
# of onsites and hoppings.
# Here 'onsites' and 'hoppings' are pairs whos second element
# is a tuple of parameter names when matrix element is a function,
# and None otherwise.
parameters = frozenset(chain.from_iterable(
p for _, p in chain(onsites, hoppings) if p))
# Compute the union of the parameters of onsites and hoppings. Here,
# 'onsites' and 'hoppings' are pairs whose second element is one of
# three things:
#
# * a tuple of parameter names when the matrix element is a function,
# * 'None' when it is a constant,
# * an exception when the parameter names could not have been
# determined (See body of _value_params_pair_cache()).
parameters = []
for _, names in chain(onsites, hoppings):
if isinstance(names, Exception):
parameters = None
break
if names:
parameters.extend(names)
else:
parameters = frozenset(parameters)
self.graph = g
self.sites = sites
......
......@@ -9,6 +9,7 @@
import tempfile
import itertools
import numpy as np
from numpy.testing import assert_equal
import tinyarray as ta
import pytest
......@@ -380,3 +381,18 @@ def test_fd_mismatch():
for k in np.linspace(-np.pi, np.pi, 5)]
assert np.allclose(spectrum1, spectrum2)
# There seems no more specific way to only filter KwantDeprecationWarning.
@pytest.mark.filterwarnings('ignore')
def test_args_params_equivalence():
for lat in [kwant.lattice.square(), kwant.lattice.honeycomb(),
kwant.lattice.kagome()]:
syst = kwant.Builder(kwant.TranslationalSymmetry(*lat.prim_vecs))
syst[lat.shape((lambda pos: True), (0, 0))] = 1
syst[lat.neighbors(1)] = 0.1
syst[lat.neighbors(2)] = lambda a, b, param: 0.01
syst = wraparound(syst).finalized()
shs = syst.hamiltonian_submatrix
assert_equal(shs(args=["bla", 1, 2]),
shs(params=dict(param="bla", k_x=1, k_y=2)))
......@@ -165,42 +165,45 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
params[name] = inspect.Parameter(
name, inspect.Parameter.POSITIONAL_ONLY)
# Add all the other parameters, except for the momenta.
# Add all the other parameters (except for the momenta). Setup the
# 'selections'.
selections = []
for val in vals:
if not callable(val):
selections.append(())
continue
val_params = get_parameters(val)[num_sites:]
assert val_params[mnp:] == momenta
val_params = val_params[:mnp]
selections.append((*site_params, *val_params))
for p in val_params:
# Skip parameters that exist in previously added functions,
# and the momenta, which will be placed at the end.
if p in params or p in momenta:
# Skip parameters that exist in previously added functions.
if p in params:
continue
params[p] = inspect.Parameter(
p, inspect.Parameter.POSITIONAL_ONLY)
# Finally, add the momenta.
for k in momenta:
params[k] = inspect.Parameter(
k, inspect.Parameter.POSITIONAL_ONLY)
# Sort values such that ones with the same arguments are bunched.
# Prepare 'val_selection_pairs' that is used in the function 'f' above.
params_keys = list(params.keys())
val_selection_pairs = []
prev_selection = None
argsort = sorted(range(len(selections)), key=selections.__getitem__)
momenta_sel = tuple(range(mnp, 0, 1))
for i in argsort:
selection = selections[i]
if selection and selection != prev_selection:
prev_selection = selection = tuple(
params_keys.index(s) for s in selection)
params_keys.index(s) for s in selection) + momenta_sel
else:
selection = ()
val_selection_pairs.append((vals[i], selection))
# Finally, add the momenta.
for k in momenta:
params[k] = inspect.Parameter(
k, inspect.Parameter.POSITIONAL_ONLY)
f.__signature__ = inspect.Signature(params.values())
return f
......@@ -220,7 +223,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
sym = TranslationalSymmetry(*periods)
momenta.pop(keep)
momenta = tuple(momenta)
mnp = -len(sym.periods) # Used by the bound functions above.
mnp = -len(momenta) # Used by the bound functions above.
# Store the names of the momentum parameters and the symmetry of the
# old Builder (this will be needed for band structure plotting)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment