From f57b86124bc214afc0d8e321240a0fee4a44ed33 Mon Sep 17 00:00:00 2001 From: Rafal Skolasinski <r.j.skolasinski@gmail.com> Date: Thu, 4 May 2017 17:09:36 +0200 Subject: [PATCH] fix bug of repeated substitutions Calling "kwant.continuum.discretize('A * k_x', substitutions={'A': 'A + B'})" ends with performing substitutions twice. Fixed by removing redundant second substitution. Fixes issue #121. --- kwant/continuum/discretizer.py | 2 +- kwant/continuum/tests/test_discretizer.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/kwant/continuum/discretizer.py b/kwant/continuum/discretizer.py index c10c6e07..511be0d2 100644 --- a/kwant/continuum/discretizer.py +++ b/kwant/continuum/discretizer.py @@ -78,7 +78,7 @@ def discretize(hamiltonian, discrete_coordinates=None, *, grid_spacing=1, substitutions=substitutions, verbose=verbose) return build_discretized(tb, coords, grid_spacing=grid_spacing, - substitutions=substitutions, verbose=verbose) + verbose=verbose) def discretize_symbolic(hamiltonian, discrete_coordinates=None, *, diff --git a/kwant/continuum/tests/test_discretizer.py b/kwant/continuum/tests/test_discretizer.py index 1592dac4..b6d7222e 100644 --- a/kwant/continuum/tests/test_discretizer.py +++ b/kwant/continuum/tests/test_discretizer.py @@ -139,6 +139,7 @@ def test_simple_derivations(commutative): @pytest.mark.parametrize('e_to_subs, e, subs', [ + ('A * k_x', '(A + B) * k_x', {'A': 'A + B'}), ('k_x', 'k_x + k_y', {'k_x': 'k_x + k_y'}), ('k_x**2 + V', 'k_x**2 + V + V_0', {'V': 'V + V_0'}), ('k_x**2 + A + C', 'k_x**2 + B + 5', {'A': 'B + 5', 'C': 0}), @@ -393,6 +394,19 @@ def test_numeric_functions_basic_string(): assert +1j * p['t'] == builder[lat(1), lat(0)](None, None, **p) +@pytest.mark.parametrize('e_to_subs, e, subs', [ + ('A * k_x + V', '(A + B) * k_x + A + B', {'A': 'A + B', 'V': 'A + B'}), +]) +def test_numeric_functions_with_subs(e_to_subs, e, subs): + p = {'A': 1, 'B': 2} + builder_direct = discretize(e) + builder_subs = discretize(e_to_subs, substitutions=subs) + + lat = next(iter(builder_direct.sites()))[0] + assert builder_direct[lat(0)](None, **p) == builder_subs[lat(0)](None, **p) + + + def test_numeric_functions_advance(): hams = [ kx**2, -- GitLab