From d0e37a1167eb66c686ebfef94cf53a9d69698305 Mon Sep 17 00:00:00 2001 From: Rafal Skolasinski <r.j.skolasinski@gmail.com> Date: Mon, 6 Mar 2017 23:09:03 +0100 Subject: [PATCH] fix that commutative momentum operators are ignored by discretizer, closes #99 --- kwant/continuum/discretizer.py | 17 +++---- kwant/continuum/tests/test_discretizer.py | 61 +++++++++++++---------- 2 files changed, 44 insertions(+), 34 deletions(-) diff --git a/kwant/continuum/discretizer.py b/kwant/continuum/discretizer.py index af7cd1c8..7cfe55db 100644 --- a/kwant/continuum/discretizer.py +++ b/kwant/continuum/discretizer.py @@ -21,6 +21,7 @@ from ._common import matrix_monomials ################ Globals variables and definitions _wf = sympy.Function('_internal_unique_name', commutative=False) +_momentum_operators = {s.name: s for s in momentum_operators} _position_operators = {s.name: s for s in position_operators} _displacements = {s: sympy.Symbol('_internal_a_{}'.format(s)) for s in 'xyz'} @@ -103,15 +104,15 @@ def discretize_symbolic(hamiltonian, discrete_coordinates=None, substitutions=No if not isinstance(hamiltonian, (sympy.Expr, sympy.matrices.MatrixBase)): hamiltonian = sympify(hamiltonian, substitutions) - atoms = hamiltonian.atoms(sympy.Symbol) - if not all('a' != s.name for s in atoms): + atoms_names = [s.name for s in hamiltonian.atoms(sympy.Symbol)] + if any( s == 'a' for s in atoms_names): raise ValueError("'a' is a symbol used internally to represent " "lattice spacing; please use a different symbol.") hamiltonian = sympy.expand(hamiltonian) if discrete_coordinates is None: - used_momenta = set(momentum_operators) & set(atoms) - discrete_coordinates = {k.name[-1] for k in used_momenta} + used_momenta = set(_momentum_operators) & set(atoms_names) + discrete_coordinates = {k[-1] for k in used_momenta} else: discrete_coordinates = set(discrete_coordinates) if not discrete_coordinates <= set('xyz'): @@ -341,14 +342,12 @@ def _discretize_expression(expression, discrete_coordinates): output[tuple(offset[n].tolist())] += c.subs(subs) return dict(output) - # main function body starts here - if (isinstance(expression, (int, float, sympy.Integer, sympy.Float)) - or not expression.atoms(sympy.Symbol) & set(momentum_operators)): + # if there are no momenta in the expression, then it is an onsite + atoms_names = [s.name for s in expression.atoms(sympy.Symbol)] + if not set(_momentum_operators) & set(atoms_names): n = len(discrete_coordinates) return {(0,) * n: expression} - assert isinstance(expression, sympy.Expr) - # make sure we have list of summands summands = expression.as_ordered_terms() diff --git a/kwant/continuum/tests/test_discretizer.py b/kwant/continuum/tests/test_discretizer.py index 014c37b3..2b6a69ef 100644 --- a/kwant/continuum/tests/test_discretizer.py +++ b/kwant/continuum/tests/test_discretizer.py @@ -8,6 +8,7 @@ from ..discretizer import _wf import inspect from functools import wraps +import pytest def swallows_extra_kwargs(f): @@ -42,52 +43,56 @@ A, B = sympy.symbols('A B', commutative=False) ns = {'A': A, 'B': B, 'a_x': ax, 'a_y': ay, 'az': az, 'x': x, 'y': y, 'z': z} -def test_reading_coordinates(): +@pytest.mark.parametrize('commutative', [True, False]) +def test_reading_coordinates(commutative): + kx, ky, kz = sympy.symbols('k_x k_y k_z', commutative=commutative) + test = { - kx**2 : {'x'}, - kx**2 + ky**2 : {'x', 'y'}, - kx**2 + ky**2 + kz**2 : {'x', 'y', 'z'}, - ky**2 + kz**2 : {'y', 'z'}, - kz**2 : {'z'}, - kx * A(x,y) * kx : {'x'}, - kx**2 + kz * B(y) : {'x', 'z'}, + kx**2 : ['x'], + kx**2 + ky**2 : ['x', 'y'], + kx**2 + ky**2 + kz**2 : ['x', 'y', 'z'], + ky**2 + kz**2 : ['y', 'z'], + kz**2 : ['z'], + kx * A(x,y) * kx : ['x'], + kx**2 + kz * B(y) : ['x', 'z'], } for inp, out in test.items(): ham, got = discretize_symbolic(inp) - assert all(d in out for d in got),\ - "Should be: _split_factors({})=={}. Not {}".format(inp, out, got) + assert got == out def test_reading_coordinates_matrix(): test = [ - (sympy.Matrix([kx**2]) , {'x'}), - (sympy.Matrix([kx**2 + ky**2]) , {'x', 'y'}), - (sympy.Matrix([kx**2 + ky**2 + kz**2]) , {'x', 'y', 'z'}), - (sympy.Matrix([ky**2 + kz**2]) , {'y', 'z'}), - (sympy.Matrix([kz**2]) , {'z'}), - (sympy.Matrix([kx * A(x,y) * kx]) , {'x'}), - (sympy.Matrix([kx**2 + kz * B(y)]) , {'x', 'z'}), + (sympy.Matrix([sympy.sympify('k_x**2')]) , ['x']), + (sympy.Matrix([kx**2]) , ['x']), + (sympy.Matrix([kx**2 + ky**2]) , ['x', 'y']), + (sympy.Matrix([kx**2 + ky**2 + kz**2]) , ['x', 'y', 'z']), + (sympy.Matrix([ky**2 + kz**2]) , ['y', 'z']), + (sympy.Matrix([kz**2]) , ['z']), + (sympy.Matrix([kx * A(x,y) * kx]) , ['x']), + (sympy.Matrix([kx**2 + kz * B(y)]) , ['x', 'z']), ] for inp, out in test: ham, got = discretize_symbolic(inp) - assert all(d in out for d in got),\ - "Should be: _split_factors({})=={}. Not {}".format(inp, out, got) + assert got == out def test_reading_different_matrix_types(): test = [ - (sympy.MutableMatrix([kx**2]) , {'x'}), - (sympy.ImmutableMatrix([kx**2]) , {'x'}), - (sympy.MutableDenseMatrix([kx**2]) , {'x'}), - (sympy.ImmutableDenseMatrix([kx**2]) , {'x'}), + (sympy.MutableMatrix([kx**2]) , ['x']), + (sympy.ImmutableMatrix([kx**2]) , ['x']), + (sympy.MutableDenseMatrix([kx**2]) , ['x']), + (sympy.ImmutableDenseMatrix([kx**2]) , ['x']), ] for inp, out in test: ham, got = discretize_symbolic(inp) - assert all(d in out for d in got),\ + assert got == out,\ "Should be: _split_factors({})=={}. Not {}".format(inp, out, got) -def test_simple_derivations(): +@pytest.mark.parametrize('commutative', [True, False]) +def test_simple_derivations(commutative): + kx, ky, kz = sympy.symbols('k_x k_y k_z', commutative=commutative) test = { kx**2 : {(0,): 2/a**2, (1,): -1/a**2}, kx**2 + ky**2 : {(0, 1): -1/a**2, (0, 0): 4/a**2, @@ -97,6 +102,8 @@ def test_simple_derivations(): ky**2 + kz**2 : {(0, 1): -1/a**2, (0, 0): 4/a**2, (1, 0): -1/a**2}, kz**2 : {(0,): 2/a**2, (1,): -1/a**2}, + } + non_commutative_test = { kx * A(x,y) * kx : {(1, ): -A(a/2 + x, y)/a**2, (0, ): A(-a/2 + x, y)/a**2 + A(a/2 + x, y)/a**2}, kx**2 + kz * B(y) : {(1, 0): -1/a**2, (0, 1): -I*B(y)/(2*a), @@ -108,6 +115,10 @@ def test_simple_derivations(): kx * (A(x) + B(x)) : {(0,): 0, (1,): -I*A(a + x)/(2*a) - I*B(a + x)/(2*a)}, } + + if not commutative: + test.update(non_commutative_test) + for inp, out in test.items(): got, _ = discretize_symbolic(inp) assert got == out -- GitLab