Skip to content
Snippets Groups Projects

change the simplex_queue to a SortedKeyList

Merged Bas Nijholt requested to merge fix_learnerND_scaling into stable-0.7
2 unresolved threads
Files
3
@@ -9,6 +9,7 @@ import random
import numpy as np
from scipy import interpolate
import scipy.spatial
from sortedcontainers import SortedKeyList
from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
distance_matrix = scipy.spatial.distance.squareform(distances)
i, j = np.unravel_index(np.argmax(distance_matrix),
distance_matrix.shape)
point = (simplex[i, :] + simplex[j, :]) / 2
if transform is not None:
@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None):
return point
def _simplex_evaluation_priority(key):
# We round the loss to 8 digits such that losses
# are equal up to numerical precision will be considered
# to be equal. This is needed because we want the learner
# to behave in a deterministic fashion.
loss, simplex, subsimplex = key
return -round(loss, ndigits=8), simplex, subsimplex or (0,)
class LearnerND(BaseLearner):
"""Learns and predicts a function 'f: ℝ^N → ℝ^M'.
@@ -200,7 +209,7 @@ class LearnerND(BaseLearner):
# so when popping an item, you should check that the simplex that has
# been returned has not been deleted. This checking is done by
# _pop_highest_existing_simplex
self._simplex_queue = [] # heap
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
@property
def npoints(self):
@@ -227,10 +236,10 @@ class LearnerND(BaseLearner):
def bounds_are_done(self):
return all(p in self.data for p in self._bounds_points)
def ip(self):
def _ip(self):
"""A `scipy.interpolate.LinearNDInterpolator` instance
containing the learner's data."""
# XXX: take our own triangulation into account when generating the ip
# XXX: take our own triangulation into account when generating the _ip
return interpolate.LinearNDInterpolator(self.points, self.values)
@property
@@ -242,7 +251,7 @@ class LearnerND(BaseLearner):
try:
self._tri = Triangulation(self.points)
self.update_losses(set(), self._tri.simplices)
self._update_losses(set(), self._tri.simplices)
return self._tri
except ValueError:
# A ValueError is raised if we do not have enough points or
@@ -283,7 +292,7 @@ class LearnerND(BaseLearner):
simplex = None
to_delete, to_add = tri.add_point(
point, simplex, transform=self._transform)
self.update_losses(to_delete, to_add)
self._update_losses(to_delete, to_add)
def _simplex_exists(self, simplex):
simplex = tuple(sorted(simplex))
@@ -344,9 +353,7 @@ class LearnerND(BaseLearner):
subtriangulation = self._subtriangulations[simplex]
for subsimplex in new_subsimplices:
subloss = subtriangulation.volume(subsimplex) * loss_density
subloss = round(subloss, ndigits=8)
heapq.heappush(self._simplex_queue,
(-subloss, simplex, subsimplex))
self._simplex_queue.add((subloss, simplex, subsimplex))
def _ask_and_tell_pending(self, n=1):
xs, losses = zip(*(self._ask() for _ in range(n)))
@@ -386,7 +393,7 @@ class LearnerND(BaseLearner):
# find the simplex with the highest loss, we do need to check that the
# simplex hasn't been deleted yet
while len(self._simplex_queue):
loss, simplex, subsimplex = heapq.heappop(self._simplex_queue)
loss, simplex, subsimplex = self._simplex_queue.pop(0)
if (subsimplex is None
and simplex in self.tri.simplices
and simplex not in self._subtriangulations):
@@ -441,7 +448,7 @@ class LearnerND(BaseLearner):
return self._ask_best_point() # O(log N)
def update_losses(self, to_delete: set, to_add: set):
def _update_losses(self, to_delete: set, to_add: set):
# XXX: add the points outside the triangulation to this as well
pending_points_unbound = set()
@@ -455,21 +462,20 @@ class LearnerND(BaseLearner):
if p not in self.data)
for simplex in to_add:
loss = self.compute_loss(simplex)
loss = self._compute_loss(simplex)
self._losses[simplex] = loss
for p in pending_points_unbound:
self._try_adding_pending_point_to_simplex(p, simplex)
if simplex not in self._subtriangulations:
loss = round(loss, ndigits=8)
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
self._simplex_queue.add((loss, simplex, None))
continue
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
def compute_loss(self, simplex):
def _compute_loss(self, simplex):
# get the loss
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
@@ -481,24 +487,23 @@ class LearnerND(BaseLearner):
# compute the loss on the scaled simplex
return float(self.loss_per_simplex(vertices, values))
def recompute_all_losses(self):
def _recompute_all_losses(self):
"""Recompute all losses and pending losses."""
# amortized O(N) complexity
if self.tri is None:
return
# reset the _simplex_queue
self._simplex_queue = []
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
# recompute all losses
for simplex in self.tri.simplices:
loss = self.compute_loss(simplex)
loss = self._compute_loss(simplex)
self._losses[simplex] = loss
# now distribute it around the the children if they are present
if simplex not in self._subtriangulations:
loss = round(loss, ndigits=8)
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
self._simplex_queue.add((loss, simplex, None))
continue
self._update_subsimplex_losses(
@@ -543,27 +548,14 @@ class LearnerND(BaseLearner):
scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale))
if scale_factor > self._recompute_losses_factor:
self._old_scale = self._scale
self.recompute_all_losses()
self._recompute_all_losses()
return True
return False
def losses(self):
"""Get the losses of each simplex in the current triangulation, as dict
Returns
-------
losses : dict
the key is a simplex, the value is the loss of this simplex
"""
# XXX could be a property
if self.tri is None:
return dict()
return self._losses
@cache_latest
def loss(self, real=True):
losses = self.losses() # XXX: compute pending loss if real == False
# XXX: compute pending loss if real == False
losses = self._losses if self.tri is not None else dict()
return max(losses.values()) if losses else float('inf')
def remove_unfinished(self):
@@ -607,7 +599,7 @@ class LearnerND(BaseLearner):
xs = ys = np.linspace(0, 1, n)
xs = xs * (x[1] - x[0]) + x[0]
ys = ys * (y[1] - y[0]) + y[0]
z = self.ip()(xs[:, None], ys[None, :]).squeeze()
z = self._ip()(xs[:, None], ys[None, :]).squeeze()
im = hv.Image(np.rot90(z), bounds=lbrt)
@@ -656,7 +648,7 @@ class LearnerND(BaseLearner):
for i in range(self.ndim)]
ind = next(i for i in range(self.ndim) if i not in cut_mapping)
x = values[ind]
y = self.ip()(*values)
y = self._ip()(*values)
p = hv.Path((x, y))
# Plot with 5% margins such that the boundary points are visible
@@ -686,7 +678,7 @@ class LearnerND(BaseLearner):
lbrt = np.reshape(lbrt, (2, 2)).T.flatten().tolist()
if len(self.data) >= 4:
z = self.ip()(*values).squeeze()
z = self._ip()(*values).squeeze()
im = hv.Image(np.rot90(z), bounds=lbrt)
else:
im = hv.Image([], bounds=lbrt)
Loading