From cee3cadd12ac9933da887909303aaafd36e85b5b Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Mon, 30 Sep 2019 16:26:20 +0200
Subject: [PATCH] refactor parts of finalization into functions

These will be used in vectorized systems later.
---
 kwant/builder.py | 311 +++++++++++++++++++++++++----------------------
 1 file changed, 166 insertions(+), 145 deletions(-)

diff --git a/kwant/builder.py b/kwant/builder.py
index 16bb4f46..3a583f99 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -1992,6 +1992,9 @@ class _FinalizedBuilderMixin:
         return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in
                                               self._symmetries))
 
+    def pos(self, i):
+        return self.sites[i].pos
+
 
 # The same (value, parameters) pair will be used for many sites/hoppings,
 # so we cache it to avoid wasting extra memory.
@@ -2024,6 +2027,64 @@ def _value_params_pair_cache(nstrip):
     return get
 
 
+def _make_graph(H, id_by_site):
+    g = graph.Graph()
+    g.num_nodes = len(id_by_site)  # Some sites could not appear in any edge.
+    for tail, hvhv in H.items():
+        for head in islice(hvhv, 2, None, 2):
+            if tail == head:
+                continue
+            g.add_edge(id_by_site[tail], id_by_site[head])
+    return g.compressed()
+
+
+def _finalize_leads(leads, id_by_site):
+    #### Connect leads.
+    finalized_leads = []
+    lead_interfaces = []
+    lead_paddings = []
+    for lead_nr, lead in enumerate(leads):
+        try:
+            with warnings.catch_warnings(record=True) as ws:
+                warnings.simplefilter("always")
+                # The following line is the whole "payload" of the entire
+                # try-block.
+                finalized_leads.append(lead.finalized())
+        except ValueError as e:
+            # Re-raise the exception with an additional message.
+            msg = 'Problem finalizing lead {0}:'.format(lead_nr)
+            e.args = (' '.join((msg,) + e.args),)
+            raise
+        else:
+            for w in ws:
+                # Re-raise any warnings with an additional message and the
+                # proper stacklevel.
+                w = w.message
+                msg = 'When finalizing lead {0}:'.format(lead_nr)
+                warnings.warn(w.__class__(' '.join((msg,) + w.args)),
+                              stacklevel=3)
+        try:
+            interface = [id_by_site[isite] for isite in lead.interface]
+        except KeyError as e:
+            msg = ("Lead {0} is attached to a site that does not "
+                   "belong to the scattering region:\n {1}")
+            raise ValueError(msg.format(lead_nr, e.args[0]))
+
+        lead_interfaces.append(np.array(interface))
+
+        padding = getattr(lead, 'padding', [])
+        # Some padding sites might have been removed after the lead was
+        # attached. Unlike in the case of the interface, this is not a
+        # problem.
+        finalized_padding = [
+            id_by_site[isite] for isite in padding if isite in id_by_site
+        ]
+
+        lead_paddings.append(np.array(finalized_padding))
+
+    return finalized_leads, lead_interfaces, lead_paddings
+
+
 class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
     """Finalized `Builder` with leads.
 
@@ -2050,57 +2111,10 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
         for site_id, site in enumerate(sites):
             id_by_site[site] = site_id
 
-        #### Make graph.
-        g = graph.Graph()
-        g.num_nodes = len(sites)  # Some sites could not appear in any edge.
-        for tail, hvhv in builder.H.items():
-            for head in islice(hvhv, 2, None, 2):
-                if tail == head:
-                    continue
-                g.add_edge(id_by_site[tail], id_by_site[head])
-        g = g.compressed()
-
-        #### Connect leads.
-        finalized_leads = []
-        lead_interfaces = []
-        lead_paddings = []
-        for lead_nr, lead in enumerate(builder.leads):
-            try:
-                with warnings.catch_warnings(record=True) as ws:
-                    warnings.simplefilter("always")
-                    # The following line is the whole "payload" of the entire
-                    # try-block.
-                    finalized_leads.append(lead.finalized())
-                for w in ws:
-                    # Re-raise any warnings with an additional message and the
-                    # proper stacklevel.
-                    w = w.message
-                    msg = 'When finalizing lead {0}:'.format(lead_nr)
-                    warnings.warn(w.__class__(' '.join((msg,) + w.args)),
-                                  stacklevel=3)
-            except ValueError as e:
-                # Re-raise the exception with an additional message.
-                msg = 'Problem finalizing lead {0}:'.format(lead_nr)
-                e.args = (' '.join((msg,) + e.args),)
-                raise
-            try:
-                interface = [id_by_site[isite] for isite in lead.interface]
-            except KeyError as e:
-                msg = ("Lead {0} is attached to a site that does not "
-                       "belong to the scattering region:\n {1}")
-                raise ValueError(msg.format(lead_nr, e.args[0]))
-
-            lead_interfaces.append(np.array(interface))
-
-            padding = getattr(lead, 'padding', [])
-            # Some padding sites might have been removed after the lead was
-            # attached. Unlike in the case of the interface, this is not a
-            # problem.
-            finalized_padding = [
-                id_by_site[isite] for isite in padding if isite in id_by_site
-            ]
+        graph = _make_graph(builder.H, id_by_site)
 
-            lead_paddings.append(np.array(finalized_padding))
+        finalized_leads, lead_interfaces, lead_paddings =\
+            _finalize_leads(builder.leads, id_by_site)
 
         # Because many onsites/hoppings share the same (value, parameter)
         # pairs, we keep them in a cache so that we only store a given pair
@@ -2109,7 +2123,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
         onsites = [cache(builder.H[site][1]) for site in sites]
         cache = _value_params_pair_cache(2)
         hoppings = [cache(builder._get_edge(sites[tail], sites[head]))
-                    for tail, head in g]
+                    for tail, head in graph]
 
         # Compute the union of the parameters of onsites and hoppings.  Here,
         # 'onsites' and 'hoppings' are pairs whose second element is one of
@@ -2129,7 +2143,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
         else:
             parameters = frozenset(parameters)
 
-        self.graph = g
+        self.graph = graph
         self.sites = sites
         self.site_ranges = _site_ranges(sites)
         self.id_by_site = id_by_site
@@ -2142,8 +2156,97 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
         self.lead_paddings = lead_paddings
         self._init_discrete_symmetries(builder)
 
-    def pos(self, i):
-        return self.sites[i].pos
+
+def _make_lead_sites(builder, interface_order):
+    #### For each site of the fundamental domain, determine whether it has
+    #### neighbors in the previous domain or not.
+    sym = builder.symmetry
+    lsites_with = []       # Fund. domain sites with neighbors in prev. dom
+    lsites_without = []    # Remaining sites of the fundamental domain
+    for tail in builder.H: # Loop over all sites of the fund. domain.
+        for head in builder._out_neighbors(tail):
+            fd = sym.which(head)[0]
+            if fd == 1:
+                # Tail belongs to fund. domain, head to the next domain.
+                lsites_with.append(tail)
+                break
+        else:
+            # Tail is a fund. domain site not connected to prev. domain.
+            lsites_without.append(tail)
+
+    if not lsites_with:
+        warnings.warn('Infinite system with disconnected cells.',
+                      RuntimeWarning, stacklevel=3)
+
+    ### Create list of sites and a lookup table
+    minus_one = ta.array((-1,))
+    plus_one = ta.array((1,))
+    if interface_order is None:
+        # interface must be sorted
+        interface = [sym.act(minus_one, s) for s in lsites_with]
+        interface.sort()
+    else:
+        lsites_with_set = set(lsites_with)
+        lsites_with = []
+        interface = []
+        if interface_order:
+            shift = ta.array((-sym.which(interface_order[0])[0] - 1,))
+        for shifted_iface_site in interface_order:
+            # Shift the interface domain before the fundamental domain.
+            # That's the right place for the interface of a lead to be, but
+            # the sites of interface_order might live in a different
+            # domain.
+            iface_site = sym.act(shift, shifted_iface_site)
+            lsite = sym.act(plus_one, iface_site)
+
+            try:
+                lsites_with_set.remove(lsite)
+            except KeyError:
+                if (-sym.which(shifted_iface_site)[0] - 1,) != shift:
+                    raise ValueError(
+                        'The sites in interface_order do not all '
+                        'belong to the same lead cell.')
+                else:
+                    raise ValueError('A site in interface_order is not an '
+                                     'interface site:\n' + str(iface_site))
+            interface.append(iface_site)
+            lsites_with.append(lsite)
+        if lsites_with_set:
+            raise ValueError(
+                'interface_order did not contain all interface sites.')
+        # `interface_order` *must* be sorted, hence `interface` should also
+        if interface != sorted(interface):
+            raise ValueError('Interface sites must be sorted.')
+        del lsites_with_set
+
+    return sorted(lsites_with), sorted(lsites_without), interface
+
+
+def _make_lead_graph(builder, sites, id_by_site, cell_size):
+    sym = builder.symmetry
+    g = graph.Graph()
+    g.num_nodes = len(sites)  # Some sites could not appear in any edge.
+    for tail_id, tail in enumerate(sites[:cell_size]):
+        for head in builder._out_neighbors(tail):
+            head_id = id_by_site.get(head)
+            if head_id is None:
+                # Head belongs neither to the fundamental domain nor to the
+                # previous domain.  Check that it belongs to the next
+                # domain and ignore it otherwise as an edge corresponding
+                # to this one has been added already or will be added.
+                fd = sym.which(head)[0]
+                if fd != 1:
+                    msg = ('Further-than-nearest-neighbor cells '
+                           'are connected by hopping\n{0}.')
+                    raise ValueError(msg.format((tail, head)))
+                continue
+            if head_id >= cell_size:
+                # Head belongs to previous domain.  The edge added here
+                # correspond to one left out just above.
+                g.add_edge(head_id, tail_id)
+            g.add_edge(tail_id, head_id)
+
+    return g.compressed()
 
 
 class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
@@ -2179,69 +2282,12 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
         sym = builder.symmetry
         assert sym.num_directions == 1
 
-        #### For each site of the fundamental domain, determine whether it has
-        #### neighbors in the previous domain or not.
-        lsites_with = []       # Fund. domain sites with neighbors in prev. dom
-        lsites_without = []    # Remaining sites of the fundamental domain
-        for tail in builder.H: # Loop over all sites of the fund. domain.
-            for head in builder._out_neighbors(tail):
-                fd = sym.which(head)[0]
-                if fd == 1:
-                    # Tail belongs to fund. domain, head to the next domain.
-                    lsites_with.append(tail)
-                    break
-            else:
-                # Tail is a fund. domain site not connected to prev. domain.
-                lsites_without.append(tail)
+        lsites_with, lsites_without, interface =\
+            _make_lead_sites(builder, interface_order)
         cell_size = len(lsites_with) + len(lsites_without)
 
-        if not lsites_with:
-            warnings.warn('Infinite system with disconnected cells.',
-                          RuntimeWarning, stacklevel=3)
-
-        ### Create list of sites and a lookup table
-        minus_one = ta.array((-1,))
-        plus_one = ta.array((1,))
-        if interface_order is None:
-            # interface must be sorted
-            interface = [sym.act(minus_one, s) for s in lsites_with]
-            interface.sort()
-        else:
-            lsites_with_set = set(lsites_with)
-            lsites_with = []
-            interface = []
-            if interface_order:
-                shift = ta.array((-sym.which(interface_order[0])[0] - 1,))
-            for shifted_iface_site in interface_order:
-                # Shift the interface domain before the fundamental domain.
-                # That's the right place for the interface of a lead to be, but
-                # the sites of interface_order might live in a different
-                # domain.
-                iface_site = sym.act(shift, shifted_iface_site)
-                lsite = sym.act(plus_one, iface_site)
-
-                try:
-                    lsites_with_set.remove(lsite)
-                except KeyError:
-                    if (-sym.which(shifted_iface_site)[0] - 1,) != shift:
-                        raise ValueError(
-                            'The sites in interface_order do not all '
-                            'belong to the same lead cell.')
-                    else:
-                        raise ValueError('A site in interface_order is not an '
-                                         'interface site:\n' + str(iface_site))
-                interface.append(iface_site)
-                lsites_with.append(lsite)
-            if lsites_with_set:
-                raise ValueError(
-                    'interface_order did not contain all interface sites.')
-            # `interface_order` *must* be sorted, hence `interface` should also
-            if interface != sorted(interface):
-                raise ValueError('Interface sites must be sorted.')
-            del lsites_with_set
-
         # we previously sorted the interface, so don't sort it again
-        sites = sorted(lsites_with) + sorted(lsites_without) + interface
+        sites = lsites_with + lsites_without + interface
         del lsites_with
         del lsites_without
         del interface
@@ -2249,41 +2295,20 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
         for site_id, site in enumerate(sites):
             id_by_site[site] = site_id
 
+        graph = _make_lead_graph(builder, sites, id_by_site, cell_size)
+
         # In the following, because many onsites/hoppings share the same
         # (value, parameter) pairs, we keep them in 'cache' so that we only
         # store a given pair in memory *once*. This is like interning strings.
 
-        #### Make graph and extract onsite Hamiltonians.
+        #### Extract onsites
         cache = _value_params_pair_cache(1)
-        g = graph.Graph()
-        g.num_nodes = len(sites)  # Some sites could not appear in any edge.
-        onsites = []
-        for tail_id, tail in enumerate(sites[:cell_size]):
-            onsites.append(cache(builder.H[tail][1]))
-            for head in builder._out_neighbors(tail):
-                head_id = id_by_site.get(head)
-                if head_id is None:
-                    # Head belongs neither to the fundamental domain nor to the
-                    # previous domain.  Check that it belongs to the next
-                    # domain and ignore it otherwise as an edge corresponding
-                    # to this one has been added already or will be added.
-                    fd = sym.which(head)[0]
-                    if fd != 1:
-                        msg = ('Further-than-nearest-neighbor cells '
-                               'are connected by hopping\n{0}.')
-                        raise ValueError(msg.format((tail, head)))
-                    continue
-                if head_id >= cell_size:
-                    # Head belongs to previous domain.  The edge added here
-                    # correspond to one left out just above.
-                    g.add_edge(head_id, tail_id)
-                g.add_edge(tail_id, head_id)
-        g = g.compressed()
+        onsites = [cache(builder.H[tail][1]) for tail in sites[:cell_size]]
 
         #### Extract hoppings.
         cache = _value_params_pair_cache(2)
         hoppings = []
-        for tail_id, head_id in g:
+        for tail_id, head_id in graph:
             tail = sites[tail_id]
             head = sites[head_id]
             if tail_id >= cell_size:
@@ -2310,7 +2335,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
         else:
             parameters = frozenset(parameters)
 
-        self.graph = g
+        self.graph = graph
         self.sites = sites
         self.site_ranges = _site_ranges(sites)
         self.id_by_site = id_by_site
@@ -2321,13 +2346,9 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
         self.cell_size = cell_size
         self._init_discrete_symmetries(builder)
 
-
     def hamiltonian(self, i, j, *args, params=None):
         cs = self.cell_size
         if i == j >= cs:
             i -= cs
             j -= cs
         return super().hamiltonian(i, j, *args, params=params)
-
-    def pos(self, i):
-        return self.sites[i].pos
-- 
GitLab