From 1e8a3ebc582c87bdca37052598453496b19b6d4f Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Tue, 17 Jan 2017 11:10:21 +0100
Subject: [PATCH] fix bug where `sum` was not passed correctly to operator
 constructors

Add tests to make sure we explicity check for `sum == True`
---
 kwant/operator.pyx           |  7 ++++---
 kwant/tests/test_operator.py | 28 +++++++++++++++++++++-------
 2 files changed, 25 insertions(+), 10 deletions(-)

diff --git a/kwant/operator.pyx b/kwant/operator.pyx
index 211267cc..8ac18b36 100644
--- a/kwant/operator.pyx
+++ b/kwant/operator.pyx
@@ -519,6 +519,7 @@ cdef class _LocalOperator:
         q.syst = self.syst
         q.onsite = self.onsite
         q.where = self.where
+        q.sum = self.sum
         q._site_ranges = self._site_ranges
         q.check_hermiticity = self.check_hermiticity
         if callable(self.onsite):
@@ -617,7 +618,7 @@ cdef class Density(_LocalOperator):
                  check_hermiticity=True, sum=False):
         where = _normalize_site_where(syst, where)
         super().__init__(syst, onsite, where,
-                        check_hermiticity=check_hermiticity)
+                         check_hermiticity=check_hermiticity, sum=sum)
 
     @cython.boundscheck(False)
     @cython.wraparound(False)
@@ -710,7 +711,7 @@ cdef class Current(_LocalOperator):
                  check_hermiticity=True, sum=False):
         where = _normalize_hopping_where(syst, where)
         super().__init__(syst, onsite, where,
-                         check_hermiticity=check_hermiticity)
+                         check_hermiticity=check_hermiticity, sum=sum)
 
     @cython.embedsignature
     def bind(self, args=()):
@@ -828,7 +829,7 @@ cdef class Source(_LocalOperator):
                  check_hermiticity=True, sum=False):
         where = _normalize_site_where(syst, where)
         super().__init__(syst, onsite, where,
-                         check_hermiticity=check_hermiticity)
+                         check_hermiticity=check_hermiticity, sum=sum)
 
     @cython.embedsignature
     def bind(self, args=()):
diff --git a/kwant/tests/test_operator.py b/kwant/tests/test_operator.py
index 3b5aa2a4..347465bd 100644
--- a/kwant/tests/test_operator.py
+++ b/kwant/tests/test_operator.py
@@ -148,6 +148,10 @@ def test_operator_construction():
     A = ops.Current(fsyst, where=where)
     assert all((a, b) in fwhere_list for a, b in A.where)
 
+    # test that `sum` is passed to constructors correctly
+    for A in opservables:
+        A(fsyst, sum=True).sum == True
+
 
 def _test(A, bra, ket=None, per_el_val=None, reduced_val=None, args=()):
     if per_el_val is not None:
@@ -161,11 +165,13 @@ def _test(A, bra, ket=None, per_el_val=None, reduced_val=None, args=()):
     act_val = np.dot(bra.conj(), A.act(ket, args=args))
     inner_val = np.sum(A(bra, ket, args=args))
     # check also when sum is done internally by operator
-    A.sum = True
     try:
+        sum_reset = A.sum
+        A.sum = True
         sum_inner_val = A(bra, ket, args=args)
+        assert inner_val == sum_inner_val
     finally:
-        A.sum = False
+        A.sum = sum_reset
 
     assert np.isclose(act_val, inner_val)
     assert np.isclose(sum_inner_val, inner_val)
@@ -179,12 +185,14 @@ def test_opservables_finite():
     ev, wfs = la.eigh(fsyst.hamiltonian_submatrix())
 
     Q = ops.Density(fsyst)
+    Qtot = ops.Density(fsyst, sum=True)
     J = ops.Current(fsyst)
     K = ops.Source(fsyst)
 
-    for wf in wfs.T:  # wfs[:, i] is i'th eigenvector
+    for i, wf in enumerate(wfs.T):  # wfs[:, i] is i'th eigenvector
         assert np.allclose(Q.act(wf), wf)  # this operation is identity
         _test(Q, wf, reduced_val=1)  # eigenvectors are normalized
+        _test(Qtot, wf, per_el_val=1)  # eigenvectors are normalized
         _test(J, wf, per_el_val=0)  # time-reversal symmetry: no current
         _test(K, wf, per_el_val=0)  # onsite commutes with hamiltonian
 
@@ -206,12 +214,14 @@ def test_opservables_finite():
     ev, wfs = la.eigh(fsyst.hamiltonian_submatrix())
 
     Q = ops.Density(fsyst)
+    Qtot = ops.Density(fsyst, sum=True)
     J = ops.Current(fsyst)
     K = ops.Source(fsyst)
 
     for wf in wfs.T:  # wfs[:, i] is i'th eigenvector
         assert np.allclose(Q.act(wf), wf)  # this operation is identity
         _test(Q, wf, reduced_val=1)  # eigenvectors are normalized
+        _test(Qtot, wf, per_el_val=1)  # eigenvectors are normalized
         _test(J, wf, per_el_val=0)  # time-reversal symmetry: no current
         _test(K, wf, per_el_val=0)  # onsite commutes with hamiltonian
 
@@ -253,10 +263,12 @@ def test_opservables_scattering():
     fsyst = syst.finalized()
 
     # currents on the left and right of the disordered region
-    J_right = ops.Current(fsyst, where=[(lat(N, j), lat(N-1, j))
-                                        for j in range(3)])
-    J_left = ops.Current(fsyst, where=[(lat(0, j), lat(-1, j))
-                                       for j in range(3)])
+    right_interface = [(lat(N, j), lat(N-1, j)) for j in range(3)]
+    left_interface = [(lat(0, j), lat(-1, j)) for j in range(3)]
+    J_right = ops.Current(fsyst, where=right_interface)
+    J_right_tot = ops.Current(fsyst, where=right_interface, sum=True)
+    J_left = ops.Current(fsyst, where=left_interface)
+    J_left_tot = ops.Current(fsyst, where=left_interface, sum=True)
 
     smatrix = kwant.smatrix(fsyst, energy=1.0)
     t = smatrix.submatrix(1, 0).T  # want to iterate over the columns
@@ -264,7 +276,9 @@ def test_opservables_scattering():
     wfs = kwant.wave_function(fsyst, energy=1.0)(0)
     for rv, tv, wf in zip(r, t, wfs):
         _test(J_right, wf, reduced_val=np.sum(np.abs(tv)**2))
+        _test(J_right_tot, wf, per_el_val=np.sum(np.abs(tv)**2))
         _test(J_left, wf, reduced_val=(1 - np.sum(np.abs(rv)**2)))
+        _test(J_left_tot, wf, per_el_val=(1 - np.sum(np.abs(rv)**2)))
 
 
 def test_opservables_spin():
-- 
GitLab