From ac9b4369c8262b6462953c1c0251da3cd9c27f59 Mon Sep 17 00:00:00 2001
From: Anton Akhmerov <anton.akhmerov@gmail.com>
Date: Sun, 17 Dec 2023 00:14:05 +0100
Subject: [PATCH] raise a proper error in check_symmetry

closes #44
---
 qsymm/hamiltonian_generator.py            | 17 +++++++++++++----
 qsymm/tests/test_hamiltonian_generator.py |  9 +++++++++
 2 files changed, 22 insertions(+), 4 deletions(-)

diff --git a/qsymm/hamiltonian_generator.py b/qsymm/hamiltonian_generator.py
index ebda7e9..5e01e13 100644
--- a/qsymm/hamiltonian_generator.py
+++ b/qsymm/hamiltonian_generator.py
@@ -353,19 +353,28 @@ def check_symmetry(family, symmetries, num_digits=None):
         In the case that the input family has been rounded, num_digits
         should be the number of significant digits to which the family
         was rounded.
-    """
 
+    Raises
+    ------
+    ValueError
+        If the family does not satisfy the symmetry.
+    """
+    def fail():
+        raise ValueError(f'Member {member} does not satisfy symmetry {symmetry}.')
     for symmetry in symmetries:
         # Iterate over all members of the family
         for member in family:
             if isinstance(symmetry, PointGroupElement):
                 if num_digits is None:
-                    assert symmetry.apply(member) == member
+                    if symmetry.apply(member) != member:
+                        fail()
                 else:
-                    assert symmetry.apply(member).around(num_digits) == member.around(num_digits)
+                    if symmetry.apply(member).around(num_digits) != member.around(num_digits):
+                        fail()
             elif isinstance(symmetry, ContinuousGroupGenerator):
                 # Continous symmetry if applying it returns zero
-                assert symmetry.apply(member) == {}
+                if symmetry.apply(member) != {}:
+                    fail()
 
 
 def constrain_family(symmetries, family, sparse_linalg=False):
diff --git a/qsymm/tests/test_hamiltonian_generator.py b/qsymm/tests/test_hamiltonian_generator.py
index 3563fd5..28e574b 100644
--- a/qsymm/tests/test_hamiltonian_generator.py
+++ b/qsymm/tests/test_hamiltonian_generator.py
@@ -1,3 +1,5 @@
+from pytest import raises
+
 import sympy
 import numpy as np
 import scipy.linalg as la
@@ -58,6 +60,13 @@ def test_check_symmetry():
                     else: # Symmetry commutes
                         assert np.allclose(Left - Right, 0)
 
+    # Test correctly raising a ValueError
+    with raises(ValueError):
+        check_symmetry(
+            [Model({1: np.eye(2)}, momenta=["k_x"])],
+            [chiral(1, np.diag([1, -1]))]
+    )
+
 
 def test_bloch_generator():
     """Square lattice with time reversal and rotation symmetry, such that all hoppings are real. """
-- 
GitLab