Skip to content
Snippets Groups Projects
Commit e7fa76a2 authored by Joseph Weston's avatar Joseph Weston
Browse files

merge branch 'improve_continuum'

make 'monomials' preserve order of operators,
remove some unimportant (redundant) tests,
and improve the API of 'monomials', which is
not yet part of the public API.

Closes #135

See merge request !151
parents aa96bfc0 eb66bfe8
No related branches found
No related tags found
No related merge requests found
......@@ -6,10 +6,8 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import functools
import keyword
from collections import defaultdict
from operator import mul
import numpy as np
......@@ -205,85 +203,70 @@ def make_commutative(expr, *symbols):
return expr
def monomials(expr, *gens):
def monomials(expr, gens=None):
"""Parse ``expr`` into monomials in the symbols in ``gens``.
Parameters
----------
expr: sympy.Expr or sympy.Matrix
Input expression that will be parsed into monomials.
gens: sequence of sympy.Symbol objects
Generators used to separate input ``expr`` into monomials.
Sympy expression to be parsed into monomials.
gens: sequence of sympy.Symbol objects or strings (optional)
Generators of monomials. If unset it will default to all
symbols used in ``expr``.
Returns
-------
dictionary (generator: monomial)
Note
----
All generators will be substituted with its commutative version using
`kwant.continuum.make_commutative`` function.
Example
-------
>>> expr = kwant.continuum.sympify("A * (x**2 + y) + B * x + C")
>>> monomials(expr, gens=('x', 'y'))
{1: C, x: B, x**2: A, y: A}
"""
if gens is None:
gens = expr.atoms(sympy.Symbol)
else:
gens = [sympify(g) for g in gens]
if not isinstance(expr, sympy.MatrixBase):
return _expression_monomials(expr, *gens)
return _expression_monomials(expr, gens)
else:
output = defaultdict(lambda: sympy.zeros(*expr.shape))
for (i, j), e in np.ndenumerate(expr):
mons = _expression_monomials(e, *gens)
mons = _expression_monomials(e, gens)
for key, val in mons.items():
output[key][i, j] += val
return dict(output)
def _expression_monomials(expression, *gens):
"""Parse ``expression`` into monomials in the symbols in ``gens``.
def _expression_monomials(expr, gens):
"""Parse ``expr`` into monomials in the symbols in ``gens``.
Example
Parameters
----------
expr: sympy.Expr
Sympy expr to be parsed.
gens: sequence of sympy.Symbol
Generators of monomials.
Returns
-------
>>> expr = A * (x**2 + y) + B * x + C
>>> _expression_monomials(expr, x, y)
{1: C, x**2: A, y: A, x: B}
dictionary (generator: monomial)
"""
f_args = [f.args for f in expression.atoms(AppliedUndef, sympy.Function)]
f_args = [i for s in f_args for i in s]
if set(gens) & set(f_args):
raise ValueError('Functions in "expression" cannot contain any of '
'"gens" as their argument.')
expression = make_commutative(expression, *gens)
gens = [make_commutative(g, g) for g in gens]
expr = sympy.expand(expr)
output = defaultdict(lambda: sympy.Integer(0))
for summand in expr.as_ordered_terms():
key = []
val = []
for factor in summand.as_ordered_factors():
symbol, exponent = factor.as_base_exp()
if symbol in gens:
key.append(factor)
else:
val.append(factor)
output[sympy.Mul(*key)] += sympy.Mul(*val)
expression = sympy.expand(expression)
summands = expression.as_ordered_terms()
output = defaultdict(int)
for summand in summands:
key = [sympy.Integer(1)]
if summand in gens:
key.append(summand)
elif isinstance(summand, sympy.Pow):
if summand.args[0] in gens:
key.append(summand)
else:
for arg in summand.args:
if arg in gens:
key.append(arg)
if isinstance(arg, sympy.Pow):
if arg.args[0] in gens:
key.append(arg)
key = functools.reduce(mul, key)
val = summand.xreplace({g: sympy.S.One for g in gens})
### to not create key
if val != 0:
output[key] += val
new_expression = sum(k * v for k, v in output.items())
assert sympy.expand(expression) == sympy.expand(new_expression)
return dict(output)
......
......@@ -517,7 +517,7 @@ def _return_string(expr, coords):
if isinstance(expr, sympy.matrices.MatrixBase):
# express matrix return values in terms of sums of known matrices,
# which will be assigned to '_cache_n' in the function body.
mons = monomials(expr, *expr.atoms(sympy.Symbol))
mons = monomials(expr, expr.atoms(sympy.Symbol))
mons = {k: cache(v) for k, v in mons.items()}
mons = ["{} * {}".format(_print_sympy(k), _print_sympy(v))
for k, v in mons.items()]
......
......@@ -6,9 +6,6 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
from functools import reduce
from operator import mul
import pytest
import tinyarray as ta
......@@ -86,83 +83,56 @@ def test_sympify_mix_symbol_and_matrx(input_expr, output_expr, subs):
assert sympify(input_expr, locals=subs) == output_expr
A, B, non_x = sympy.symbols('A B x', commutative=False)
x, y = sympy.symbols('x y')
A, B, x = sympy.symbols('A B x', commutative=False)
com_x, com_y = sympy.symbols('x y')
expr1 = non_x*A*non_x + x**2 * A * x + B*non_x**2
expr1 = x*A*x + x**2 * A * x + B*x**2
matr = sympy.Matrix([[expr1, expr1+A*non_x], [0, -expr1]])
res_mat = sympy.Matrix([[x**3*A + x**2*A + x**2*B, x**3*A + x**2*A + x**2*B + x*A],
[0, -x**3*A - x**2*A - x**2*B]])
matr_com = sympy.Matrix([[expr1, expr1+A*x], [0, -expr1]])
res_mat = sympy.Matrix([[com_x**3*A + com_x**2*A + com_x**2*B, com_x**3*A + com_x**2*A + com_x**2*B + com_x*A],
[0, -com_x**3*A - com_x**2*A - com_x**2*B]])
def test_make_commutative():
assert make_commutative(expr1, x) == make_commutative(expr1, non_x)
assert make_commutative(expr1, x) == x**3*A + x**2*A + x**2*B
assert make_commutative(matr, x) == res_mat
expr2 = non_x*A*non_x + x**2 * A*2 * x + B*non_x/2 + non_x*B/2 + x + A + non_x + x/A
def test_monomials():
f, g, a, b = sympy.symbols('f g a b')
assert monomials(expr2, x) == {x**3: 2*A, 1: A, x: 2 + A**(-1) + B, x**2: A}
assert monomials(expr1, x) == {x**2: A + B, x**3: A}
assert monomials(x, x) == {x: 1}
assert monomials(x**2, x) == {x**2: 1}
assert monomials(x**2 + x, x) == {x: 1, x**2: 1}
assert monomials(x**2 + x + A**2, x) == {x: 1, x**2: 1, 1: A**2}
assert monomials(x * f(a, b), x) == {x: f(a, b)}
expr = x * f(a) + y * g(b)
out = {y: g(b), x: f(a)}
assert monomials(expr, x, y) == out
expr = 1 + x + A*x + 2*x + x**2 + A*x**2 + non_x*A*non_x
out = {1: 1, x: 3 + A, x**2: 2 * A + 1}
assert monomials(expr, x) == out
expr = 1 + x * (3 + A) + x**2 * (1 + A)
out = {1: 1, x: 3 + A, x**2: 1 * A + 1}
assert monomials(expr, x) == out
with pytest.raises(ValueError):
monomials(f(x), x)
with pytest.raises(ValueError):
monomials(f(a), a)
def legacy_monomials(expr, *gens):
"""This was my first implementation. Unfortunately it is very slow.
It is used to test correctness of new monomials function.
"""
expr = make_commutative(expr, x)
R = sympy.ring(gens, sympy.EX, sympy.lex)[0]
expr = R(expr)
output = {}
for power, coeff in zip(expr.monoms(), expr.coeffs()):
key = reduce(mul, [sympy.Symbol(k.name)**n for k, n in zip(gens, power)])
output[key] = sympy.expand(coeff.as_expr())
return output
def test_monomials_with_reference_function():
assert legacy_monomials(expr2, x) == monomials(expr2, x)
assert make_commutative(expr1, com_x) == make_commutative(expr1, x)
assert make_commutative(expr1, com_x) == com_x**3*A + com_x**2*A + com_x**2*B
assert make_commutative(matr_com, com_x) == res_mat
matr_monomials = sympify("[[x+y, a*x**2 + b*y], [y, x]]")
x, y, z = position_operators
a, b = sympy.symbols('a, b')
@pytest.mark.parametrize('expr, gens, output', [
(x * a(x) * x + x**2 * a, None, {x**2: a(x), a*x**2: 1}),
(x * a(x) * x + x**2 * a, [x], {x**2: a(x) + a}),
(x**2, [x], {x**2: 1}),
(2 * x + 3 * x**2, [x], {x: 2, x**2: 3}),
(2 * x + 3 * x**2, 'x', {x: 2, x**2: 3}),
(a * x**2 + 2 * b * x**2, 'x', {x**2: a + 2 * b}),
(x**2 * (a + 2 * b) , 'x', {x**2: a + 2 * b}),
(2 * x * y + 3 * y * x, 'xy', {x*y: 2, y*x: 3}),
(2 * x * a + 3 * b, 'ab', {a: 2*x, b: 3}),
(matr_monomials, None, {
x: sympy.Matrix([[1, 0], [0, 1]]),
b*y: sympy.Matrix([[0, 1], [0, 0]]),
a*x**2: sympy.Matrix([[0, 1], [0, 0]]),
y: sympy.Matrix([[1, 0], [1, 0]])
}),
(matr_monomials, [x], {
x: sympy.Matrix([[1, 0], [0, 1]]),
1: sympy.Matrix([[y, b*y], [y, 0]]),
x**2: sympy.Matrix([[0, a], [0, 0]])
}),
(matr_monomials, [x, y], {
x: sympy.Matrix([[1, 0], [0, 1]]),
x**2: sympy.Matrix([[0, a], [0, 0]]),
y: sympy.Matrix([[1, b], [1, 0]])
}),
])
def test_monomials(expr, gens, output):
assert monomials(expr, gens) == output
def test_matrix_monomials():
out = {
x**2: sympy.Matrix([[A + B, A + B],[0, -A - B]]),
x: sympy.Matrix([[0, A], [0, 0]]),
x**3: sympy.Matrix([[A, A], [0, -A]]),
}
mons = monomials(matr, x)
assert mons == out
@pytest.mark.parametrize("e, should_be, kwargs", [
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment