Skip to content
Snippets Groups Projects
Commit 08cbaaad authored by Jörg Behrmann's avatar Jörg Behrmann Committed by Joseph Weston
Browse files

define ordering for `SiteFamily` and add a test for this

parent 91b58d4a
No related branches found
No related tags found
No related merge requests found
...@@ -12,6 +12,7 @@ __all__ = ['Builder', 'Site', 'SiteFamily', 'SimpleSiteFamily', 'Symmetry', ...@@ -12,6 +12,7 @@ __all__ = ['Builder', 'Site', 'SiteFamily', 'SimpleSiteFamily', 'Symmetry',
import abc import abc
import warnings import warnings
import operator import operator
from functools import total_ordering
from itertools import islice, chain from itertools import islice, chain
import tinyarray as ta import tinyarray as ta
import numpy as np import numpy as np
...@@ -82,6 +83,7 @@ class Site(tuple): ...@@ -82,6 +83,7 @@ class Site(tuple):
return self.family.pos(self.tag) return self.family.pos(self.tag)
@total_ordering
class SiteFamily(metaclass=abc.ABCMeta): class SiteFamily(metaclass=abc.ABCMeta):
"""Abstract base class for site families. """Abstract base class for site families.
...@@ -154,6 +156,12 @@ class SiteFamily(metaclass=abc.ABCMeta): ...@@ -154,6 +156,12 @@ class SiteFamily(metaclass=abc.ABCMeta):
except AttributeError: except AttributeError:
return True return True
def __lt__(self, other):
try:
return self.canonical_repr < other.canonical_repr
except AttributeError:
return False
@abc.abstractmethod @abc.abstractmethod
def normalize_tag(self, tag): def normalize_tag(self, tag):
"""Return a normalized version of the tag. """Return a normalized version of the tag.
......
...@@ -111,6 +111,23 @@ def test_site_families(): ...@@ -111,6 +111,23 @@ def test_site_families():
assert_not_equal(fam, 'a') assert_not_equal(fam, 'a')
def test_site_families_sorting():
fam1 = builder.SimpleSiteFamily('fam1')
fam2 = builder.SimpleSiteFamily('fam2')
rng = Random(123)
tags = [(rng.randint(0,10), rng.randint(0,10)) for i in range(10)]
tags_sorted = [t for t in sorted(tags)]
sites = [fam1(*t) for t in tags] + [fam2(*t) for t in tags]
rng.shuffle(sites)
sites_sorted = ([fam1(*t) for t in tags_sorted] +
[fam2(*t) for t in tags_sorted])
assert_equal(list(sorted(sites)), sites_sorted)
class VerySimpleSymmetry(builder.Symmetry): class VerySimpleSymmetry(builder.Symmetry):
def __init__(self, period): def __init__(self, period):
self.period = period self.period = period
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment