diff --git a/kwant/builder.py b/kwant/builder.py index f60a7712286a97a25f49cd8509a1d1f1e4a1ca24..89e4377e363e129cd602746d7935fae5f501e66b 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -317,6 +317,22 @@ class Symmetry(metaclass=abc.ABCMeta): return False return True + @abc.abstractmethod + def subgroup(self, *generators): + """Return the subgroup generated by a sequence of group elements.""" + pass + + @abc.abstractmethod + def isstrictsupergroup(self, other): + """Test whether other symmetry is a strict supergroup.""" + pass + + def issubgroup(self, other): + """Test whether other symmetry is a subgroup.""" + return other.isstrictsupergroup(self) + + __le__ = issubgroup + class NoSymmetry(Symmetry): """A symmetry with a trivial symmetry group.""" @@ -334,6 +350,8 @@ class NoSymmetry(Symmetry): def num_directions(self): return 0 + periods = () + _empty_array = ta.array((), int) def which(self, site): @@ -350,6 +368,15 @@ class NoSymmetry(Symmetry): def in_fd(self, site): return True + def subgroup(self, *generators): + if any(generators): + raise ValueError('Generators must be empty for NoSymmetry.') + return NoSymmetry(generators) + + def isstrictsupergroup(self, other): + return False + + ################ Hopping kinds diff --git a/kwant/lattice.py b/kwant/lattice.py index 5bc1e88f18d1d023a13b5f57e4a35c57d5ccbdaf..5b2ac3d9db00e353ebd1edb40c3bcef86e094d74 100644 --- a/kwant/lattice.py +++ b/kwant/lattice.py @@ -542,6 +542,38 @@ class TranslationalSymmetry(builder.Symmetry): self.site_family_data = {} self.is_reversed = False + def subgroup(self, *generators): + """Return the subgroup generated by a sequence of group elements. + + Parameters + ---------- + *generators: sequence of int + Each generator must have length ``self.num_directions``. + """ + generators = ta.array(generators) + if generators.dtype != int: + raise ValueError('Generators must be sequences of integers.') + return TranslationalSymmetry(*ta.dot(generators, self.periods)) + + def isstrictsupergroup(self, other): + if isinstance(other, builder.NoSymmetry): + return True + elif not isinstance(other, TranslationalSymmetry): + raise ValueError("Unknown symmetry type.") + + if other.periods.shape[1] != self.periods.shape[1]: + return False # Mismatch of spatial dimensionalities. + + inv = np.linalg.pinv(self.periods) + factors = np.dot(other.periods, inv) + # Absolute tolerance is correct in the following since we want an error + # relative to the closest integer. + if not (np.allclose(factors, np.round(factors), rtol=0, atol=1e-8) and + np.allclose(ta.dot(factors, self.periods), other.periods)): + return False + else: + return True + def add_site_family(self, fam, other_vectors=None): """ Select a fundamental domain for site family and cache associated data. diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py index 65dcf1503c1d9b69a7bbddb2b38e5431adaf4b43..bfe757edea5c757292622a05865e8c23b86a24c9 100644 --- a/kwant/tests/test_builder.py +++ b/kwant/tests/test_builder.py @@ -125,6 +125,22 @@ class VerySimpleSymmetry(builder.Symmetry): def num_directions(self): return 1 + def isstrictsupergroup(self, other): + if isinstance(other, builder.NoSymmetry): + return True + elif isinstance(other, VerySimpleSymmetry): + return not other.period % self.period + else: + return False + + def subgroup(self, *generators): + generators = ta.array(generators) + assert generators.shape == (1, 1) + if generators.dtype != int: + raise ValueError('Generators must be sequences of integers.') + g = generators[0, 0] + return VerySimpleSymmetry(g * self.period) + def which(self, site): return ta.array((site.tag[0] // self.period,), int) diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py index fa4196dce75b4caff8e8b14accc14a8d11c2bac7..45337c1074eb0b929789ce603f1f17a4af5290aa 100644 --- a/kwant/tests/test_lattice.py +++ b/kwant/tests/test_lattice.py @@ -229,3 +229,28 @@ def test_norbs(): assert lat != lat1 assert lat != lat2 assert lat1 != lat2 + + +def test_symmetry_subgroup(): + rng = np.random.RandomState(0) + ## test whether actual subgroups are detected as such + vecs = rng.randn(3, 3) + sym1 = lattice.TranslationalSymmetry(*vecs) + assert sym1 >= sym1 + assert sym1 >= builder.NoSymmetry() + assert sym1 >= lattice.TranslationalSymmetry(2 * vecs[0], + 3 * vecs[1] + 4 * vecs[2]) + assert not sym1 <= lattice.TranslationalSymmetry(*(0.8 * vecs)) + + ## test subgroup creation + for dim in range(1, 4): + generators = rng.randint(10, size=(dim, 3)) + assert sym1.subgroup(*generators) <= sym1 + + # generators are not linearly independent + with raises(ValueError): + sym1.subgroup(*rng.randint(10, size=(4, 3))) + + # generators are not integer sequences + with raises(ValueError): + sym1.subgroup(*rng.rand(1, 3))