Commit c2e922b0 authored by Tómas's avatar Tómas Committed by Joseph Weston

validate always returns list, improve a test

parent c9908074
......@@ -141,11 +141,11 @@ class DiscreteSymmetry:
Returns
-------
broken_symmetries : list or ``None``
broken_symmetries : list
List of strings, the names of symmetries broken by the
matrix: any combination of "Conservation law", "Time reversal",
"Particle-hole", "Chiral". If no symmetries are broken, returns
None.
an empty list.
"""
# Extra transposes are to enforse sparse dot product in case matrix is
# dense.
......@@ -173,10 +173,7 @@ class DiscreteSymmetry:
commutator = commutator - sign * cond_conj(matrix, conj)
if np.linalg.norm(commutator.data) > 1e-8:
broken_symmetries.append(name)
if not len(broken_symmetries):
return None
else:
return broken_symmetries
return broken_symmetries
def __getitem__(self, item):
return (self.projectors, self.time_reversal,
......
......@@ -152,21 +152,21 @@ def test_validate():
sym = DiscreteSymmetry(projectors=[csr(np.array([[1], [0]])),
csr(np.array([[0], [1]]))])
assert sym.validate(csr(np.array([[0], [1]]))) == ['Conservation law']
assert sym.validate(np.array([[1], [0]])) is None
assert sym.validate(np.eye(2)) is None
assert sym.validate(np.array([[1], [0]])) == []
assert sym.validate(np.eye(2)) == []
assert sym.validate(1 - np.eye(2)) == ['Conservation law']
sym = DiscreteSymmetry(particle_hole=sparse.identity(2))
assert sym.validate(1j * sparse.identity(2)) is None
assert sym.validate(1j * sparse.identity(2)) == []
assert sym.validate(sparse.identity(2)) == ['Particle-hole']
sym = DiscreteSymmetry(time_reversal=sparse.identity(2))
assert sym.validate(sparse.identity(2)) is None
assert sym.validate(sparse.identity(2)) == []
assert sym.validate(1j * sparse.identity(2)) == ['Time reversal']
sym = DiscreteSymmetry(chiral=csr(np.diag((1, -1))))
assert sym.validate(np.eye(2)) == ['Chiral']
assert sym.validate(1 - np.eye(2)) is None
assert sym.validate(1 - np.eye(2)) == []
def random_onsite_hop(n, rng=0):
......@@ -178,7 +178,13 @@ def random_onsite_hop(n, rng=0):
def test_validate_commutator():
symm_class = ['AI', 'AII', 'D', 'C', 'AIII']
symm_class = ['AI', 'AII', 'D', 'C', 'AIII', 'BDI']
sym_dict = {'AI': ['Time reversal'],
'AII': ['Time reversal'],
'D': ['Particle-hole'],
'C': ['Particle-hole'],
'AIII': ['Chiral'],
'BDI': ['Time reversal', 'Particle-hole', 'Chiral']}
n = 10
rng = 10
for sym in symm_class:
......@@ -201,6 +207,7 @@ def test_validate_commutator():
disc_symm = DiscreteSymmetry(particle_hole=p_mat,
time_reversal=t_mat,
chiral=c_mat)
assert disc_symm.validate(h) == None
assert disc_symm.validate(h) == []
a = random_onsite_hop(n, rng=rng)[1]
assert len(disc_symm.validate(a))
for symmetry in disc_symm.validate(a):
assert symmetry in sym_dict[sym]
......@@ -249,9 +249,7 @@ class InfiniteSystem(System, metaclass=abc.ABCMeta):
symmetries = self.discrete_symmetry(args, params=params)
# Check whether each symmetry is broken.
# If a symmetry is broken, it is ignored in the computation.
broken = {symmetry for item in (symmetries.validate(ham),
symmetries.validate(hop))
if item is not None for symmetry in item}
broken = set(symmetries.validate(ham) + symmetries.validate(hop))
for name in broken:
warnings.warn('Hamiltonian breaks ' + name +
', ignoring the symmetry in the computation.')
......@@ -304,13 +302,8 @@ class InfiniteSystem(System, metaclass=abc.ABCMeta):
symmetries = self.discrete_symmetry(args=args, params=params)
ham = self.cell_hamiltonian(args=args, sparse=True, params=params)
hop = self.inter_cell_hopping(args=args, sparse=True, params=params)
broken = list({symmetry for item in (symmetries.validate(ham),
symmetries.validate(hop))
if item is not None for symmetry in item})
if len(broken):
return broken
else:
return broken
broken = set(symmetries.validate(ham) + symmetries.validate(hop))
return list(broken)
class PrecalculatedLead:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment