From cee3cadd12ac9933da887909303aaafd36e85b5b Mon Sep 17 00:00:00 2001 From: Joseph Weston <joseph@weston.cloud> Date: Mon, 30 Sep 2019 16:26:20 +0200 Subject: [PATCH] refactor parts of finalization into functions These will be used in vectorized systems later. --- kwant/builder.py | 311 +++++++++++++++++++++++++---------------------- 1 file changed, 166 insertions(+), 145 deletions(-) diff --git a/kwant/builder.py b/kwant/builder.py index 16bb4f46..3a583f99 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -1992,6 +1992,9 @@ class _FinalizedBuilderMixin: return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in self._symmetries)) + def pos(self, i): + return self.sites[i].pos + # The same (value, parameters) pair will be used for many sites/hoppings, # so we cache it to avoid wasting extra memory. @@ -2024,6 +2027,64 @@ def _value_params_pair_cache(nstrip): return get +def _make_graph(H, id_by_site): + g = graph.Graph() + g.num_nodes = len(id_by_site) # Some sites could not appear in any edge. + for tail, hvhv in H.items(): + for head in islice(hvhv, 2, None, 2): + if tail == head: + continue + g.add_edge(id_by_site[tail], id_by_site[head]) + return g.compressed() + + +def _finalize_leads(leads, id_by_site): + #### Connect leads. + finalized_leads = [] + lead_interfaces = [] + lead_paddings = [] + for lead_nr, lead in enumerate(leads): + try: + with warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always") + # The following line is the whole "payload" of the entire + # try-block. + finalized_leads.append(lead.finalized()) + except ValueError as e: + # Re-raise the exception with an additional message. + msg = 'Problem finalizing lead {0}:'.format(lead_nr) + e.args = (' '.join((msg,) + e.args),) + raise + else: + for w in ws: + # Re-raise any warnings with an additional message and the + # proper stacklevel. + w = w.message + msg = 'When finalizing lead {0}:'.format(lead_nr) + warnings.warn(w.__class__(' '.join((msg,) + w.args)), + stacklevel=3) + try: + interface = [id_by_site[isite] for isite in lead.interface] + except KeyError as e: + msg = ("Lead {0} is attached to a site that does not " + "belong to the scattering region:\n {1}") + raise ValueError(msg.format(lead_nr, e.args[0])) + + lead_interfaces.append(np.array(interface)) + + padding = getattr(lead, 'padding', []) + # Some padding sites might have been removed after the lead was + # attached. Unlike in the case of the interface, this is not a + # problem. + finalized_padding = [ + id_by_site[isite] for isite in padding if isite in id_by_site + ] + + lead_paddings.append(np.array(finalized_padding)) + + return finalized_leads, lead_interfaces, lead_paddings + + class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): """Finalized `Builder` with leads. @@ -2050,57 +2111,10 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): for site_id, site in enumerate(sites): id_by_site[site] = site_id - #### Make graph. - g = graph.Graph() - g.num_nodes = len(sites) # Some sites could not appear in any edge. - for tail, hvhv in builder.H.items(): - for head in islice(hvhv, 2, None, 2): - if tail == head: - continue - g.add_edge(id_by_site[tail], id_by_site[head]) - g = g.compressed() - - #### Connect leads. - finalized_leads = [] - lead_interfaces = [] - lead_paddings = [] - for lead_nr, lead in enumerate(builder.leads): - try: - with warnings.catch_warnings(record=True) as ws: - warnings.simplefilter("always") - # The following line is the whole "payload" of the entire - # try-block. - finalized_leads.append(lead.finalized()) - for w in ws: - # Re-raise any warnings with an additional message and the - # proper stacklevel. - w = w.message - msg = 'When finalizing lead {0}:'.format(lead_nr) - warnings.warn(w.__class__(' '.join((msg,) + w.args)), - stacklevel=3) - except ValueError as e: - # Re-raise the exception with an additional message. - msg = 'Problem finalizing lead {0}:'.format(lead_nr) - e.args = (' '.join((msg,) + e.args),) - raise - try: - interface = [id_by_site[isite] for isite in lead.interface] - except KeyError as e: - msg = ("Lead {0} is attached to a site that does not " - "belong to the scattering region:\n {1}") - raise ValueError(msg.format(lead_nr, e.args[0])) - - lead_interfaces.append(np.array(interface)) - - padding = getattr(lead, 'padding', []) - # Some padding sites might have been removed after the lead was - # attached. Unlike in the case of the interface, this is not a - # problem. - finalized_padding = [ - id_by_site[isite] for isite in padding if isite in id_by_site - ] + graph = _make_graph(builder.H, id_by_site) - lead_paddings.append(np.array(finalized_padding)) + finalized_leads, lead_interfaces, lead_paddings =\ + _finalize_leads(builder.leads, id_by_site) # Because many onsites/hoppings share the same (value, parameter) # pairs, we keep them in a cache so that we only store a given pair @@ -2109,7 +2123,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): onsites = [cache(builder.H[site][1]) for site in sites] cache = _value_params_pair_cache(2) hoppings = [cache(builder._get_edge(sites[tail], sites[head])) - for tail, head in g] + for tail, head in graph] # Compute the union of the parameters of onsites and hoppings. Here, # 'onsites' and 'hoppings' are pairs whose second element is one of @@ -2129,7 +2143,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): else: parameters = frozenset(parameters) - self.graph = g + self.graph = graph self.sites = sites self.site_ranges = _site_ranges(sites) self.id_by_site = id_by_site @@ -2142,8 +2156,97 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): self.lead_paddings = lead_paddings self._init_discrete_symmetries(builder) - def pos(self, i): - return self.sites[i].pos + +def _make_lead_sites(builder, interface_order): + #### For each site of the fundamental domain, determine whether it has + #### neighbors in the previous domain or not. + sym = builder.symmetry + lsites_with = [] # Fund. domain sites with neighbors in prev. dom + lsites_without = [] # Remaining sites of the fundamental domain + for tail in builder.H: # Loop over all sites of the fund. domain. + for head in builder._out_neighbors(tail): + fd = sym.which(head)[0] + if fd == 1: + # Tail belongs to fund. domain, head to the next domain. + lsites_with.append(tail) + break + else: + # Tail is a fund. domain site not connected to prev. domain. + lsites_without.append(tail) + + if not lsites_with: + warnings.warn('Infinite system with disconnected cells.', + RuntimeWarning, stacklevel=3) + + ### Create list of sites and a lookup table + minus_one = ta.array((-1,)) + plus_one = ta.array((1,)) + if interface_order is None: + # interface must be sorted + interface = [sym.act(minus_one, s) for s in lsites_with] + interface.sort() + else: + lsites_with_set = set(lsites_with) + lsites_with = [] + interface = [] + if interface_order: + shift = ta.array((-sym.which(interface_order[0])[0] - 1,)) + for shifted_iface_site in interface_order: + # Shift the interface domain before the fundamental domain. + # That's the right place for the interface of a lead to be, but + # the sites of interface_order might live in a different + # domain. + iface_site = sym.act(shift, shifted_iface_site) + lsite = sym.act(plus_one, iface_site) + + try: + lsites_with_set.remove(lsite) + except KeyError: + if (-sym.which(shifted_iface_site)[0] - 1,) != shift: + raise ValueError( + 'The sites in interface_order do not all ' + 'belong to the same lead cell.') + else: + raise ValueError('A site in interface_order is not an ' + 'interface site:\n' + str(iface_site)) + interface.append(iface_site) + lsites_with.append(lsite) + if lsites_with_set: + raise ValueError( + 'interface_order did not contain all interface sites.') + # `interface_order` *must* be sorted, hence `interface` should also + if interface != sorted(interface): + raise ValueError('Interface sites must be sorted.') + del lsites_with_set + + return sorted(lsites_with), sorted(lsites_without), interface + + +def _make_lead_graph(builder, sites, id_by_site, cell_size): + sym = builder.symmetry + g = graph.Graph() + g.num_nodes = len(sites) # Some sites could not appear in any edge. + for tail_id, tail in enumerate(sites[:cell_size]): + for head in builder._out_neighbors(tail): + head_id = id_by_site.get(head) + if head_id is None: + # Head belongs neither to the fundamental domain nor to the + # previous domain. Check that it belongs to the next + # domain and ignore it otherwise as an edge corresponding + # to this one has been added already or will be added. + fd = sym.which(head)[0] + if fd != 1: + msg = ('Further-than-nearest-neighbor cells ' + 'are connected by hopping\n{0}.') + raise ValueError(msg.format((tail, head))) + continue + if head_id >= cell_size: + # Head belongs to previous domain. The edge added here + # correspond to one left out just above. + g.add_edge(head_id, tail_id) + g.add_edge(tail_id, head_id) + + return g.compressed() class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): @@ -2179,69 +2282,12 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): sym = builder.symmetry assert sym.num_directions == 1 - #### For each site of the fundamental domain, determine whether it has - #### neighbors in the previous domain or not. - lsites_with = [] # Fund. domain sites with neighbors in prev. dom - lsites_without = [] # Remaining sites of the fundamental domain - for tail in builder.H: # Loop over all sites of the fund. domain. - for head in builder._out_neighbors(tail): - fd = sym.which(head)[0] - if fd == 1: - # Tail belongs to fund. domain, head to the next domain. - lsites_with.append(tail) - break - else: - # Tail is a fund. domain site not connected to prev. domain. - lsites_without.append(tail) + lsites_with, lsites_without, interface =\ + _make_lead_sites(builder, interface_order) cell_size = len(lsites_with) + len(lsites_without) - if not lsites_with: - warnings.warn('Infinite system with disconnected cells.', - RuntimeWarning, stacklevel=3) - - ### Create list of sites and a lookup table - minus_one = ta.array((-1,)) - plus_one = ta.array((1,)) - if interface_order is None: - # interface must be sorted - interface = [sym.act(minus_one, s) for s in lsites_with] - interface.sort() - else: - lsites_with_set = set(lsites_with) - lsites_with = [] - interface = [] - if interface_order: - shift = ta.array((-sym.which(interface_order[0])[0] - 1,)) - for shifted_iface_site in interface_order: - # Shift the interface domain before the fundamental domain. - # That's the right place for the interface of a lead to be, but - # the sites of interface_order might live in a different - # domain. - iface_site = sym.act(shift, shifted_iface_site) - lsite = sym.act(plus_one, iface_site) - - try: - lsites_with_set.remove(lsite) - except KeyError: - if (-sym.which(shifted_iface_site)[0] - 1,) != shift: - raise ValueError( - 'The sites in interface_order do not all ' - 'belong to the same lead cell.') - else: - raise ValueError('A site in interface_order is not an ' - 'interface site:\n' + str(iface_site)) - interface.append(iface_site) - lsites_with.append(lsite) - if lsites_with_set: - raise ValueError( - 'interface_order did not contain all interface sites.') - # `interface_order` *must* be sorted, hence `interface` should also - if interface != sorted(interface): - raise ValueError('Interface sites must be sorted.') - del lsites_with_set - # we previously sorted the interface, so don't sort it again - sites = sorted(lsites_with) + sorted(lsites_without) + interface + sites = lsites_with + lsites_without + interface del lsites_with del lsites_without del interface @@ -2249,41 +2295,20 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): for site_id, site in enumerate(sites): id_by_site[site] = site_id + graph = _make_lead_graph(builder, sites, id_by_site, cell_size) + # In the following, because many onsites/hoppings share the same # (value, parameter) pairs, we keep them in 'cache' so that we only # store a given pair in memory *once*. This is like interning strings. - #### Make graph and extract onsite Hamiltonians. + #### Extract onsites cache = _value_params_pair_cache(1) - g = graph.Graph() - g.num_nodes = len(sites) # Some sites could not appear in any edge. - onsites = [] - for tail_id, tail in enumerate(sites[:cell_size]): - onsites.append(cache(builder.H[tail][1])) - for head in builder._out_neighbors(tail): - head_id = id_by_site.get(head) - if head_id is None: - # Head belongs neither to the fundamental domain nor to the - # previous domain. Check that it belongs to the next - # domain and ignore it otherwise as an edge corresponding - # to this one has been added already or will be added. - fd = sym.which(head)[0] - if fd != 1: - msg = ('Further-than-nearest-neighbor cells ' - 'are connected by hopping\n{0}.') - raise ValueError(msg.format((tail, head))) - continue - if head_id >= cell_size: - # Head belongs to previous domain. The edge added here - # correspond to one left out just above. - g.add_edge(head_id, tail_id) - g.add_edge(tail_id, head_id) - g = g.compressed() + onsites = [cache(builder.H[tail][1]) for tail in sites[:cell_size]] #### Extract hoppings. cache = _value_params_pair_cache(2) hoppings = [] - for tail_id, head_id in g: + for tail_id, head_id in graph: tail = sites[tail_id] head = sites[head_id] if tail_id >= cell_size: @@ -2310,7 +2335,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): else: parameters = frozenset(parameters) - self.graph = g + self.graph = graph self.sites = sites self.site_ranges = _site_ranges(sites) self.id_by_site = id_by_site @@ -2321,13 +2346,9 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): self.cell_size = cell_size self._init_discrete_symmetries(builder) - def hamiltonian(self, i, j, *args, params=None): cs = self.cell_size if i == j >= cs: i -= cs j -= cs return super().hamiltonian(i, j, *args, params=params) - - def pos(self, i): - return self.sites[i].pos -- GitLab