Skip to content
Snippets Groups Projects
Commit cfedec56 authored by Joseph Weston's avatar Joseph Weston
Browse files

improve error message in _normalize_matrix_blocks

parent 6710d4b9
No related branches found
No related tags found
No related merge requests found
Pipeline #25464 passed
......@@ -759,17 +759,23 @@ def _normalize_matrix_blocks(blocks, expected_shape, *, calling_function=None):
raise ValueError(
"Matrix elements declared with incompatible shapes."
) from None
original_shape = blocks.shape
was_broadcast = True # Did the shape get broadcasted to a more general one?
if len(blocks.shape) == 0: # scalar → broadcast to vector of 1x1 matrices
blocks = np.tile(blocks, (expected_shape[0], 1, 1))
elif len(blocks.shape) == 1: # vector → interpret as vector of 1x1 matrices
blocks = blocks.reshape(-1, 1, 1)
elif len(blocks.shape) == 2: # matrix → broadcast to vector of matrices
blocks = np.tile(blocks, (expected_shape[0], 1, 1))
else:
was_broadcast = False
if blocks.shape != expected_shape:
msg = (
"Expected values of shape {}, but received values of shape {}"
.format(expected_shape, blocks.shape),
"(broadcasted from shape {})".format(original_shape)
if was_broadcast else "",
"when evaluating {}".format(calling_function.__name__)
if callable(calling_function) else "",
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment