From e5d30c69c2418da29ef984711713e9cb8333e52d Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Mon, 9 Sep 2019 12:06:06 +0200
Subject: [PATCH] add SiteArray and update SiteFamily to support them

We remove the ABC metaclass as now 'normalize_tag' and 'normalize_tags'
are cyclically defined, and subclasses must redefine at least 1.
---
 doc/source/reference/kwant.system.rst |  1 +
 kwant/system.py                       | 88 ++++++++++++++++++++++++---
 2 files changed, 81 insertions(+), 8 deletions(-)

diff --git a/doc/source/reference/kwant.system.rst b/doc/source/reference/kwant.system.rst
index 6d6dbec6..357765c9 100644
--- a/doc/source/reference/kwant.system.rst
+++ b/doc/source/reference/kwant.system.rst
@@ -25,6 +25,7 @@ Sites
    :toctree: generated/
 
    Site
+   SiteArray
    SiteFamily
 
 Systems
diff --git a/kwant/system.py b/kwant/system.py
index 5498ead7..b08f27cf 100644
--- a/kwant/system.py
+++ b/kwant/system.py
@@ -93,8 +93,62 @@ class Site(tuple):
         return self.family.pos(self.tag)
 
 
+class SiteArray:
+    """An array of sites, members of a `SiteFamily`.
+
+    Parameters
+    ----------
+    family : an instance of `SiteFamily`
+        The 'type' of the sites.
+    tags : a sequence of python objects
+        Sequence of unique identifiers of the sites within the
+        site array family, typically vectors of integers.
+
+    Raises
+    ------
+    ValueError
+        If `tags` are not proper tags for `family`.
+
+    See Also
+    --------
+    kwant.system.Site
+    """
+
+    def __init__(self, family, tags):
+        self.family = family
+        try:
+            tags = family.normalize_tags(tags)
+        except (TypeError, ValueError) as e:
+            msg = 'Tags {0} are not allowed for site family {1}: {2}'
+            raise type(e)(msg.format(repr(tags), repr(family), e.args[0]))
+        self.tags = tags
+
+    def __repr__(self):
+        return 'SiteArray({0}, {1})'.format(repr(self.family), repr(self.tags))
+
+    def __str__(self):
+        sf = self.family
+        return ('<SiteArray {0} of {1}>'
+                .format(self.tags, sf.name if sf.name else sf))
+
+    def __len__(self):
+        return len(self.tags)
+
+    def __eq__(self, other):
+        if not isinstance(other, SiteArray):
+            raise NotImplementedError()
+        return self.family == other.family and np.all(self.tags == other.tags)
+
+    def positions(self):
+        """Real space position of the site.
+
+        This relies on ``family`` having a ``pos`` method (see `SiteFamily`).
+        """
+        return self.family.positions(self.tags)
+
+
 @total_ordering
-class SiteFamily(metaclass=abc.ABCMeta):
+class SiteFamily:
     """Abstract base class for site families.
 
     Site families are the 'type' of `Site` objects.  Within a family, individual
@@ -112,12 +166,16 @@ class SiteFamily(metaclass=abc.ABCMeta):
     the number of orbitals is not specified.
 
 
-    All site families must define the method `normalize_tag` which brings a tag
-    to the standard format for this site family.
+    All site families must define either 'normalize_tag' or 'normalize_tags',
+    which brings a tag (or, in the latter case, a sequence of tags) to the
+    standard format for this site family.
 
-    Site families that are intended for use with plotting should also provide a
-    method `pos(tag)`, which returns a vector with real-space coordinates of the
-    site belonging to this family with a given tag.
+    Site families may also implement methods ``pos(tag)`` and
+    ``positions(tags)``, which return a vector of realspace coordinates or an
+    array of vectors of realspace coordinates of the site(s) belonging to this
+    family with the given tag(s). These methods are used in plotting routines.
+    ``positions(tags)`` should return an array with shape ``(N, M)`` where
+    ``N`` is the length of ``tags``, and ``M`` is the realspace dimension.
 
     If the ``norbs`` of a site family are provided, and sites of this family
     are used to populate a `~kwant.builder.Builder`, then the associated
@@ -144,6 +202,13 @@ class SiteFamily(metaclass=abc.ABCMeta):
             norbs = int(norbs)
         self.norbs = norbs
 
+    def __init_subclass__(cls, **kwargs):
+        super().__init_subclass__(**kwargs)
+        if (cls.normalize_tag is SiteFamily.normalize_tag
+            and cls.normalize_tags is SiteFamily.normalize_tags):
+            raise TypeError("Must redefine either 'normalize_tag' or "
+                            "'normalize_tags'")
+
     def __repr__(self):
         return self.canonical_repr
 
@@ -175,13 +240,20 @@ class SiteFamily(metaclass=abc.ABCMeta):
         # to compare it to something non-comparable anyway.
         return self.canonical_repr < other.canonical_repr
 
-    @abc.abstractmethod
     def normalize_tag(self, tag):
         """Return a normalized version of the tag.
 
         Raises TypeError or ValueError if the tag is not acceptable.
         """
-        pass
+        tag, = self.normalize_tags([tag])
+        return tag
+
+    def normalize_tags(self, tags):
+        """Return a normalized version of the tags.
+
+        Raises TypeError or ValueError if the tags are not acceptable.
+        """
+        return np.array([self.normalize_tag(tag) for tag in tags])
 
     def __call__(self, *tag):
         """
-- 
GitLab