From d0e37a1167eb66c686ebfef94cf53a9d69698305 Mon Sep 17 00:00:00 2001
From: Rafal Skolasinski <r.j.skolasinski@gmail.com>
Date: Mon, 6 Mar 2017 23:09:03 +0100
Subject: [PATCH] fix that commutative momentum operators are ignored by
 discretizer, closes #99

---
 kwant/continuum/discretizer.py            | 17 +++----
 kwant/continuum/tests/test_discretizer.py | 61 +++++++++++++----------
 2 files changed, 44 insertions(+), 34 deletions(-)

diff --git a/kwant/continuum/discretizer.py b/kwant/continuum/discretizer.py
index af7cd1c8..7cfe55db 100644
--- a/kwant/continuum/discretizer.py
+++ b/kwant/continuum/discretizer.py
@@ -21,6 +21,7 @@ from ._common import matrix_monomials
 ################ Globals variables and definitions
 
 _wf = sympy.Function('_internal_unique_name', commutative=False)
+_momentum_operators = {s.name: s for s in momentum_operators}
 _position_operators = {s.name: s for s in position_operators}
 _displacements = {s: sympy.Symbol('_internal_a_{}'.format(s)) for s in 'xyz'}
 
@@ -103,15 +104,15 @@ def discretize_symbolic(hamiltonian, discrete_coordinates=None, substitutions=No
     if not isinstance(hamiltonian, (sympy.Expr, sympy.matrices.MatrixBase)):
         hamiltonian = sympify(hamiltonian, substitutions)
 
-    atoms = hamiltonian.atoms(sympy.Symbol)
-    if not all('a' != s.name for s in atoms):
+    atoms_names = [s.name for s in hamiltonian.atoms(sympy.Symbol)]
+    if any( s == 'a' for s in atoms_names):
         raise ValueError("'a' is a symbol used internally to represent "
                          "lattice spacing; please use a different symbol.")
 
     hamiltonian = sympy.expand(hamiltonian)
     if discrete_coordinates is None:
-        used_momenta = set(momentum_operators) & set(atoms)
-        discrete_coordinates = {k.name[-1] for k in used_momenta}
+        used_momenta = set(_momentum_operators) & set(atoms_names)
+        discrete_coordinates = {k[-1] for k in used_momenta}
     else:
         discrete_coordinates = set(discrete_coordinates)
         if not discrete_coordinates <= set('xyz'):
@@ -341,14 +342,12 @@ def _discretize_expression(expression, discrete_coordinates):
             output[tuple(offset[n].tolist())] += c.subs(subs)
         return dict(output)
 
-    # main function body starts here
-    if (isinstance(expression, (int, float, sympy.Integer, sympy.Float))
-            or not expression.atoms(sympy.Symbol) & set(momentum_operators)):
+    # if there are no momenta in the expression, then it is an onsite
+    atoms_names = [s.name for s in expression.atoms(sympy.Symbol)]
+    if not set(_momentum_operators) & set(atoms_names):
         n = len(discrete_coordinates)
         return {(0,) * n: expression}
 
-    assert isinstance(expression, sympy.Expr)
-
     # make sure we have list of summands
     summands = expression.as_ordered_terms()
 
diff --git a/kwant/continuum/tests/test_discretizer.py b/kwant/continuum/tests/test_discretizer.py
index 014c37b3..2b6a69ef 100644
--- a/kwant/continuum/tests/test_discretizer.py
+++ b/kwant/continuum/tests/test_discretizer.py
@@ -8,6 +8,7 @@ from ..discretizer import  _wf
 
 import inspect
 from functools import wraps
+import pytest
 
 
 def swallows_extra_kwargs(f):
@@ -42,52 +43,56 @@ A, B = sympy.symbols('A B', commutative=False)
 ns = {'A': A, 'B': B, 'a_x': ax, 'a_y': ay, 'az': az, 'x': x, 'y': y, 'z': z}
 
 
-def test_reading_coordinates():
+@pytest.mark.parametrize('commutative', [True, False])
+def test_reading_coordinates(commutative):
+    kx, ky, kz = sympy.symbols('k_x k_y k_z', commutative=commutative)
+
     test = {
-        kx**2                         : {'x'},
-        kx**2 + ky**2                 : {'x', 'y'},
-        kx**2 + ky**2 + kz**2         : {'x', 'y', 'z'},
-        ky**2 + kz**2                 : {'y', 'z'},
-        kz**2                         : {'z'},
-        kx * A(x,y) * kx              : {'x'},
-        kx**2 + kz * B(y)             : {'x', 'z'},
+        kx**2                         : ['x'],
+        kx**2 + ky**2                 : ['x', 'y'],
+        kx**2 + ky**2 + kz**2         : ['x', 'y', 'z'],
+        ky**2 + kz**2                 : ['y', 'z'],
+        kz**2                         : ['z'],
+        kx * A(x,y) * kx              : ['x'],
+        kx**2 + kz * B(y)             : ['x', 'z'],
     }
     for inp, out in test.items():
         ham, got = discretize_symbolic(inp)
-        assert all(d in out for d in got),\
-            "Should be: _split_factors({})=={}. Not {}".format(inp, out, got)
+        assert got == out
 
 
 def test_reading_coordinates_matrix():
     test = [
-        (sympy.Matrix([kx**2])                        , {'x'}),
-        (sympy.Matrix([kx**2 + ky**2])                , {'x', 'y'}),
-        (sympy.Matrix([kx**2 + ky**2 + kz**2])        , {'x', 'y', 'z'}),
-        (sympy.Matrix([ky**2 + kz**2])                , {'y', 'z'}),
-        (sympy.Matrix([kz**2])                        , {'z'}),
-        (sympy.Matrix([kx * A(x,y) * kx])             , {'x'}),
-        (sympy.Matrix([kx**2 + kz * B(y)])            , {'x', 'z'}),
+        (sympy.Matrix([sympy.sympify('k_x**2')])        , ['x']),
+        (sympy.Matrix([kx**2])                        , ['x']),
+        (sympy.Matrix([kx**2 + ky**2])                , ['x', 'y']),
+        (sympy.Matrix([kx**2 + ky**2 + kz**2])        , ['x', 'y', 'z']),
+        (sympy.Matrix([ky**2 + kz**2])                , ['y', 'z']),
+        (sympy.Matrix([kz**2])                        , ['z']),
+        (sympy.Matrix([kx * A(x,y) * kx])             , ['x']),
+        (sympy.Matrix([kx**2 + kz * B(y)])            , ['x', 'z']),
     ]
     for inp, out in test:
         ham, got = discretize_symbolic(inp)
-        assert all(d in out for d in got),\
-            "Should be: _split_factors({})=={}. Not {}".format(inp, out, got)
+        assert got == out
 
 
 def test_reading_different_matrix_types():
     test = [
-        (sympy.MutableMatrix([kx**2])                    , {'x'}),
-        (sympy.ImmutableMatrix([kx**2])                  , {'x'}),
-        (sympy.MutableDenseMatrix([kx**2])               , {'x'}),
-        (sympy.ImmutableDenseMatrix([kx**2])             , {'x'}),
+        (sympy.MutableMatrix([kx**2])                    , ['x']),
+        (sympy.ImmutableMatrix([kx**2])                  , ['x']),
+        (sympy.MutableDenseMatrix([kx**2])               , ['x']),
+        (sympy.ImmutableDenseMatrix([kx**2])             , ['x']),
     ]
     for inp, out in test:
         ham, got = discretize_symbolic(inp)
-        assert all(d in out for d in got),\
+        assert got == out,\
             "Should be: _split_factors({})=={}. Not {}".format(inp, out, got)
 
 
-def test_simple_derivations():
+@pytest.mark.parametrize('commutative', [True, False])
+def test_simple_derivations(commutative):
+    kx, ky, kz = sympy.symbols('k_x k_y k_z', commutative=commutative)
     test = {
         kx**2                   : {(0,): 2/a**2, (1,): -1/a**2},
         kx**2 + ky**2           : {(0, 1): -1/a**2, (0, 0): 4/a**2,
@@ -97,6 +102,8 @@ def test_simple_derivations():
         ky**2 + kz**2           : {(0, 1): -1/a**2, (0, 0): 4/a**2,
                                    (1, 0): -1/a**2},
         kz**2                   : {(0,): 2/a**2, (1,): -1/a**2},
+    }
+    non_commutative_test = {
         kx * A(x,y) * kx        : {(1, ): -A(a/2 + x, y)/a**2,
                                   (0, ): A(-a/2 + x, y)/a**2 + A(a/2 + x, y)/a**2},
         kx**2 + kz * B(y)       : {(1, 0): -1/a**2, (0, 1): -I*B(y)/(2*a),
@@ -108,6 +115,10 @@ def test_simple_derivations():
         kx * (A(x) + B(x))      : {(0,): 0,
                                    (1,): -I*A(a + x)/(2*a) - I*B(a + x)/(2*a)},
     }
+
+    if not commutative:
+        test.update(non_commutative_test)
+
     for inp, out in test.items():
         got, _ = discretize_symbolic(inp)
         assert got == out
-- 
GitLab