From 17845d558d8bb9de0370353214abb73fd9cd1c0c Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Thu, 27 Jul 2017 15:22:43 +0200
Subject: [PATCH] builder: gather common code in _FinalizedBuilderMixin

---
 doc/source/conf.py |   3 +-
 kwant/builder.py   | 306 ++++++++++++++++++---------------------------
 2 files changed, 126 insertions(+), 183 deletions(-)

diff --git a/doc/source/conf.py b/doc/source/conf.py
index 1af1f28a..9b5c14cd 100644
--- a/doc/source/conf.py
+++ b/doc/source/conf.py
@@ -267,7 +267,8 @@ nitpick_ignore = [('py:class', 'Warning'), ('py:class', 'Exception'),
                   ('py:class', 'object'), ('py:class', 'tuple'),
                   ('py:class', 'kwant.operator._LocalOperator'),
                   ('py:class', 'numpy.ndarray'),
-                  ('py:class', 'kwant.solvers.common.BlockResult')]
+                  ('py:class', 'kwant.solvers.common.BlockResult'),
+                  ('py:class', 'kwant.builder._FinalizedBuilderMixin')]
 
 
 # Use custom MathJax CDN, as cdn.mathjax.org will soon shut down
diff --git a/kwant/builder.py b/kwant/builder.py
index fd8146a2..997a7e9c 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -1590,29 +1590,6 @@ def _raise_user_error(exc, func):
     raise UserCodeError(msg.format(func.__name__)) from exc
 
 
-def discrete_symmetry(self, args=(), *, params=None):
-    if self._cons_law is not None:
-        eigvals, eigvecs = self._cons_law
-        eigvals = eigvals.tocoo(args, params=params)
-        if not np.allclose(eigvals.data, np.round(eigvals.data)):
-            raise ValueError("Conservation law must have integer eigenvalues.")
-        eigvals = np.round(eigvals).tocsr()
-        # Avoid appearance of zero eigenvalues
-        eigvals = eigvals + 0.5 * sparse.identity(eigvals.shape[0])
-        eigvals.eliminate_zeros()
-        eigvecs = eigvecs.tocoo(args, params=params)
-        projectors = [eigvecs.dot(eigvals == val)
-                      for val in sorted(np.unique(eigvals.data))]
-    else:
-        projectors = None
-
-    def evaluate(op):
-        return None if op is None else op.tocoo(args, params=params)
-
-    return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in
-                                          self._symmetries))
-
-
 def _translate_cons_law(cons_law):
     """Translate a conservation law from builder format to something that can
     be used to initialize operator.Density.
@@ -1650,22 +1627,123 @@ def _translate_cons_law(cons_law):
     return vals, vecs
 
 
-def _init_discrete_symmetries(self, builder):
-        def _operator(op):
+class _FinalizedBuilderMixin:
+    """Common functionality for all finalized builders"""
+
+    def _init_ham_param_maps(self):
+        """Find parameters taken by all value functions
+        """
+        ham_param_map = {}
+        for hams, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
+            for ham in hams:
+                if (not callable(ham) or ham is Other or
+                    ham in ham_param_map):
+                    continue
+                # parameters come in the same order as in the function signature
+                params, defaults, takes_kwargs = get_parameters(ham)
+                params = params[skip:]  # remove site argument(s)
+                ham_param_map[ham] = (params, defaults, takes_kwargs)
+
+        self._ham_param_map = ham_param_map
+
+
+    def _init_discrete_symmetries(self, builder):
+        def operator(op):
             return Density(self, op, check_hermiticity=False)
 
         if builder.conservation_law is None:
             self._cons_law = None
         else:
-            self._cons_law = tuple(map(
-                _operator, _translate_cons_law(builder.conservation_law)))
-        self._symmetries = tuple(None if op is None else _operator(op)
+            a = _translate_cons_law(builder.conservation_law)
+            self._cons_law = tuple(map(operator, a))
+        self._symmetries = tuple(None if op is None else operator(op)
                                  for op in [builder.time_reversal,
                                             builder.particle_hole,
                                             builder.chiral])
 
+    def hamiltonian(self, i, j, *args, params=None):
+        if args and params:
+            raise TypeError("'args' and 'params' are mutually exclusive.")
+        if i == j:
+            value = self.onsite_hamiltonians[i]
+            if callable(value):
+                site = self.symmetry.to_fd(self.sites[i])
+                if params:
+                    required, defaults, takes_kw = self._ham_param_map[value]
+                    invalid_params = set(params).intersection(set(defaults))
+                    if invalid_params:
+                        raise ValueError("Parameters {} have default values "
+                                         "and may not be set with 'params'"
+                                         .format(', '.join(invalid_params)))
+                    if not takes_kw:
+                        params = {pn: params[pn] for pn in required}
+                    try:
+                        value = value(site, **params)
+                    except Exception as exc:
+                        _raise_user_error(exc, value)
+                else:
+                    try:
+                        value = value(site, *args)
+                    except Exception as exc:
+                        _raise_user_error(exc, value)
+        else:
+            edge_id = self.graph.first_edge_id(i, j)
+            value = self.hoppings[edge_id]
+            conj = value is Other
+            if conj:
+                i, j = j, i
+                edge_id = self.graph.first_edge_id(i, j)
+                value = self.hoppings[edge_id]
+            if callable(value):
+                sites = self.sites
+                site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
+                if params:
+                    required, defaults, takes_kw = self._ham_param_map[value]
+                    invalid_params = set(params).intersection(set(defaults))
+                    if invalid_params:
+                        raise ValueError("Parameters {} have default values "
+                                         "and may not be set with 'params'"
+                                         .format(', '.join(invalid_params)))
+                    if not takes_kw:
+                        params = {pn: params[pn] for pn in required}
+                    try:
+                        value = value(site_i, site_j, **params)
+                    except Exception as exc:
+                        _raise_user_error(exc, value)
+                else:
+                    try:
+                        value = value(site_i, site_j, *args)
+                    except Exception as exc:
+                        _raise_user_error(exc, value)
+            if conj:
+                value = herm_conj(value)
+        return value
+
+    def discrete_symmetry(self, args=(), *, params=None):
+        if self._cons_law is not None:
+            eigvals, eigvecs = self._cons_law
+            eigvals = eigvals.tocoo(args, params=params)
+            if not np.allclose(eigvals.data, np.round(eigvals.data)):
+                raise ValueError("Conservation law must have integer"
+                                 " eigenvalues.")
+            eigvals = np.round(eigvals).tocsr()
+            # Avoid appearance of zero eigenvalues
+            eigvals = eigvals + 0.5 * sparse.identity(eigvals.shape[0])
+            eigvals.eliminate_zeros()
+            eigvecs = eigvecs.tocoo(args, params=params)
+            projectors = [eigvecs.dot(eigvals == val)
+                          for val in sorted(np.unique(eigvals.data))]
+        else:
+            projectors = None
 
-class FiniteSystem(system.FiniteSystem):
+        def evaluate(op):
+            return None if op is None else op.tocoo(args, params=params)
+
+        return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in
+                                              self._symmetries))
+
+
+class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
     """Finalized `Builder` with leads.
 
     Usable as input for the solvers in `kwant.solvers`.
