Skip to content
Snippets Groups Projects
Commit 4b74beca authored by Joseph Weston's avatar Joseph Weston
Browse files

add option to '_normalize_matrix_blocks' for more informative errors

We allow the function that produced the data to be passed to
'_normalize_matrix_blocks', and any error messages are augmented
with this information. This may prove useful for debugging.
parent 23574cf5
Branches
No related tags found
No related merge requests found
...@@ -2127,7 +2127,8 @@ class _VectorizedFinalizedBuilderMixin(_FinalizedBuilderMixin): ...@@ -2127,7 +2127,8 @@ class _VectorizedFinalizedBuilderMixin(_FinalizedBuilderMixin):
to_family.norbs, to_family.norbs,
from_family.norbs if not is_onsite else to_family.norbs, from_family.norbs if not is_onsite else to_family.norbs,
) )
ham = system._normalize_matrix_blocks(ham, expected_shape) ham = system._normalize_matrix_blocks(ham, expected_shape,
calling_function=val)
return ham return ham
......
...@@ -334,7 +334,8 @@ class _FunctionalOnsite: ...@@ -334,7 +334,8 @@ class _FunctionalOnsite:
_raise_user_error(exc, self.onsite, vectorized=False) _raise_user_error(exc, self.onsite, vectorized=False)
expected_shape = (len(site_offsets), norbs, norbs) expected_shape = (len(site_offsets), norbs, norbs)
return _normalize_matrix_blocks(ret, expected_shape) return _normalize_matrix_blocks(ret, expected_shape,
calling_function=self.onsite)
class _VectorizedFunctionalOnsite: class _VectorizedFunctionalOnsite:
...@@ -353,7 +354,8 @@ class _VectorizedFunctionalOnsite: ...@@ -353,7 +354,8 @@ class _VectorizedFunctionalOnsite:
_raise_user_error(exc, self.onsite, vectorized=True) _raise_user_error(exc, self.onsite, vectorized=True)
expected_shape = (len(sites), sites.family.norbs, sites.family.norbs) expected_shape = (len(sites), sites.family.norbs, sites.family.norbs)
return _normalize_matrix_blocks(ret, expected_shape) return _normalize_matrix_blocks(ret, expected_shape,
calling_function=self.onsite)
class _FunctionalOnsiteNoTransform: class _FunctionalOnsiteNoTransform:
...@@ -371,7 +373,8 @@ class _FunctionalOnsiteNoTransform: ...@@ -371,7 +373,8 @@ class _FunctionalOnsiteNoTransform:
_raise_user_error(exc, self.onsite, vectorized=False) _raise_user_error(exc, self.onsite, vectorized=False)
expected_shape = (len(site_offsets), norbs, norbs) expected_shape = (len(site_offsets), norbs, norbs)
return _normalize_matrix_blocks(ret, expected_shape) return _normalize_matrix_blocks(ret, expected_shape,
calling_function=self.onsite)
class _DictOnsite: class _DictOnsite:
......
...@@ -742,13 +742,16 @@ def is_vectorized(syst): ...@@ -742,13 +742,16 @@ def is_vectorized(syst):
return isinstance(syst, (FiniteVectorizedSystem, InfiniteVectorizedSystem)) return isinstance(syst, (FiniteVectorizedSystem, InfiniteVectorizedSystem))
def _normalize_matrix_blocks(blocks, expected_shape): def _normalize_matrix_blocks(blocks, expected_shape, *, calling_function=None):
"""Normalize a sequence of matrices into a single 3D numpy array """Normalize a sequence of matrices into a single 3D numpy array
Parameters Parameters
---------- ----------
blocks : sequence of complex array-like blocks : sequence of complex array-like
expected_shape : (int, int, int) expected_shape : (int, int, int)
calling_function : callable (optional)
The function that produced 'blocks'. If provided, used to give
a more helpful error message if 'blocks' is not of the correct shape.
""" """
try: try:
blocks = np.asarray(blocks, dtype=complex) blocks = np.asarray(blocks, dtype=complex)
...@@ -766,9 +769,11 @@ def _normalize_matrix_blocks(blocks, expected_shape): ...@@ -766,9 +769,11 @@ def _normalize_matrix_blocks(blocks, expected_shape):
if blocks.shape != expected_shape: if blocks.shape != expected_shape:
msg = ( msg = (
"Expected values of shape {}, but received values of shape {}" "Expected values of shape {}, but received values of shape {}"
.format(expected_shape, blocks.shape) .format(expected_shape, blocks.shape),
"when evaluating {}".format(calling_function.__name__)
if callable(calling_function) else "",
) )
raise ValueError(msg) raise ValueError(" ".join(msg))
return blocks return blocks
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment