Skip to content
Snippets Groups Projects
Commit cee3cadd authored by Joseph Weston's avatar Joseph Weston
Browse files

refactor parts of finalization into functions

These will be used in vectorized systems later.
parent 64d9d4be
No related branches found
No related tags found
No related merge requests found
...@@ -1992,6 +1992,9 @@ class _FinalizedBuilderMixin: ...@@ -1992,6 +1992,9 @@ class _FinalizedBuilderMixin:
return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in return DiscreteSymmetry(projectors, *(evaluate(symm) for symm in
self._symmetries)) self._symmetries))
def pos(self, i):
return self.sites[i].pos
# The same (value, parameters) pair will be used for many sites/hoppings, # The same (value, parameters) pair will be used for many sites/hoppings,
# so we cache it to avoid wasting extra memory. # so we cache it to avoid wasting extra memory.
...@@ -2024,6 +2027,64 @@ def _value_params_pair_cache(nstrip): ...@@ -2024,6 +2027,64 @@ def _value_params_pair_cache(nstrip):
return get return get
def _make_graph(H, id_by_site):
g = graph.Graph()
g.num_nodes = len(id_by_site) # Some sites could not appear in any edge.
for tail, hvhv in H.items():
for head in islice(hvhv, 2, None, 2):
if tail == head:
continue
g.add_edge(id_by_site[tail], id_by_site[head])
return g.compressed()
def _finalize_leads(leads, id_by_site):
#### Connect leads.
finalized_leads = []
lead_interfaces = []
lead_paddings = []
for lead_nr, lead in enumerate(leads):
try:
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
# The following line is the whole "payload" of the entire
# try-block.
finalized_leads.append(lead.finalized())
except ValueError as e:
# Re-raise the exception with an additional message.
msg = 'Problem finalizing lead {0}:'.format(lead_nr)
e.args = (' '.join((msg,) + e.args),)
raise
else:
for w in ws:
# Re-raise any warnings with an additional message and the
# proper stacklevel.
w = w.message
msg = 'When finalizing lead {0}:'.format(lead_nr)
warnings.warn(w.__class__(' '.join((msg,) + w.args)),
stacklevel=3)
try:
interface = [id_by_site[isite] for isite in lead.interface]
except KeyError as e:
msg = ("Lead {0} is attached to a site that does not "
"belong to the scattering region:\n {1}")
raise ValueError(msg.format(lead_nr, e.args[0]))
lead_interfaces.append(np.array(interface))
padding = getattr(lead, 'padding', [])
# Some padding sites might have been removed after the lead was
# attached. Unlike in the case of the interface, this is not a
# problem.
finalized_padding = [
id_by_site[isite] for isite in padding if isite in id_by_site
]
lead_paddings.append(np.array(finalized_padding))
return finalized_leads, lead_interfaces, lead_paddings
class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
"""Finalized `Builder` with leads. """Finalized `Builder` with leads.
...@@ -2050,57 +2111,10 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): ...@@ -2050,57 +2111,10 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
for site_id, site in enumerate(sites): for site_id, site in enumerate(sites):
id_by_site[site] = site_id id_by_site[site] = site_id
#### Make graph. graph = _make_graph(builder.H, id_by_site)
g = graph.Graph()
g.num_nodes = len(sites) # Some sites could not appear in any edge.
for tail, hvhv in builder.H.items():
for head in islice(hvhv, 2, None, 2):
if tail == head:
continue
g.add_edge(id_by_site[tail], id_by_site[head])
g = g.compressed()
#### Connect leads.
finalized_leads = []
lead_interfaces = []
lead_paddings = []
for lead_nr, lead in enumerate(builder.leads):
try:
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
# The following line is the whole "payload" of the entire
# try-block.
finalized_leads.append(lead.finalized())
for w in ws:
# Re-raise any warnings with an additional message and the
# proper stacklevel.
w = w.message
msg = 'When finalizing lead {0}:'.format(lead_nr)
warnings.warn(w.__class__(' '.join((msg,) + w.args)),
stacklevel=3)
except ValueError as e:
# Re-raise the exception with an additional message.
msg = 'Problem finalizing lead {0}:'.format(lead_nr)
e.args = (' '.join((msg,) + e.args),)
raise
try:
interface = [id_by_site[isite] for isite in lead.interface]
except KeyError as e:
msg = ("Lead {0} is attached to a site that does not "
"belong to the scattering region:\n {1}")
raise ValueError(msg.format(lead_nr, e.args[0]))
lead_interfaces.append(np.array(interface))
padding = getattr(lead, 'padding', [])
# Some padding sites might have been removed after the lead was
# attached. Unlike in the case of the interface, this is not a
# problem.
finalized_padding = [
id_by_site[isite] for isite in padding if isite in id_by_site
]
lead_paddings.append(np.array(finalized_padding)) finalized_leads, lead_interfaces, lead_paddings =\
_finalize_leads(builder.leads, id_by_site)
# Because many onsites/hoppings share the same (value, parameter) # Because many onsites/hoppings share the same (value, parameter)
# pairs, we keep them in a cache so that we only store a given pair # pairs, we keep them in a cache so that we only store a given pair
...@@ -2109,7 +2123,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): ...@@ -2109,7 +2123,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
onsites = [cache(builder.H[site][1]) for site in sites] onsites = [cache(builder.H[site][1]) for site in sites]
cache = _value_params_pair_cache(2) cache = _value_params_pair_cache(2)
hoppings = [cache(builder._get_edge(sites[tail], sites[head])) hoppings = [cache(builder._get_edge(sites[tail], sites[head]))
for tail, head in g] for tail, head in graph]
# Compute the union of the parameters of onsites and hoppings. Here, # Compute the union of the parameters of onsites and hoppings. Here,
# 'onsites' and 'hoppings' are pairs whose second element is one of # 'onsites' and 'hoppings' are pairs whose second element is one of
...@@ -2129,7 +2143,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): ...@@ -2129,7 +2143,7 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
else: else:
parameters = frozenset(parameters) parameters = frozenset(parameters)
self.graph = g self.graph = graph
self.sites = sites self.sites = sites
self.site_ranges = _site_ranges(sites) self.site_ranges = _site_ranges(sites)
self.id_by_site = id_by_site self.id_by_site = id_by_site
...@@ -2142,8 +2156,97 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem): ...@@ -2142,8 +2156,97 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
self.lead_paddings = lead_paddings self.lead_paddings = lead_paddings
self._init_discrete_symmetries(builder) self._init_discrete_symmetries(builder)
def pos(self, i):
return self.sites[i].pos def _make_lead_sites(builder, interface_order):
#### For each site of the fundamental domain, determine whether it has
#### neighbors in the previous domain or not.
sym = builder.symmetry
lsites_with = [] # Fund. domain sites with neighbors in prev. dom
lsites_without = [] # Remaining sites of the fundamental domain
for tail in builder.H: # Loop over all sites of the fund. domain.
for head in builder._out_neighbors(tail):
fd = sym.which(head)[0]
if fd == 1:
# Tail belongs to fund. domain, head to the next domain.
lsites_with.append(tail)
break
else:
# Tail is a fund. domain site not connected to prev. domain.
lsites_without.append(tail)
if not lsites_with:
warnings.warn('Infinite system with disconnected cells.',
RuntimeWarning, stacklevel=3)
### Create list of sites and a lookup table
minus_one = ta.array((-1,))
plus_one = ta.array((1,))
if interface_order is None:
# interface must be sorted
interface = [sym.act(minus_one, s) for s in lsites_with]
interface.sort()
else:
lsites_with_set = set(lsites_with)
lsites_with = []
interface = []
if interface_order:
shift = ta.array((-sym.which(interface_order[0])[0] - 1,))
for shifted_iface_site in interface_order:
# Shift the interface domain before the fundamental domain.
# That's the right place for the interface of a lead to be, but
# the sites of interface_order might live in a different
# domain.
iface_site = sym.act(shift, shifted_iface_site)
lsite = sym.act(plus_one, iface_site)
try:
lsites_with_set.remove(lsite)
except KeyError:
if (-sym.which(shifted_iface_site)[0] - 1,) != shift:
raise ValueError(
'The sites in interface_order do not all '
'belong to the same lead cell.')
else:
raise ValueError('A site in interface_order is not an '
'interface site:\n' + str(iface_site))
interface.append(iface_site)
lsites_with.append(lsite)
if lsites_with_set:
raise ValueError(
'interface_order did not contain all interface sites.')
# `interface_order` *must* be sorted, hence `interface` should also
if interface != sorted(interface):
raise ValueError('Interface sites must be sorted.')
del lsites_with_set
return sorted(lsites_with), sorted(lsites_without), interface
def _make_lead_graph(builder, sites, id_by_site, cell_size):
sym = builder.symmetry
g = graph.Graph()
g.num_nodes = len(sites) # Some sites could not appear in any edge.
for tail_id, tail in enumerate(sites[:cell_size]):
for head in builder._out_neighbors(tail):
head_id = id_by_site.get(head)
if head_id is None:
# Head belongs neither to the fundamental domain nor to the
# previous domain. Check that it belongs to the next
# domain and ignore it otherwise as an edge corresponding
# to this one has been added already or will be added.
fd = sym.which(head)[0]
if fd != 1:
msg = ('Further-than-nearest-neighbor cells '
'are connected by hopping\n{0}.')
raise ValueError(msg.format((tail, head)))
continue
if head_id >= cell_size:
# Head belongs to previous domain. The edge added here
# correspond to one left out just above.
g.add_edge(head_id, tail_id)
g.add_edge(tail_id, head_id)
return g.compressed()
class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
...@@ -2179,69 +2282,12 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): ...@@ -2179,69 +2282,12 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
sym = builder.symmetry sym = builder.symmetry
assert sym.num_directions == 1 assert sym.num_directions == 1
#### For each site of the fundamental domain, determine whether it has lsites_with, lsites_without, interface =\
#### neighbors in the previous domain or not. _make_lead_sites(builder, interface_order)
lsites_with = [] # Fund. domain sites with neighbors in prev. dom
lsites_without = [] # Remaining sites of the fundamental domain
for tail in builder.H: # Loop over all sites of the fund. domain.
for head in builder._out_neighbors(tail):
fd = sym.which(head)[0]
if fd == 1:
# Tail belongs to fund. domain, head to the next domain.
lsites_with.append(tail)
break
else:
# Tail is a fund. domain site not connected to prev. domain.
lsites_without.append(tail)
cell_size = len(lsites_with) + len(lsites_without) cell_size = len(lsites_with) + len(lsites_without)
if not lsites_with:
warnings.warn('Infinite system with disconnected cells.',
RuntimeWarning, stacklevel=3)
### Create list of sites and a lookup table
minus_one = ta.array((-1,))
plus_one = ta.array((1,))
if interface_order is None:
# interface must be sorted
interface = [sym.act(minus_one, s) for s in lsites_with]
interface.sort()
else:
lsites_with_set = set(lsites_with)
lsites_with = []
interface = []
if interface_order:
shift = ta.array((-sym.which(interface_order[0])[0] - 1,))
for shifted_iface_site in interface_order:
# Shift the interface domain before the fundamental domain.
# That's the right place for the interface of a lead to be, but
# the sites of interface_order might live in a different
# domain.
iface_site = sym.act(shift, shifted_iface_site)
lsite = sym.act(plus_one, iface_site)
try:
lsites_with_set.remove(lsite)
except KeyError:
if (-sym.which(shifted_iface_site)[0] - 1,) != shift:
raise ValueError(
'The sites in interface_order do not all '
'belong to the same lead cell.')
else:
raise ValueError('A site in interface_order is not an '
'interface site:\n' + str(iface_site))
interface.append(iface_site)
lsites_with.append(lsite)
if lsites_with_set:
raise ValueError(
'interface_order did not contain all interface sites.')
# `interface_order` *must* be sorted, hence `interface` should also
if interface != sorted(interface):
raise ValueError('Interface sites must be sorted.')
del lsites_with_set
# we previously sorted the interface, so don't sort it again # we previously sorted the interface, so don't sort it again
sites = sorted(lsites_with) + sorted(lsites_without) + interface sites = lsites_with + lsites_without + interface
del lsites_with del lsites_with
del lsites_without del lsites_without
del interface del interface
...@@ -2249,41 +2295,20 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): ...@@ -2249,41 +2295,20 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
for site_id, site in enumerate(sites): for site_id, site in enumerate(sites):
id_by_site[site] = site_id id_by_site[site] = site_id
graph = _make_lead_graph(builder, sites, id_by_site, cell_size)
# In the following, because many onsites/hoppings share the same # In the following, because many onsites/hoppings share the same
# (value, parameter) pairs, we keep them in 'cache' so that we only # (value, parameter) pairs, we keep them in 'cache' so that we only
# store a given pair in memory *once*. This is like interning strings. # store a given pair in memory *once*. This is like interning strings.
#### Make graph and extract onsite Hamiltonians. #### Extract onsites
cache = _value_params_pair_cache(1) cache = _value_params_pair_cache(1)
g = graph.Graph() onsites = [cache(builder.H[tail][1]) for tail in sites[:cell_size]]
g.num_nodes = len(sites) # Some sites could not appear in any edge.
onsites = []
for tail_id, tail in enumerate(sites[:cell_size]):
onsites.append(cache(builder.H[tail][1]))
for head in builder._out_neighbors(tail):
head_id = id_by_site.get(head)
if head_id is None:
# Head belongs neither to the fundamental domain nor to the
# previous domain. Check that it belongs to the next
# domain and ignore it otherwise as an edge corresponding
# to this one has been added already or will be added.
fd = sym.which(head)[0]
if fd != 1:
msg = ('Further-than-nearest-neighbor cells '
'are connected by hopping\n{0}.')
raise ValueError(msg.format((tail, head)))
continue
if head_id >= cell_size:
# Head belongs to previous domain. The edge added here
# correspond to one left out just above.
g.add_edge(head_id, tail_id)
g.add_edge(tail_id, head_id)
g = g.compressed()
#### Extract hoppings. #### Extract hoppings.
cache = _value_params_pair_cache(2) cache = _value_params_pair_cache(2)
hoppings = [] hoppings = []
for tail_id, head_id in g: for tail_id, head_id in graph:
tail = sites[tail_id] tail = sites[tail_id]
head = sites[head_id] head = sites[head_id]
if tail_id >= cell_size: if tail_id >= cell_size:
...@@ -2310,7 +2335,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): ...@@ -2310,7 +2335,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
else: else:
parameters = frozenset(parameters) parameters = frozenset(parameters)
self.graph = g self.graph = graph
self.sites = sites self.sites = sites
self.site_ranges = _site_ranges(sites) self.site_ranges = _site_ranges(sites)
self.id_by_site = id_by_site self.id_by_site = id_by_site
...@@ -2321,13 +2346,9 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem): ...@@ -2321,13 +2346,9 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
self.cell_size = cell_size self.cell_size = cell_size
self._init_discrete_symmetries(builder) self._init_discrete_symmetries(builder)
def hamiltonian(self, i, j, *args, params=None): def hamiltonian(self, i, j, *args, params=None):
cs = self.cell_size cs = self.cell_size
if i == j >= cs: if i == j >= cs:
i -= cs i -= cs
j -= cs j -= cs
return super().hamiltonian(i, j, *args, params=params) return super().hamiltonian(i, j, *args, params=params)
def pos(self, i):
return self.sites[i].pos
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment