From a954beab0c78f83bdf707e329fc21ec13d0ae209 Mon Sep 17 00:00:00 2001
From: Anton Akhmerov <anton.akhmerov@gmail.com>
Date: Tue, 29 Jan 2013 12:06:10 +0100
Subject: [PATCH] minor refactoring of TranslationalSymmetry

---
 kwant/builder.py |  1 -
 kwant/lattice.py | 26 ++++++++++++--------------
 2 files changed, 12 insertions(+), 15 deletions(-)

diff --git a/kwant/builder.py b/kwant/builder.py
index 8bddf07e..bae5a3d5 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -970,7 +970,6 @@ class Builder(object):
                 'different site groups. See tutorial for more details.'
             raise ValueError(msg.format(tuple(groups)))
 
-
         all_doms = list(sym.which(site)[0]
                         for site in self.H if sym.to_fd(site) in H)
         if origin is not None:
diff --git a/kwant/lattice.py b/kwant/lattice.py
index 5a661470..af54c1a7 100644
--- a/kwant/lattice.py
+++ b/kwant/lattice.py
@@ -307,28 +307,26 @@ class TranslationalSymmetry(builder.Symmetry):
 
         det_x_inv_m_part = det_x_inv_m[:num_dir, :]
         m_part = m[:, :num_dir]
-        self.site_group_data[gr] = (ta.array(det_x_inv_m_part),
-                                    ta.array(m_part),
-                                    det_m)
+        self.site_group_data[gr] = (ta.array(m_part),
+                                    ta.array(det_x_inv_m_part), det_m)
 
     @property
     def num_directions(self):
         return len(self.periods)
 
-    def which(self, site):
+    def _get_site_group_data(self, group):
         try:
-            det_x_inv_m_part, m_part, det_m = self.site_group_data[site.group]
+            return self.site_group_data[group]
         except KeyError:
-            self.add_site_group(site.group)
-            return self.which(site)
+            self.add_site_group(group)
+            return self.site_group_data[group]
+
+    def which(self, site):
+        det_x_inv_m_part, det_m = self._get_site_group_data(site.group)[-2:]
         return ta.dot(det_x_inv_m_part, site.tag) // det_m
 
     def act(self, element, a, b=None):
-        try:
-            det_x_inv_m_part, m_part, det_m = self.site_group_data[a.group]
-        except KeyError:
-            self.add_site_group(a.group)
-            return self.act(element, a, b)
+        m_part = self._get_site_group_data(a.group)[0]
         try:
             delta = ta.dot(m_part, element)
         except ValueError:
@@ -353,13 +351,13 @@ class TranslationalSymmetry(builder.Symmetry):
         periods = [[-i for i in j] for j in self.periods]
         result = TranslationalSymmetry(*periods)
         for gr in self.site_group_data:
-            det_x_inv_m_part, m_part, det_m = self.site_group_data[gr]
+            m_part, det_x_inv_m_part, det_m = self.site_group_data[gr]
             if self.num_directions % 2:
                 det_m = -det_m
             else:
                 det_x_inv_m_part = -det_x_inv_m_part
             m_part = -m_part
-            result.site_group_data[gr] = (det_x_inv_m_part, m_part, det_m)
+            result.site_group_data[gr] = (m_part, det_x_inv_m_part, det_m)
         return result
 
 
-- 
GitLab