From 85695ef5eaa82401e9fc1dbce6363c0a8323954e Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Mon, 9 Sep 2019 12:02:36 +0200
Subject: [PATCH] update TranslationalSymmetry to work with SiteArrays

Now 'which' and 'act' (and, by extension, 'to_fd') now accept
SiteArrays as well as Sites.
---
 kwant/builder.py | 60 +++++++++++++++++++++++++++++++++++++++++-------
 kwant/lattice.py | 44 +++++++++++++++++++++++++++--------
 2 files changed, 87 insertions(+), 17 deletions(-)

diff --git a/kwant/builder.py b/kwant/builder.py
index e28f63f4..16bb4f46 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -139,19 +139,44 @@ class Symmetry(metaclass=abc.ABCMeta):
     def which(self, site):
         """Calculate the domain of the site.
 
-        Return the group element whose action on a certain site from the
-        fundamental domain will result in the given ``site``.
+        Parameters
+        ----------
+        site : `~kwant.system.Site` or `~kwant.system.SiteArray`
+
+        Returns
+        -------
+        group_element : tuple or sequence of tuples
+            A single tuple if ``site`` is a Site, or a sequence of tuples if
+            ``site`` is a SiteArray.  The group element(s) whose action
+            on a certain site(s) from the fundamental domain will result
+            in the given ``site``.
         """
         pass
 
     @abc.abstractmethod
     def act(self, element, a, b=None):
-        """Act with a symmetry group element on a site or hopping."""
+        """Act with symmetry group element(s) on site(s) or hopping(s).
+
+        Parameters
+        ----------
+        element : tuple or sequence of tuples
+            Group element(s) with which to act on the provided site(s)
+            or hopping(s)
+        a, b : `~kwant.system.Site` or `~kwant.system.SiteArray`
+            If Site then ``element`` is a single tuple, if SiteArray then
+            ``element`` is a sequence of tuples. If only ``a`` is provided then
+            ``element`` acts on the site(s) of ``a``. If ``b`` is also provided
+            then ``element`` acts on the hopping(s) ``(a, b)``.
+        """
         pass
 
     def to_fd(self, a, b=None):
         """Map a site or hopping to the fundamental domain.
 
+        Parameters
+        ----------
+        a, b : `~kwant.system.Site` or `~kwant.system.SiteArray`
+
         If ``b`` is None, return a site equivalent to ``a`` within the
         fundamental domain.  Otherwise, return a hopping equivalent to ``(a,
         b)`` but where the first element belongs to the fundamental domain.
@@ -161,11 +186,30 @@ class Symmetry(metaclass=abc.ABCMeta):
         return self.act(-self.which(a), a, b)
 
     def in_fd(self, site):
-        """Tell whether ``site`` lies within the fundamental domain."""
-        for d in self.which(site):
-            if d != 0:
-                return False
-        return True
+        """Tell whether ``site`` lies within the fundamental domain.
+
+        Parameters
+        ----------
+        site : `~kwant.system.Site` or `~kwant.system.SiteArray`
+
+        Returns
+        -------
+        in_fd : bool or sequence of bool
+            single bool if ``site`` is a Site, or a sequence of
+            bool if ``site`` is a SiteArray. In the latter case
+            we return whether each site in the SiteArray is in
+            the fundamental domain.
+        """
+        if isinstance(site, Site):
+            for d in self.which(site):
+                if d != 0:
+                    return False
+            return True
+        elif isinstance(site, SiteArray):
+            which = self.which(site)
+            return np.logical_and.reduce(which != 0, axis=1)
+        else:
+            raise TypeError("'site' must be a Site or SiteArray")
 
     @abc.abstractmethod
     def subgroup(self, *generators):
diff --git a/kwant/lattice.py b/kwant/lattice.py
index b6db9a14..e2557796 100644
--- a/kwant/lattice.py
+++ b/kwant/lattice.py
@@ -698,26 +698,48 @@ class TranslationalSymmetry(builder.Symmetry):
 
     def which(self, site):
         det_x_inv_m_part, det_m = self._get_site_family_data(site.family)[-2:]
-        result = ta.dot(det_x_inv_m_part, site.tag) // det_m
+        if isinstance(site, system.Site):
+            result = ta.dot(det_x_inv_m_part, site.tag) // det_m
+        elif isinstance(site, system.SiteArray):
+            result = np.dot(det_x_inv_m_part, site.tags.transpose()) // det_m
+        else:
+            raise TypeError("'site' must be a Site or a SiteArray")
+
         return -result if self.is_reversed else result
 
     def act(self, element, a, b=None):
-        element = ta.array(element)
-        if element.dtype is not int:
+        is_site = isinstance(a, system.Site)
+        # Tinyarray for small arrays (single site) else numpy
+        array_mod = ta if is_site else np
+        element = array_mod.array(element)
+        if not np.issubdtype(element.dtype, np.integer):
             raise ValueError("group element must be a tuple of integers")
+        if (len(element.shape) == 2 and is_site):
+            raise ValueError("must provide a single group element when "
+                             "acting on single sites.")
+        if (len(element.shape) == 1 and not is_site):
+            raise ValueError("must provide a sequence of group elements "
+                             "when acting on site arrays.")
         m_part = self._get_site_family_data(a.family)[0]
         try:
-            delta = ta.dot(m_part, element)
+            delta = array_mod.dot(m_part, element)
         except ValueError:
             msg = 'Expecting a {0}-tuple group element, but got `{1}` instead.'
             raise ValueError(msg.format(self.num_directions, element))
         if self.is_reversed:
             delta = -delta
         if b is None:
-            return builder.Site(a.family, a.tag + delta, True)
+            if is_site:
+                return system.Site(a.family, a.tag + delta, True)
+            else:
+                return system.SiteArray(a.family, a.tags + delta.transpose())
         elif b.family == a.family:
-            return (builder.Site(a.family, a.tag + delta, True),
-                    builder.Site(b.family, b.tag + delta, True))
+            if is_site:
+                return (system.Site(a.family, a.tag + delta, True),
+                        system.Site(b.family, b.tag + delta, True))
+            else:
+                return (system.SiteArray(a.family, a.tags + delta.transpose()),
+                        system.SiteArray(b.family, b.tags + delta.transpose()))
         else:
             m_part = self._get_site_family_data(b.family)[0]
             try:
@@ -728,8 +750,12 @@ class TranslationalSymmetry(builder.Symmetry):
                 raise ValueError(msg.format(self.num_directions, element))
             if self.is_reversed:
                 delta2 = -delta2
-            return (builder.Site(a.family, a.tag + delta, True),
-                    builder.Site(b.family, b.tag + delta2, True))
+            if is_site:
+                return (system.Site(a.family, a.tag + delta, True),
+                        system.Site(b.family, b.tag + delta2, True))
+            else:
+                return (system.SiteArray(a.family, a.tags + delta.transpose()),
+                        system.SiteArray(b.family, b.tags + delta2.transpose()))
 
     def reversed(self):
         """Return a reversed copy of the symmetry.
-- 
GitLab