Skip to content
Snippets Groups Projects
Commit 08cbaaad authored by Jörg Behrmann's avatar Jörg Behrmann Committed by Joseph Weston
Browse files

define ordering for `SiteFamily` and add a test for this

parent 91b58d4a
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ __all__ = ['Builder', 'Site', 'SiteFamily', 'SimpleSiteFamily', 'Symmetry',
import abc
import warnings
import operator
from functools import total_ordering
from itertools import islice, chain
import tinyarray as ta
import numpy as np
......@@ -82,6 +83,7 @@ class Site(tuple):
return self.family.pos(self.tag)
@total_ordering
class SiteFamily(metaclass=abc.ABCMeta):
"""Abstract base class for site families.
......@@ -154,6 +156,12 @@ class SiteFamily(metaclass=abc.ABCMeta):
except AttributeError:
return True
def __lt__(self, other):
try:
return self.canonical_repr < other.canonical_repr
except AttributeError:
return False
@abc.abstractmethod
def normalize_tag(self, tag):
"""Return a normalized version of the tag.
......
......@@ -111,6 +111,23 @@ def test_site_families():
assert_not_equal(fam, 'a')
def test_site_families_sorting():
fam1 = builder.SimpleSiteFamily('fam1')
fam2 = builder.SimpleSiteFamily('fam2')
rng = Random(123)
tags = [(rng.randint(0,10), rng.randint(0,10)) for i in range(10)]
tags_sorted = [t for t in sorted(tags)]
sites = [fam1(*t) for t in tags] + [fam2(*t) for t in tags]
rng.shuffle(sites)
sites_sorted = ([fam1(*t) for t in tags_sorted] +
[fam2(*t) for t in tags_sorted])
assert_equal(list(sorted(sites)), sites_sorted)
class VerySimpleSymmetry(builder.Symmetry):
def __init__(self, period):
self.period = period
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment