Skip to content
Snippets Groups Projects
Commit 1e8a3ebc authored by Joseph Weston's avatar Joseph Weston
Browse files

fix bug where `sum` was not passed correctly to operator constructors

Add tests to make sure we explicity check for `sum == True`
parent 554d0fb1
No related branches found
No related tags found
No related merge requests found
......@@ -519,6 +519,7 @@ cdef class _LocalOperator:
q.syst = self.syst
q.onsite = self.onsite
q.where = self.where
q.sum = self.sum
q._site_ranges = self._site_ranges
q.check_hermiticity = self.check_hermiticity
if callable(self.onsite):
......@@ -617,7 +618,7 @@ cdef class Density(_LocalOperator):
check_hermiticity=True, sum=False):
where = _normalize_site_where(syst, where)
super().__init__(syst, onsite, where,
check_hermiticity=check_hermiticity)
check_hermiticity=check_hermiticity, sum=sum)
@cython.boundscheck(False)
@cython.wraparound(False)
......@@ -710,7 +711,7 @@ cdef class Current(_LocalOperator):
check_hermiticity=True, sum=False):
where = _normalize_hopping_where(syst, where)
super().__init__(syst, onsite, where,
check_hermiticity=check_hermiticity)
check_hermiticity=check_hermiticity, sum=sum)
@cython.embedsignature
def bind(self, args=()):
......@@ -828,7 +829,7 @@ cdef class Source(_LocalOperator):
check_hermiticity=True, sum=False):
where = _normalize_site_where(syst, where)
super().__init__(syst, onsite, where,
check_hermiticity=check_hermiticity)
check_hermiticity=check_hermiticity, sum=sum)
@cython.embedsignature
def bind(self, args=()):
......
......@@ -148,6 +148,10 @@ def test_operator_construction():
A = ops.Current(fsyst, where=where)
assert all((a, b) in fwhere_list for a, b in A.where)
# test that `sum` is passed to constructors correctly
for A in opservables:
A(fsyst, sum=True).sum == True
def _test(A, bra, ket=None, per_el_val=None, reduced_val=None, args=()):
if per_el_val is not None:
......@@ -161,11 +165,13 @@ def _test(A, bra, ket=None, per_el_val=None, reduced_val=None, args=()):
act_val = np.dot(bra.conj(), A.act(ket, args=args))
inner_val = np.sum(A(bra, ket, args=args))
# check also when sum is done internally by operator
A.sum = True
try:
sum_reset = A.sum
A.sum = True
sum_inner_val = A(bra, ket, args=args)
assert inner_val == sum_inner_val
finally:
A.sum = False
A.sum = sum_reset
assert np.isclose(act_val, inner_val)
assert np.isclose(sum_inner_val, inner_val)
......@@ -179,12 +185,14 @@ def test_opservables_finite():
ev, wfs = la.eigh(fsyst.hamiltonian_submatrix())
Q = ops.Density(fsyst)
Qtot = ops.Density(fsyst, sum=True)
J = ops.Current(fsyst)
K = ops.Source(fsyst)
for wf in wfs.T: # wfs[:, i] is i'th eigenvector
for i, wf in enumerate(wfs.T): # wfs[:, i] is i'th eigenvector
assert np.allclose(Q.act(wf), wf) # this operation is identity
_test(Q, wf, reduced_val=1) # eigenvectors are normalized
_test(Qtot, wf, per_el_val=1) # eigenvectors are normalized
_test(J, wf, per_el_val=0) # time-reversal symmetry: no current
_test(K, wf, per_el_val=0) # onsite commutes with hamiltonian
......@@ -206,12 +214,14 @@ def test_opservables_finite():
ev, wfs = la.eigh(fsyst.hamiltonian_submatrix())
Q = ops.Density(fsyst)
Qtot = ops.Density(fsyst, sum=True)
J = ops.Current(fsyst)
K = ops.Source(fsyst)
for wf in wfs.T: # wfs[:, i] is i'th eigenvector
assert np.allclose(Q.act(wf), wf) # this operation is identity
_test(Q, wf, reduced_val=1) # eigenvectors are normalized
_test(Qtot, wf, per_el_val=1) # eigenvectors are normalized
_test(J, wf, per_el_val=0) # time-reversal symmetry: no current
_test(K, wf, per_el_val=0) # onsite commutes with hamiltonian
......@@ -253,10 +263,12 @@ def test_opservables_scattering():
fsyst = syst.finalized()
# currents on the left and right of the disordered region
J_right = ops.Current(fsyst, where=[(lat(N, j), lat(N-1, j))
for j in range(3)])
J_left = ops.Current(fsyst, where=[(lat(0, j), lat(-1, j))
for j in range(3)])
right_interface = [(lat(N, j), lat(N-1, j)) for j in range(3)]
left_interface = [(lat(0, j), lat(-1, j)) for j in range(3)]
J_right = ops.Current(fsyst, where=right_interface)
J_right_tot = ops.Current(fsyst, where=right_interface, sum=True)
J_left = ops.Current(fsyst, where=left_interface)
J_left_tot = ops.Current(fsyst, where=left_interface, sum=True)
smatrix = kwant.smatrix(fsyst, energy=1.0)
t = smatrix.submatrix(1, 0).T # want to iterate over the columns
......@@ -264,7 +276,9 @@ def test_opservables_scattering():
wfs = kwant.wave_function(fsyst, energy=1.0)(0)
for rv, tv, wf in zip(r, t, wfs):
_test(J_right, wf, reduced_val=np.sum(np.abs(tv)**2))
_test(J_right_tot, wf, per_el_val=np.sum(np.abs(tv)**2))
_test(J_left, wf, reduced_val=(1 - np.sum(np.abs(rv)**2)))
_test(J_left_tot, wf, per_el_val=(1 - np.sum(np.abs(rv)**2)))
def test_opservables_spin():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment