From e5d30c69c2418da29ef984711713e9cb8333e52d Mon Sep 17 00:00:00 2001 From: Joseph Weston <joseph@weston.cloud> Date: Mon, 9 Sep 2019 12:06:06 +0200 Subject: [PATCH] add SiteArray and update SiteFamily to support them We remove the ABC metaclass as now 'normalize_tag' and 'normalize_tags' are cyclically defined, and subclasses must redefine at least 1. --- doc/source/reference/kwant.system.rst | 1 + kwant/system.py | 88 ++++++++++++++++++++++++--- 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/doc/source/reference/kwant.system.rst b/doc/source/reference/kwant.system.rst index 6d6dbec6..357765c9 100644 --- a/doc/source/reference/kwant.system.rst +++ b/doc/source/reference/kwant.system.rst @@ -25,6 +25,7 @@ Sites :toctree: generated/ Site + SiteArray SiteFamily Systems diff --git a/kwant/system.py b/kwant/system.py index 5498ead7..b08f27cf 100644 --- a/kwant/system.py +++ b/kwant/system.py @@ -93,8 +93,62 @@ class Site(tuple): return self.family.pos(self.tag) +class SiteArray: + """An array of sites, members of a `SiteFamily`. + + Parameters + ---------- + family : an instance of `SiteFamily` + The 'type' of the sites. + tags : a sequence of python objects + Sequence of unique identifiers of the sites within the + site array family, typically vectors of integers. + + Raises + ------ + ValueError + If `tags` are not proper tags for `family`. + + See Also + -------- + kwant.system.Site + """ + + def __init__(self, family, tags): + self.family = family + try: + tags = family.normalize_tags(tags) + except (TypeError, ValueError) as e: + msg = 'Tags {0} are not allowed for site family {1}: {2}' + raise type(e)(msg.format(repr(tags), repr(family), e.args[0])) + self.tags = tags + + def __repr__(self): + return 'SiteArray({0}, {1})'.format(repr(self.family), repr(self.tags)) + + def __str__(self): + sf = self.family + return ('<SiteArray {0} of {1}>' + .format(self.tags, sf.name if sf.name else sf)) + + def __len__(self): + return len(self.tags) + + def __eq__(self, other): + if not isinstance(other, SiteArray): + raise NotImplementedError() + return self.family == other.family and np.all(self.tags == other.tags) + + def positions(self): + """Real space position of the site. + + This relies on ``family`` having a ``pos`` method (see `SiteFamily`). + """ + return self.family.positions(self.tags) + + @total_ordering -class SiteFamily(metaclass=abc.ABCMeta): +class SiteFamily: """Abstract base class for site families. Site families are the 'type' of `Site` objects. Within a family, individual @@ -112,12 +166,16 @@ class SiteFamily(metaclass=abc.ABCMeta): the number of orbitals is not specified. - All site families must define the method `normalize_tag` which brings a tag - to the standard format for this site family. + All site families must define either 'normalize_tag' or 'normalize_tags', + which brings a tag (or, in the latter case, a sequence of tags) to the + standard format for this site family. - Site families that are intended for use with plotting should also provide a - method `pos(tag)`, which returns a vector with real-space coordinates of the - site belonging to this family with a given tag. + Site families may also implement methods ``pos(tag)`` and + ``positions(tags)``, which return a vector of realspace coordinates or an + array of vectors of realspace coordinates of the site(s) belonging to this + family with the given tag(s). These methods are used in plotting routines. + ``positions(tags)`` should return an array with shape ``(N, M)`` where + ``N`` is the length of ``tags``, and ``M`` is the realspace dimension. If the ``norbs`` of a site family are provided, and sites of this family are used to populate a `~kwant.builder.Builder`, then the associated @@ -144,6 +202,13 @@ class SiteFamily(metaclass=abc.ABCMeta): norbs = int(norbs) self.norbs = norbs + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + if (cls.normalize_tag is SiteFamily.normalize_tag + and cls.normalize_tags is SiteFamily.normalize_tags): + raise TypeError("Must redefine either 'normalize_tag' or " + "'normalize_tags'") + def __repr__(self): return self.canonical_repr @@ -175,13 +240,20 @@ class SiteFamily(metaclass=abc.ABCMeta): # to compare it to something non-comparable anyway. return self.canonical_repr < other.canonical_repr - @abc.abstractmethod def normalize_tag(self, tag): """Return a normalized version of the tag. Raises TypeError or ValueError if the tag is not acceptable. """ - pass + tag, = self.normalize_tags([tag]) + return tag + + def normalize_tags(self, tags): + """Return a normalized version of the tags. + + Raises TypeError or ValueError if the tags are not acceptable. + """ + return np.array([self.normalize_tag(tag) for tag in tags]) def __call__(self, *tag): """ -- GitLab