@@ -1730,97 +1808,28 @@ class FiniteSystem(system.FiniteSystem):
 
             lead_interfaces.append(np.array(interface))
 
-        #### Find parameters taken by all value functions
+        onsite_hamiltonians = [builder.H[site][1] for site in sites]
         hoppings = [builder._get_edge(sites[tail], sites[head])
                            for tail, head in g]
-        onsite_hamiltonians = [builder.H[site][1] for site in sites]
-
-        _ham_param_map = {}
-        for hams, skip in [(onsite_hamiltonians, 1), (hoppings, 2)]:
-            for ham in hams:
-                if (not callable(ham) or ham is Other or
-                    ham in _ham_param_map):
-                    continue
-                # parameters come in the same order as in the function signature
-                params, defaults, takes_kwargs = get_parameters(ham)
-                params = params[skip:]  # remove site argument(s)
-                _ham_param_map[ham] = (params, defaults, takes_kwargs)
 
         self.graph = g
         self.sites = sites
         self.site_ranges = _site_ranges(sites)
         self.id_by_site = id_by_site
-        self.leads = finalized_leads
         self.hoppings = hoppings
         self.onsite_hamiltonians = onsite_hamiltonians
-        self._ham_param_map = _ham_param_map
-        self.lead_interfaces = lead_interfaces
         self.symmetry = builder.symmetry
-        _init_discrete_symmetries(self, builder)
+        self.leads = finalized_leads
+        self.lead_interfaces = lead_interfaces
+        self._init_ham_param_maps()
+        self._init_discrete_symmetries(builder)
 
-    def hamiltonian(self, i, j, *args, params=None):
-        if args and params:
-            raise TypeError("'args' and 'params' are mutually exclusive.")
-        if i == j:
-            value = self.onsite_hamiltonians[i]
-            if callable(value):
-                if params:
-                    required, defaults, takes_kw = self._ham_param_map[value]
-                    invalid_params = set(params).intersection(set(defaults))
-                    if invalid_params:
-                        raise ValueError("Parameters {} have default values "
-                                         "and may not be set with 'params'"
-                                         .format(', '.join(invalid_params)))
-                    if not takes_kw:
-                        params = {pn: params[pn] for pn in required}
-                    try:
-                        value = value(self.sites[i], **params)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-                else:
-                    try:
-                        value = value(self.sites[i], *args)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-        else:
-            edge_id = self.graph.first_edge_id(i, j)
-            value = self.hoppings[edge_id]
-            conj = value is Other
-            if conj:
-                i, j = j, i
-                edge_id = self.graph.first_edge_id(i, j)
-                value = self.hoppings[edge_id]
-            if callable(value):
-                sites = self.sites
-                if params:
-                    required, defaults, takes_kw = self._ham_param_map[value]
-                    invalid_params = set(params).intersection(set(defaults))
-                    if invalid_params:
-                        raise ValueError("Parameters {} have default values "
-                                         "and may not be set with 'params'"
-                                         .format(', '.join(invalid_params)))
-                    if not takes_kw:
-                        params = {pn: params[pn] for pn in required}
-                    try:
-                        value = value(sites[i], sites[j], **params)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-                else:
-                    try:
-                        value = value(sites[i], sites[j], *args)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-            if conj:
-                value = herm_conj(value)
-        return value
 
     def pos(self, i):
         return self.sites[i].pos
 
