Commit 46dcc082 authored by Christoph Groth's avatar Christoph Groth Committed by Joseph Weston

finalized builders: store parameter names together with values

With this, wrapping parameter-substituted value functions is no longer
necessary and any overhead of using builder parameter substitutions
disappears.
Co-authored-by: Joseph Weston's avatarJoseph Weston <joseph@weston.cloud>
parent c2da972a
......@@ -1812,18 +1812,6 @@ def _translate_cons_law(cons_law):
class _FinalizedBuilderMixin:
"""Common functionality for all finalized builders"""
def _init_param_names(self):
"""For each value function, store the required parameters.
"""
pn = {}
for values, skip in [(self.onsite_hamiltonians, 1), (self.hoppings, 2)]:
for value in values:
if not callable(value) or value is Other or value in pn:
continue
pn[value] = get_parameters(value)[skip:]
self._param_names = pn
def _init_discrete_symmetries(self, builder):
def operator(op):
return Density(self, op, check_hermiticity=False)
......@@ -1842,31 +1830,43 @@ class _FinalizedBuilderMixin:
if args and params:
raise TypeError("'args' and 'params' are mutually exclusive.")
if i == j:
value = self.onsite_hamiltonians[i]
if callable(value):
value, param_names = self.onsites[i]
if param_names is not None: # 'value' is callable
site = self.symmetry.to_fd(self.sites[i])
if params:
args = map(params.__getitem__, self._param_names[value])
args = map(params.__getitem__, param_names)
try:
value = value(site, *args)
except Exception as exc:
if isinstance(exc, KeyError) and params:
missing = [p for p in param_names if p not in params]
if missing:
msg = ('System is missing required arguments: ',
', '.join(map('"{}"'.format, missing)))
raise TypeError(''.join(msg))
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
value, param_names = self.hoppings[edge_id]
conj = value is Other
if conj:
i, j = j, i
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
if callable(value):
value, param_names = self.hoppings[edge_id]
if param_names is not None: # 'value' is callable
sites = self.sites
site_i, site_j = self.symmetry.to_fd(sites[i], sites[j])
if params:
args = map(params.__getitem__, self._param_names[value])
args = map(params.__getitem__, param_names)
try:
value = value(site_i, site_j, *args)
except Exception as exc:
if isinstance(exc, KeyError) and params:
missing = [p for p in param_names if p not in params]
if missing:
msg = ('System is missing required arguments: ',
', '.join(map('"{}"'.format, missing)))
raise TypeError(''.join(msg))
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
......@@ -1896,6 +1896,26 @@ class _FinalizedBuilderMixin:
self._symmetries))
# The same (value, parameters) pair will be used for many sites/hoppings,
# so we cache it to avoid wasting extra memory.
def _value_params_pair_cache(nstrip):
def get(value):
entry = cache.get(id(value))
if entry is None:
if isinstance(value, _Substituted):
entry = value.func, value.params[nstrip:]
elif callable(value):
entry = value, get_parameters(value)[nstrip:]
else:
# None means: value is not callable. (That's faster to check.)
entry = value, None
cache[id(value)] = entry
return entry
assert nstrip in [1, 2]
cache = {}
return get
class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
"""Finalized `Builder` with leads.
......@@ -1961,20 +1981,24 @@ class FiniteSystem(_FinalizedBuilderMixin, system.FiniteSystem):
lead_interfaces.append(np.array(interface))
onsite_hamiltonians = [builder.H[site][1] for site in sites]
hoppings = [builder._get_edge(sites[tail], sites[head])
for tail, head in g]
# Because many onsites/hoppings share the same (value, parameter)
# pairs, we keep them in a cache so that we only store a given pair
# in memory *once*. This is a similar idea to interning strings.
cache = _value_params_pair_cache(1)
onsites = [cache(builder.H[site][1]) for site in sites]
cache = _value_params_pair_cache(2)
hoppings = [cache(builder._get_edge(sites[tail], sites[head]))
for tail, head in g]
self.graph = g
self.sites = sites
self.site_ranges = _site_ranges(sites)
self.id_by_site = id_by_site
self.hoppings = hoppings
self.onsite_hamiltonians = onsite_hamiltonians
self.onsites = onsites
self.symmetry = builder.symmetry
self.leads = finalized_leads
self.lead_interfaces = lead_interfaces
self._init_param_names()
self._init_discrete_symmetries(builder)
......@@ -2083,12 +2107,17 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
for site_id, site in enumerate(sites):
id_by_site[site] = site_id
# In the following, because many onsites/hoppings share the same
# (value, parameter) pairs, we keep them in 'cache' so that we only
# store a given pair in memory *once*. This is like interning strings.
#### Make graph and extract onsite Hamiltonians.
cache = _value_params_pair_cache(1)
g = graph.Graph()
g.num_nodes = len(sites) # Some sites could not appear in any edge.
onsite_hamiltonians = []
onsites = []
for tail_id, tail in enumerate(sites[:cell_size]):
onsite_hamiltonians.append(builder.H[tail][1])
onsites.append(cache(builder.H[tail][1]))
for head in builder._out_neighbors(tail):
head_id = id_by_site.get(head)
if head_id is None:
......@@ -2110,6 +2139,7 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
g = g.compressed()
#### Extract hoppings.
cache = _value_params_pair_cache(2)
hoppings = []
for tail_id, head_id in g:
tail = sites[tail_id]
......@@ -2118,17 +2148,16 @@ class InfiniteSystem(_FinalizedBuilderMixin, system.InfiniteSystem):
# The tail belongs to the previous domain. Find the
# corresponding hopping with the tail in the fund. domain.
tail, head = sym.to_fd(tail, head)
hoppings.append(builder._get_edge(tail, head))
hoppings.append(cache(builder._get_edge(tail, head)))
self.graph = g
self.sites = sites
self.site_ranges = _site_ranges(sites)
self.id_by_site = id_by_site
self.hoppings = hoppings
self.onsite_hamiltonians = onsite_hamiltonians
self.onsites = onsites
self.symmetry = builder.symmetry
self.cell_size = cell_size
self._init_param_names()
self._init_discrete_symmetries(builder)
......
......@@ -294,7 +294,7 @@ def check_onsite(fsyst, sites, subset=False, check_values=True):
site = fsyst.sites[node].tag
freq[site] = freq.get(site, 0) + 1
if check_values and site in sites:
assert fsyst.onsite_hamiltonians[node] is sites[site]
assert fsyst.onsites[node][0] is sites[site]
if not subset:
# Check that all sites of `fsyst` are in `sites`.
for site in freq.keys():
......@@ -310,7 +310,7 @@ def check_hoppings(fsyst, hops):
tail, head = edge
tail = fsyst.sites[tail].tag
head = fsyst.sites[head].tag
value = fsyst.hoppings[edge_id]
value = fsyst.hoppings[edge_id][0]
if value is builder.Other:
assert (head, tail) in hops
else:
......@@ -1208,6 +1208,10 @@ def test_argument_passing():
with raises(TypeError):
inf_syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
# test that missing any parameters raises TypeError
with raises(TypeError):
syst.hamiltonian(0, 0, params=dict(fake=10))
# test that passing parameters without default values works, and that
# passing parameters with default values fails
def onsite(site, p1, p2):
......
......@@ -27,7 +27,7 @@ def test_hamiltonian_submatrix():
mat = syst2.hamiltonian_submatrix()
assert mat.shape == (3, 3)
# Sorting is required due to unknown compression order of builder.
perm = np.argsort(syst2.onsite_hamiltonians)
perm = np.argsort([os[0] for os in syst2.onsites])
mat_should_be = np.array([[0, 1j, 0], [-1j, 0.5, 2j], [0, -2j, 1]])
mat = mat[perm, :]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment