Skip to content
Snippets Groups Projects
Commit d0e37a11 authored by Rafal Skolasinski's avatar Rafal Skolasinski
Browse files

fix that commutative momentum operators are ignored by discretizer, closes #99

parent fa3256d5
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@ from ._common import matrix_monomials
################ Globals variables and definitions
_wf = sympy.Function('_internal_unique_name', commutative=False)
_momentum_operators = {s.name: s for s in momentum_operators}
_position_operators = {s.name: s for s in position_operators}
_displacements = {s: sympy.Symbol('_internal_a_{}'.format(s)) for s in 'xyz'}
......@@ -103,15 +104,15 @@ def discretize_symbolic(hamiltonian, discrete_coordinates=None, substitutions=No
if not isinstance(hamiltonian, (sympy.Expr, sympy.matrices.MatrixBase)):
hamiltonian = sympify(hamiltonian, substitutions)
atoms = hamiltonian.atoms(sympy.Symbol)
if not all('a' != s.name for s in atoms):
atoms_names = [s.name for s in hamiltonian.atoms(sympy.Symbol)]
if any( s == 'a' for s in atoms_names):
raise ValueError("'a' is a symbol used internally to represent "
"lattice spacing; please use a different symbol.")
hamiltonian = sympy.expand(hamiltonian)
if discrete_coordinates is None:
used_momenta = set(momentum_operators) & set(atoms)
discrete_coordinates = {k.name[-1] for k in used_momenta}
used_momenta = set(_momentum_operators) & set(atoms_names)
discrete_coordinates = {k[-1] for k in used_momenta}
else:
discrete_coordinates = set(discrete_coordinates)
if not discrete_coordinates <= set('xyz'):
......@@ -341,14 +342,12 @@ def _discretize_expression(expression, discrete_coordinates):
output[tuple(offset[n].tolist())] += c.subs(subs)
return dict(output)
# main function body starts here
if (isinstance(expression, (int, float, sympy.Integer, sympy.Float))
or not expression.atoms(sympy.Symbol) & set(momentum_operators)):
# if there are no momenta in the expression, then it is an onsite
atoms_names = [s.name for s in expression.atoms(sympy.Symbol)]
if not set(_momentum_operators) & set(atoms_names):
n = len(discrete_coordinates)
return {(0,) * n: expression}
assert isinstance(expression, sympy.Expr)
# make sure we have list of summands
summands = expression.as_ordered_terms()
......
......@@ -8,6 +8,7 @@ from ..discretizer import _wf
import inspect
from functools import wraps
import pytest
def swallows_extra_kwargs(f):
......@@ -42,52 +43,56 @@ A, B = sympy.symbols('A B', commutative=False)
ns = {'A': A, 'B': B, 'a_x': ax, 'a_y': ay, 'az': az, 'x': x, 'y': y, 'z': z}
def test_reading_coordinates():
@pytest.mark.parametrize('commutative', [True, False])
def test_reading_coordinates(commutative):
kx, ky, kz = sympy.symbols('k_x k_y k_z', commutative=commutative)
test = {
kx**2 : {'x'},
kx**2 + ky**2 : {'x', 'y'},
kx**2 + ky**2 + kz**2 : {'x', 'y', 'z'},
ky**2 + kz**2 : {'y', 'z'},
kz**2 : {'z'},
kx * A(x,y) * kx : {'x'},
kx**2 + kz * B(y) : {'x', 'z'},
kx**2 : ['x'],
kx**2 + ky**2 : ['x', 'y'],
kx**2 + ky**2 + kz**2 : ['x', 'y', 'z'],
ky**2 + kz**2 : ['y', 'z'],
kz**2 : ['z'],
kx * A(x,y) * kx : ['x'],
kx**2 + kz * B(y) : ['x', 'z'],
}
for inp, out in test.items():
ham, got = discretize_symbolic(inp)
assert all(d in out for d in got),\
"Should be: _split_factors({})=={}. Not {}".format(inp, out, got)
assert got == out
def test_reading_coordinates_matrix():
test = [
(sympy.Matrix([kx**2]) , {'x'}),
(sympy.Matrix([kx**2 + ky**2]) , {'x', 'y'}),
(sympy.Matrix([kx**2 + ky**2 + kz**2]) , {'x', 'y', 'z'}),
(sympy.Matrix([ky**2 + kz**2]) , {'y', 'z'}),
(sympy.Matrix([kz**2]) , {'z'}),
(sympy.Matrix([kx * A(x,y) * kx]) , {'x'}),
(sympy.Matrix([kx**2 + kz * B(y)]) , {'x', 'z'}),
(sympy.Matrix([sympy.sympify('k_x**2')]) , ['x']),
(sympy.Matrix([kx**2]) , ['x']),
(sympy.Matrix([kx**2 + ky**2]) , ['x', 'y']),
(sympy.Matrix([kx**2 + ky**2 + kz**2]) , ['x', 'y', 'z']),
(sympy.Matrix([ky**2 + kz**2]) , ['y', 'z']),
(sympy.Matrix([kz**2]) , ['z']),
(sympy.Matrix([kx * A(x,y) * kx]) , ['x']),
(sympy.Matrix([kx**2 + kz * B(y)]) , ['x', 'z']),
]
for inp, out in test:
ham, got = discretize_symbolic(inp)
assert all(d in out for d in got),\
"Should be: _split_factors({})=={}. Not {}".format(inp, out, got)
assert got == out
def test_reading_different_matrix_types():
test = [
(sympy.MutableMatrix([kx**2]) , {'x'}),
(sympy.ImmutableMatrix([kx**2]) , {'x'}),
(sympy.MutableDenseMatrix([kx**2]) , {'x'}),
(sympy.ImmutableDenseMatrix([kx**2]) , {'x'}),
(sympy.MutableMatrix([kx**2]) , ['x']),
(sympy.ImmutableMatrix([kx**2]) , ['x']),
(sympy.MutableDenseMatrix([kx**2]) , ['x']),
(sympy.ImmutableDenseMatrix([kx**2]) , ['x']),
]
for inp, out in test:
ham, got = discretize_symbolic(inp)
assert all(d in out for d in got),\
assert got == out,\
"Should be: _split_factors({})=={}. Not {}".format(inp, out, got)
def test_simple_derivations():
@pytest.mark.parametrize('commutative', [True, False])
def test_simple_derivations(commutative):
kx, ky, kz = sympy.symbols('k_x k_y k_z', commutative=commutative)
test = {
kx**2 : {(0,): 2/a**2, (1,): -1/a**2},
kx**2 + ky**2 : {(0, 1): -1/a**2, (0, 0): 4/a**2,
......@@ -97,6 +102,8 @@ def test_simple_derivations():
ky**2 + kz**2 : {(0, 1): -1/a**2, (0, 0): 4/a**2,
(1, 0): -1/a**2},
kz**2 : {(0,): 2/a**2, (1,): -1/a**2},
}
non_commutative_test = {
kx * A(x,y) * kx : {(1, ): -A(a/2 + x, y)/a**2,
(0, ): A(-a/2 + x, y)/a**2 + A(a/2 + x, y)/a**2},
kx**2 + kz * B(y) : {(1, 0): -1/a**2, (0, 1): -I*B(y)/(2*a),
......@@ -108,6 +115,10 @@ def test_simple_derivations():
kx * (A(x) + B(x)) : {(0,): 0,
(1,): -I*A(a + x)/(2*a) - I*B(a + x)/(2*a)},
}
if not commutative:
test.update(non_commutative_test)
for inp, out in test.items():
got, _ = discretize_symbolic(inp)
assert got == out
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment