Skip to content
Snippets Groups Projects
Commit 3e950e7c authored by Joseph Weston's avatar Joseph Weston
Browse files

add consistency checks when computing onsite and hamiltonian data

We get rid of '_is_herm_conj' in favour of '_is_hermitian', replace
'_check_ham' with '_check_hams' (which works on vectorized values)
and add '_check_onsites' (which also works on vectorized values).
Replace absolute value calculation with call to 'cabs' from 'complex.h'
parent 51d678bf
No related branches found
No related tags found
No related merge requests found
......@@ -4,17 +4,24 @@ from .graph.defs import gint_dtype
cdef gint _bisect(gint[:] a, gint x)
cdef int _is_herm_conj(complex[:, :] a, complex[:, :] b,
double atol=*, double rtol=*) except -1
cdef int _is_hermitian(
complex[:, :] a, double atol=*, double rtol=*
) except -1
cdef int _is_hermitian_3d(
complex[:, :, :] a, double atol=*, double rtol=*
) except -1
cdef _select(gint[:, :] arr, gint[:] indexes)
cdef int _check_onsite(complex[:, :] M, gint norbs,
int check_hermiticity) except -1
cdef int _check_ham(complex[:, :] H, ham, args, params,
gint a, gint a_norbs, gint b, gint b_norbs,
int check_hermiticity) except -1
cdef int _check_onsites(complex[:, :, :] M, gint norbs,
int check_hermiticity) except -1
cdef int _check_hams(complex[:, :, :] H, gint to_norbs, gint from_norbs,
int check_hermiticity) except -1
cdef void _get_orbs(gint[:, :] site_ranges, gint site,
gint *start_orb, gint *norbs)
......
......@@ -21,6 +21,9 @@ from scipy.sparse import coo_matrix
from libc cimport math
cdef extern from "complex.h":
double cabs(double complex)
from .graph.core cimport EdgeIterator
from .graph.core import DisabledFeatureError, NodeDoesNotExistError
from .graph.defs cimport gint
......@@ -51,32 +54,61 @@ cdef gint _bisect(gint[:] a, gint x):
@cython.boundscheck(False)
@cython.wraparound(False)
cdef int _is_herm_conj(complex[:, :] a, complex[:, :] b,
double atol=1e-300, double rtol=1e-13) except -1:
"Return True if `a` is the Hermitian conjugate of `b`."
assert a.shape[0] == b.shape[1]
assert a.shape[1] == b.shape[0]
cdef int _is_hermitian(
complex[:, :] a, double atol=1e-300, double rtol=1e-13
) except -1:
"Return True if 'a' is Hermitian"
if a.shape[0] != a.shape[1]:
return False
# compute max(a)
cdef double tmp, max_a = 0
cdef gint i, j
cdef gint i, j, k
for i in range(a.shape[0]):
for j in range(a.shape[1]):
tmp = a[i, j].real * a[i, j].real + a[i, j].imag * a[i, j].imag
tmp = cabs(a[i, j])
if tmp > max_a:
max_a = tmp
max_a = math.sqrt(max_a)
cdef double tol = rtol * max_a + atol
cdef complex ctmp
for i in range(a.shape[0]):
for j in range(a.shape[1]):
ctmp = a[i, j] - b[j, i].conjugate()
tmp = ctmp.real * ctmp.real + ctmp.imag * ctmp.imag
for j in range(i, a.shape[1]):
tmp = cabs(a[i, j] - a[j, i].conjugate())
if tmp > tol:
return False
return True
@cython.boundscheck(False)
@cython.wraparound(False)
cdef int _is_hermitian_3d(
complex[:, :, :] a, double atol=1e-300, double rtol=1e-13
) except -1:
"Return True if 'a' is Hermitian"
if a.shape[1] != a.shape[2]:
return False
# compute max(a)
cdef double tmp, max_a = 0
cdef gint i, j, k
for k in range(a.shape[0]):
for i in range(a.shape[1]):
for j in range(a.shape[2]):
tmp = cabs(a[k, i, j])
if tmp > max_a:
max_a = tmp
max_a = math.sqrt(max_a)
cdef double tol = rtol * max_a + atol
for k in range(a.shape[0]):
for i in range(a.shape[1]):
for j in range(i, a.shape[2]):
tmp = cabs(a[k, i, j] - a[k, j, i].conjugate())
if tmp > tol:
return False
return True
@cython.boundscheck(False)
......@@ -107,22 +139,28 @@ cdef int _check_onsite(complex[:, :] M, gint norbs,
raise UserCodeError('Onsite matrix is not square')
if M.shape[0] != norbs:
raise UserCodeError(_shape_msg.format('Onsite'))
if check_hermiticity and not _is_herm_conj(M, M):
if check_hermiticity and not _is_hermitian(M):
raise ValueError(_herm_msg.format('Onsite'))
return 0
cdef int _check_ham(complex[:, :] H, ham, args, params,
gint a, gint a_norbs, gint b, gint b_norbs,
int check_hermiticity) except -1:
"Check Hamiltonian matrix for correct shape and hermiticity."
if H.shape[0] != a_norbs and H.shape[1] != b_norbs:
cdef int _check_onsites(complex[:, :, :] M, gint norbs,
int check_hermiticity) except -1:
"Check onsite matrix for correct shape and hermiticity."
if M.shape[1] != M.shape[2]:
raise UserCodeError('Onsite matrix is not square')
if M.shape[1] != norbs:
raise UserCodeError(_shape_msg.format('Onsite'))
if check_hermiticity and not _is_hermitian_3d(M):
raise ValueError(_herm_msg.format('Onsite'))
return 0
cdef int _check_hams(complex[:, :, :] H, gint to_norbs, gint from_norbs,
int check_hermiticity) except -1:
if H.shape[1] != to_norbs or H.shape[2] != from_norbs:
raise UserCodeError(_shape_msg.format('Hamiltonian'))
if check_hermiticity:
# call the "partner" element if we are not on the diagonal
H_conj = H if a == b else ta.matrix(ham(b, a, *args, params=params),
complex)
if not _is_herm_conj(H_conj, H):
if check_hermiticity and not _is_hermitian_3d(H):
raise ValueError(_herm_msg.format('Hamiltonian'))
return 0
......@@ -901,6 +939,7 @@ cdef class _LocalOperator:
# All sites selected by 'which' are part of the same site family.
site_offsets = _select(self.where, which)[:, 0] - start_site
data = self.onsite(sr, site_offsets, *args)
_check_onsites(data, norbs, self.check_hermiticity)
return data
matrix_elements = _make_matrix_elements(eval_onsite, self._terms)
......@@ -929,6 +968,13 @@ cdef class _LocalOperator:
args=args, params=params)
if herm_conj:
data = data.conjugate().transpose(0, 2, 1)
# Checks for data consistency
(to_sr, from_sr), _ = syst.subgraphs[syst.terms[term_id].subgraph]
to_norbs = syst.site_ranges[to_sr][1]
from_norbs = syst.site_ranges[from_sr][1]
if herm_conj:
to_norbs, from_norbs = from_norbs, to_norbs
_check_hams(data, to_norbs, from_norbs, is_onsite and check_hermiticity)
return data
......@@ -948,6 +994,11 @@ cdef class _LocalOperator:
for i in which
]
data = _normalize_matrix_blocks(data, len(which))
# Checks for data consistency
(to_sr, from_sr) = term_id
to_norbs = syst.site_ranges[to_sr][1]
from_norbs = syst.site_ranges[from_sr][1]
_check_hams(data, to_norbs, from_norbs, is_onsite and check_hermiticity)
return data
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment