diff --git a/kwant/operator.pyx b/kwant/operator.pyx index 211267cc4439fa367c94aad04e53fe1b83550deb..8ac18b369341c8d8ea175845cd5346066449b738 100644 --- a/kwant/operator.pyx +++ b/kwant/operator.pyx @@ -519,6 +519,7 @@ cdef class _LocalOperator: q.syst = self.syst q.onsite = self.onsite q.where = self.where + q.sum = self.sum q._site_ranges = self._site_ranges q.check_hermiticity = self.check_hermiticity if callable(self.onsite): @@ -617,7 +618,7 @@ cdef class Density(_LocalOperator): check_hermiticity=True, sum=False): where = _normalize_site_where(syst, where) super().__init__(syst, onsite, where, - check_hermiticity=check_hermiticity) + check_hermiticity=check_hermiticity, sum=sum) @cython.boundscheck(False) @cython.wraparound(False) @@ -710,7 +711,7 @@ cdef class Current(_LocalOperator): check_hermiticity=True, sum=False): where = _normalize_hopping_where(syst, where) super().__init__(syst, onsite, where, - check_hermiticity=check_hermiticity) + check_hermiticity=check_hermiticity, sum=sum) @cython.embedsignature def bind(self, args=()): @@ -828,7 +829,7 @@ cdef class Source(_LocalOperator): check_hermiticity=True, sum=False): where = _normalize_site_where(syst, where) super().__init__(syst, onsite, where, - check_hermiticity=check_hermiticity) + check_hermiticity=check_hermiticity, sum=sum) @cython.embedsignature def bind(self, args=()): diff --git a/kwant/tests/test_operator.py b/kwant/tests/test_operator.py index 3b5aa2a41a2a9475844554eb7b4808bb0d16b08e..347465bd273e9aa53b0437874ce87faaaa432cfd 100644 --- a/kwant/tests/test_operator.py +++ b/kwant/tests/test_operator.py @@ -148,6 +148,10 @@ def test_operator_construction(): A = ops.Current(fsyst, where=where) assert all((a, b) in fwhere_list for a, b in A.where) + # test that `sum` is passed to constructors correctly + for A in opservables: + A(fsyst, sum=True).sum == True + def _test(A, bra, ket=None, per_el_val=None, reduced_val=None, args=()): if per_el_val is not None: @@ -161,11 +165,13 @@ def _test(A, bra, ket=None, per_el_val=None, reduced_val=None, args=()): act_val = np.dot(bra.conj(), A.act(ket, args=args)) inner_val = np.sum(A(bra, ket, args=args)) # check also when sum is done internally by operator - A.sum = True try: + sum_reset = A.sum + A.sum = True sum_inner_val = A(bra, ket, args=args) + assert inner_val == sum_inner_val finally: - A.sum = False + A.sum = sum_reset assert np.isclose(act_val, inner_val) assert np.isclose(sum_inner_val, inner_val) @@ -179,12 +185,14 @@ def test_opservables_finite(): ev, wfs = la.eigh(fsyst.hamiltonian_submatrix()) Q = ops.Density(fsyst) + Qtot = ops.Density(fsyst, sum=True) J = ops.Current(fsyst) K = ops.Source(fsyst) - for wf in wfs.T: # wfs[:, i] is i'th eigenvector + for i, wf in enumerate(wfs.T): # wfs[:, i] is i'th eigenvector assert np.allclose(Q.act(wf), wf) # this operation is identity _test(Q, wf, reduced_val=1) # eigenvectors are normalized + _test(Qtot, wf, per_el_val=1) # eigenvectors are normalized _test(J, wf, per_el_val=0) # time-reversal symmetry: no current _test(K, wf, per_el_val=0) # onsite commutes with hamiltonian @@ -206,12 +214,14 @@ def test_opservables_finite(): ev, wfs = la.eigh(fsyst.hamiltonian_submatrix()) Q = ops.Density(fsyst) + Qtot = ops.Density(fsyst, sum=True) J = ops.Current(fsyst) K = ops.Source(fsyst) for wf in wfs.T: # wfs[:, i] is i'th eigenvector assert np.allclose(Q.act(wf), wf) # this operation is identity _test(Q, wf, reduced_val=1) # eigenvectors are normalized + _test(Qtot, wf, per_el_val=1) # eigenvectors are normalized _test(J, wf, per_el_val=0) # time-reversal symmetry: no current _test(K, wf, per_el_val=0) # onsite commutes with hamiltonian @@ -253,10 +263,12 @@ def test_opservables_scattering(): fsyst = syst.finalized() # currents on the left and right of the disordered region - J_right = ops.Current(fsyst, where=[(lat(N, j), lat(N-1, j)) - for j in range(3)]) - J_left = ops.Current(fsyst, where=[(lat(0, j), lat(-1, j)) - for j in range(3)]) + right_interface = [(lat(N, j), lat(N-1, j)) for j in range(3)] + left_interface = [(lat(0, j), lat(-1, j)) for j in range(3)] + J_right = ops.Current(fsyst, where=right_interface) + J_right_tot = ops.Current(fsyst, where=right_interface, sum=True) + J_left = ops.Current(fsyst, where=left_interface) + J_left_tot = ops.Current(fsyst, where=left_interface, sum=True) smatrix = kwant.smatrix(fsyst, energy=1.0) t = smatrix.submatrix(1, 0).T # want to iterate over the columns @@ -264,7 +276,9 @@ def test_opservables_scattering(): wfs = kwant.wave_function(fsyst, energy=1.0)(0) for rv, tv, wf in zip(r, t, wfs): _test(J_right, wf, reduced_val=np.sum(np.abs(tv)**2)) + _test(J_right_tot, wf, per_el_val=np.sum(np.abs(tv)**2)) _test(J_left, wf, reduced_val=(1 - np.sum(np.abs(rv)**2))) + _test(J_left_tot, wf, per_el_val=(1 - np.sum(np.abs(rv)**2))) def test_opservables_spin():