-    discrete_symmetry = discrete_symmetry
 
-
-class InfiniteSystem(system.InfiniteSystem):
+class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
     """Finalized infinite system, extracted from a `Builder`.
 
     Attributes
@@ -1851,7 +1860,6 @@ class InfiniteSystem(system.InfiniteSystem):
         sym = builder.symmetry
         assert sym.num_directions == 1
 
-
         #### For each site of the fundamental domain, determine whether it has
         #### neighbors in the previous domain or not.
         lsites_with = []       # Fund. domain sites with neighbors in prev. dom
@@ -1959,90 +1967,24 @@ class InfiniteSystem(system.InfiniteSystem):
                 tail, head = sym.to_fd(tail, head)
             hoppings.append(builder._get_edge(tail, head))
 
-        #### Find parameters taken by all value functions
-        _ham_param_map = {}
-        for hams, skip in [(onsite_hamiltonians, 1), (hoppings, 2)]:
-            for ham in hams:
-                if (not callable(ham) or ham is Other or
-                    ham in _ham_param_map):
-                    continue
-                # parameters come in the same order as in the function signature
-                params, defaults, takes_kwargs = get_parameters(ham)
-                params = params[skip:]  # remove site argument(s)
-                _ham_param_map[ham] = (params, defaults, takes_kwargs)
-
-        self.cell_size = cell_size
+        self.graph = g
         self.sites = sites
-        self.id_by_site = id_by_site
         self.site_ranges = _site_ranges(sites)
-        self.graph = g
+        self.id_by_site = id_by_site
         self.hoppings = hoppings
         self.onsite_hamiltonians = onsite_hamiltonians
-        self._ham_param_map = _ham_param_map
         self.symmetry = builder.symmetry
-        _init_discrete_symmetries(self, builder)
+        self.cell_size = cell_size
+        self._init_ham_param_maps()
+        self._init_discrete_symmetries(builder)
+
 
     def hamiltonian(self, i, j, *args, params=None):
-        if args and params:
-            raise TypeError("'args' and 'params' are mutually exclusive.")
-        if i == j:
-            if i >= self.cell_size:
-                i -= self.cell_size
-            value = self.onsite_hamiltonians[i]
-            if callable(value):
-                site = self.symmetry.to_fd(self.sites[i])
-                if params:
-                    required, defaults, takes_kw = self._ham_param_map[value]
-                    invalid_params = set(params).intersection(set(defaults))
-                    if invalid_params:
-                        raise ValueError("Parameters {} have default values "
-                                         "and may not be set with 'params'"
-                                         .format(', '.join(invalid_params)))
-                    if not takes_kw:
-                        params = {pn: params[pn] for pn in required}
-                    try:
-                        value = value(site, **params)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-                else:
-                    try:
-                        value = value(site, *args)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-        else:
-            edge_id = self.graph.first_edge_id(i, j)
-            value = self.hoppings[edge_id]
-            conj = value is Other
-            if conj:
-                i, j = j, i
-                edge_id = self.graph.first_edge_id(i, j)
-                value = self.hoppings[edge_id]
-            if callable(value):
-                sites = self.sites
-                site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
-                if params:
-                    required, defaults, takes_kw = self._ham_param_map[value]
-                    invalid_params = set(params).intersection(set(defaults))
-                    if invalid_params:
-                        raise ValueError("Parameters {} have default values "
-                                         "and may not be set with 'params'"
-                                         .format(', '.join(invalid_params)))
-                    if not takes_kw:
-                        params = {pn: params[pn] for pn in required}
-                    try:
-                        value = value(site_i, site_j, **params)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-                else:
-                    try:
-                        value = value(site_i, site_j, *args)
-                    except Exception as exc:
-                        _raise_user_error(exc, value)
-            if conj:
-                value = herm_conj(value)
-        return value
+        cs = self.cell_size
+        if i == j >= cs:
+            i -= cs
+            j -= cs
+        return super().hamiltonian(i, j, *args, params=params)
 
     def pos(self, i):
         return self.sites[i].pos
-
-    discrete_symmetry = discrete_symmetry
-- 
GitLab