From ac9b4369c8262b6462953c1c0251da3cd9c27f59 Mon Sep 17 00:00:00 2001 From: Anton Akhmerov <anton.akhmerov@gmail.com> Date: Sun, 17 Dec 2023 00:14:05 +0100 Subject: [PATCH] raise a proper error in check_symmetry closes #44 --- qsymm/hamiltonian_generator.py | 17 +++++++++++++---- qsymm/tests/test_hamiltonian_generator.py | 9 +++++++++ 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/qsymm/hamiltonian_generator.py b/qsymm/hamiltonian_generator.py index ebda7e9..5e01e13 100644 --- a/qsymm/hamiltonian_generator.py +++ b/qsymm/hamiltonian_generator.py @@ -353,19 +353,28 @@ def check_symmetry(family, symmetries, num_digits=None): In the case that the input family has been rounded, num_digits should be the number of significant digits to which the family was rounded. - """ + Raises + ------ + ValueError + If the family does not satisfy the symmetry. + """ + def fail(): + raise ValueError(f'Member {member} does not satisfy symmetry {symmetry}.') for symmetry in symmetries: # Iterate over all members of the family for member in family: if isinstance(symmetry, PointGroupElement): if num_digits is None: - assert symmetry.apply(member) == member + if symmetry.apply(member) != member: + fail() else: - assert symmetry.apply(member).around(num_digits) == member.around(num_digits) + if symmetry.apply(member).around(num_digits) != member.around(num_digits): + fail() elif isinstance(symmetry, ContinuousGroupGenerator): # Continous symmetry if applying it returns zero - assert symmetry.apply(member) == {} + if symmetry.apply(member) != {}: + fail() def constrain_family(symmetries, family, sparse_linalg=False): diff --git a/qsymm/tests/test_hamiltonian_generator.py b/qsymm/tests/test_hamiltonian_generator.py index 3563fd5..28e574b 100644 --- a/qsymm/tests/test_hamiltonian_generator.py +++ b/qsymm/tests/test_hamiltonian_generator.py @@ -1,3 +1,5 @@ +from pytest import raises + import sympy import numpy as np import scipy.linalg as la @@ -58,6 +60,13 @@ def test_check_symmetry(): else: # Symmetry commutes assert np.allclose(Left - Right, 0) + # Test correctly raising a ValueError + with raises(ValueError): + check_symmetry( + [Model({1: np.eye(2)}, momenta=["k_x"])], + [chiral(1, np.diag([1, -1]))] + ) + def test_bloch_generator(): """Square lattice with time reversal and rotation symmetry, such that all hoppings are real. """ -- GitLab