diff --git a/kwant/continuum/_common.py b/kwant/continuum/_common.py
index f614f2bba0f9da0dec6f67664193c9c50597c9b8..249e805e723a452aaf8e5bfe63210248bb07117c 100644
--- a/kwant/continuum/_common.py
+++ b/kwant/continuum/_common.py
@@ -217,9 +217,12 @@ def _expression_monomials(expression, *gens):
         >>> _expression_monomials(expr, x, y)
         {1: C, x**2: A, y: A, x: B}
     """
-    if expression.atoms(AppliedUndef):
-        raise NotImplementedError('Getting monomials of expressions containing '
-                                  'functions is not implemented.')
+    f_args = [f.args for f in expression.atoms(AppliedUndef, sympy.Function)]
+    f_args = [i for s in f_args for i in s]
+
+    if set(gens) & set(f_args):
+        raise ValueError('Functions in "expression" cannot contain any of '
+                         '"gens" as their argument.')
 
     expression = make_commutative(expression, *gens)
     gens = [make_commutative(g, g) for g in gens]
@@ -246,7 +249,7 @@ def _expression_monomials(expression, *gens):
                         key.append(arg)
 
         key = functools.reduce(mul, key)
-        val = summand.xreplace({g: 1 for g in gens})
+        val = summand.xreplace({g: sympy.S.One for g in gens})
 
         ### to not create key
         if val != 0:
@@ -254,7 +257,6 @@ def _expression_monomials(expression, *gens):
 
     new_expression = sum(k * v for k, v in output.items())
     assert sympy.expand(expression) == sympy.expand(new_expression)
-
     return dict(output)
 
 
diff --git a/kwant/continuum/tests/test_common.py b/kwant/continuum/tests/test_common.py
index 313f3193148a1551d7607e638dccb1e028fac2e2..4b1c557ac1f3cc5bfd33197d001ba81ba9b62271 100644
--- a/kwant/continuum/tests/test_common.py
+++ b/kwant/continuum/tests/test_common.py
@@ -81,7 +81,7 @@ def test_sympify_mix_symbol_and_matrx(input_expr, output_expr, subs):
 
 
 A, B, non_x = sympy.symbols('A B x', commutative=False)
-x = sympy.Symbol('x')
+x, y = sympy.symbols('x y')
 
 expr1 = non_x*A*non_x + x**2 * A * x + B*non_x**2
 
@@ -100,12 +100,19 @@ expr2 = non_x*A*non_x + x**2 * A*2 * x + B*non_x/2 + non_x*B/2 + x + A + non_x +
 
 
 def test_monomials():
+    f, g, a, b = sympy.symbols('f g a b')
+
     assert monomials(expr2, x) == {x**3: 2*A, 1: A, x: 2 + A**(-1) + B, x**2: A}
     assert monomials(expr1, x) == {x**2: A + B, x**3: A}
     assert monomials(x, x) == {x: 1}
     assert monomials(x**2, x) == {x**2: 1}
     assert monomials(x**2 + x, x) == {x: 1, x**2: 1}
     assert monomials(x**2 + x + A**2, x) == {x: 1, x**2: 1, 1: A**2}
+    assert monomials(x * f(a, b), x) == {x: f(a, b)}
+
+    expr = x * f(a) + y * g(b)
+    out = {y: g(b), x: f(a)}
+    assert monomials(expr, x, y) == out
 
     expr = 1 + x + A*x + 2*x + x**2 + A*x**2 + non_x*A*non_x
     out = {1: 1, x: 3 + A, x**2: 2 * A + 1}
@@ -115,6 +122,12 @@ def test_monomials():
     out = {1: 1, x: 3 + A, x**2: 1 * A + 1}
     assert monomials(expr, x) == out
 
+    with pytest.raises(ValueError):
+        monomials(f(x), x)
+
+    with pytest.raises(ValueError):
+        monomials(f(a), a)
+
 
 def legacy_monomials(expr, *gens):
     """This was my first implementation. Unfortunately it is very slow.