From a954beab0c78f83bdf707e329fc21ec13d0ae209 Mon Sep 17 00:00:00 2001 From: Anton Akhmerov <anton.akhmerov@gmail.com> Date: Tue, 29 Jan 2013 12:06:10 +0100 Subject: [PATCH] minor refactoring of TranslationalSymmetry --- kwant/builder.py | 1 - kwant/lattice.py | 26 ++++++++++++-------------- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/kwant/builder.py b/kwant/builder.py index 8bddf07e..bae5a3d5 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -970,7 +970,6 @@ class Builder(object): 'different site groups. See tutorial for more details.' raise ValueError(msg.format(tuple(groups))) - all_doms = list(sym.which(site)[0] for site in self.H if sym.to_fd(site) in H) if origin is not None: diff --git a/kwant/lattice.py b/kwant/lattice.py index 5a661470..af54c1a7 100644 --- a/kwant/lattice.py +++ b/kwant/lattice.py @@ -307,28 +307,26 @@ class TranslationalSymmetry(builder.Symmetry): det_x_inv_m_part = det_x_inv_m[:num_dir, :] m_part = m[:, :num_dir] - self.site_group_data[gr] = (ta.array(det_x_inv_m_part), - ta.array(m_part), - det_m) + self.site_group_data[gr] = (ta.array(m_part), + ta.array(det_x_inv_m_part), det_m) @property def num_directions(self): return len(self.periods) - def which(self, site): + def _get_site_group_data(self, group): try: - det_x_inv_m_part, m_part, det_m = self.site_group_data[site.group] + return self.site_group_data[group] except KeyError: - self.add_site_group(site.group) - return self.which(site) + self.add_site_group(group) + return self.site_group_data[group] + + def which(self, site): + det_x_inv_m_part, det_m = self._get_site_group_data(site.group)[-2:] return ta.dot(det_x_inv_m_part, site.tag) // det_m def act(self, element, a, b=None): - try: - det_x_inv_m_part, m_part, det_m = self.site_group_data[a.group] - except KeyError: - self.add_site_group(a.group) - return self.act(element, a, b) + m_part = self._get_site_group_data(a.group)[0] try: delta = ta.dot(m_part, element) except ValueError: @@ -353,13 +351,13 @@ class TranslationalSymmetry(builder.Symmetry): periods = [[-i for i in j] for j in self.periods] result = TranslationalSymmetry(*periods) for gr in self.site_group_data: - det_x_inv_m_part, m_part, det_m = self.site_group_data[gr] + m_part, det_x_inv_m_part, det_m = self.site_group_data[gr] if self.num_directions % 2: det_m = -det_m else: det_x_inv_m_part = -det_x_inv_m_part m_part = -m_part - result.site_group_data[gr] = (det_x_inv_m_part, m_part, det_m) + result.site_group_data[gr] = (m_part, det_x_inv_m_part, det_m) return result -- GitLab