diff --git a/kwant/continuum/_common.py b/kwant/continuum/_common.py index f614f2bba0f9da0dec6f67664193c9c50597c9b8..249e805e723a452aaf8e5bfe63210248bb07117c 100644 --- a/kwant/continuum/_common.py +++ b/kwant/continuum/_common.py @@ -217,9 +217,12 @@ def _expression_monomials(expression, *gens): >>> _expression_monomials(expr, x, y) {1: C, x**2: A, y: A, x: B} """ - if expression.atoms(AppliedUndef): - raise NotImplementedError('Getting monomials of expressions containing ' - 'functions is not implemented.') + f_args = [f.args for f in expression.atoms(AppliedUndef, sympy.Function)] + f_args = [i for s in f_args for i in s] + + if set(gens) & set(f_args): + raise ValueError('Functions in "expression" cannot contain any of ' + '"gens" as their argument.') expression = make_commutative(expression, *gens) gens = [make_commutative(g, g) for g in gens] @@ -246,7 +249,7 @@ def _expression_monomials(expression, *gens): key.append(arg) key = functools.reduce(mul, key) - val = summand.xreplace({g: 1 for g in gens}) + val = summand.xreplace({g: sympy.S.One for g in gens}) ### to not create key if val != 0: @@ -254,7 +257,6 @@ def _expression_monomials(expression, *gens): new_expression = sum(k * v for k, v in output.items()) assert sympy.expand(expression) == sympy.expand(new_expression) - return dict(output) diff --git a/kwant/continuum/tests/test_common.py b/kwant/continuum/tests/test_common.py index 313f3193148a1551d7607e638dccb1e028fac2e2..4b1c557ac1f3cc5bfd33197d001ba81ba9b62271 100644 --- a/kwant/continuum/tests/test_common.py +++ b/kwant/continuum/tests/test_common.py @@ -81,7 +81,7 @@ def test_sympify_mix_symbol_and_matrx(input_expr, output_expr, subs): A, B, non_x = sympy.symbols('A B x', commutative=False) -x = sympy.Symbol('x') +x, y = sympy.symbols('x y') expr1 = non_x*A*non_x + x**2 * A * x + B*non_x**2 @@ -100,12 +100,19 @@ expr2 = non_x*A*non_x + x**2 * A*2 * x + B*non_x/2 + non_x*B/2 + x + A + non_x + def test_monomials(): + f, g, a, b = sympy.symbols('f g a b') + assert monomials(expr2, x) == {x**3: 2*A, 1: A, x: 2 + A**(-1) + B, x**2: A} assert monomials(expr1, x) == {x**2: A + B, x**3: A} assert monomials(x, x) == {x: 1} assert monomials(x**2, x) == {x**2: 1} assert monomials(x**2 + x, x) == {x: 1, x**2: 1} assert monomials(x**2 + x + A**2, x) == {x: 1, x**2: 1, 1: A**2} + assert monomials(x * f(a, b), x) == {x: f(a, b)} + + expr = x * f(a) + y * g(b) + out = {y: g(b), x: f(a)} + assert monomials(expr, x, y) == out expr = 1 + x + A*x + 2*x + x**2 + A*x**2 + non_x*A*non_x out = {1: 1, x: 3 + A, x**2: 2 * A + 1} @@ -115,6 +122,12 @@ def test_monomials(): out = {1: 1, x: 3 + A, x**2: 1 * A + 1} assert monomials(expr, x) == out + with pytest.raises(ValueError): + monomials(f(x), x) + + with pytest.raises(ValueError): + monomials(f(a), a) + def legacy_monomials(expr, *gens): """This was my first implementation. Unfortunately it is very slow.