Skip to content
Snippets Groups Projects

change the simplex_queue to a SortedKeyList

Merged Bas Nijholt requested to merge fix_learnerND_scaling into stable-0.7
2 unresolved threads
Files
2
@@ -9,6 +9,7 @@ import random
import numpy as np
from scipy import interpolate
import scipy.spatial
from sortedcontainers import SortedKeyList
from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
distance_matrix = scipy.spatial.distance.squareform(distances)
i, j = np.unravel_index(np.argmax(distance_matrix),
distance_matrix.shape)
point = (simplex[i, :] + simplex[j, :]) / 2
if transform is not None:
@@ -100,6 +100,13 @@ def choose_point_in_simplex(simplex, transform=None):
return point
def make_simplex_queue():
def sort_key(key):
loss, simpl, subsimpl = key
return -round(loss, ndigits=6), simpl, subsimpl or (0,)
return SortedKeyList(key=sort_key)
class LearnerND(BaseLearner):
"""Learns and predicts a function 'f: ℝ^N → ℝ^M'.
@@ -200,7 +207,7 @@ class LearnerND(BaseLearner):
# so when popping an item, you should check that the simplex that has
# been returned has not been deleted. This checking is done by
# _pop_highest_existing_simplex
self._simplex_queue = [] # heap
self._simplex_queue = make_simplex_queue() # heap
@property
def npoints(self):
@@ -344,9 +351,7 @@ class LearnerND(BaseLearner):
subtriangulation = self._subtriangulations[simplex]
for subsimplex in new_subsimplices:
subloss = subtriangulation.volume(subsimplex) * loss_density
subloss = round(subloss, ndigits=8)
heapq.heappush(self._simplex_queue,
(-subloss, simplex, subsimplex))
self._simplex_queue.add((subloss, simplex, subsimplex))
def _ask_and_tell_pending(self, n=1):
xs, losses = zip(*(self._ask() for _ in range(n)))
@@ -386,7 +391,7 @@ class LearnerND(BaseLearner):
# find the simplex with the highest loss, we do need to check that the
# simplex hasn't been deleted yet
while len(self._simplex_queue):
loss, simplex, subsimplex = heapq.heappop(self._simplex_queue)
loss, simplex, subsimplex = self._simplex_queue.pop(0)
if (subsimplex is None
and simplex in self.tri.simplices
and simplex not in self._subtriangulations):
@@ -462,8 +467,7 @@ class LearnerND(BaseLearner):
self._try_adding_pending_point_to_simplex(p, simplex)
if simplex not in self._subtriangulations:
loss = round(loss, ndigits=8)
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
self._simplex_queue.add((loss, simplex, None))
continue
self._update_subsimplex_losses(
@@ -488,7 +492,7 @@ class LearnerND(BaseLearner):
return
# reset the _simplex_queue
self._simplex_queue = []
self._simplex_queue = make_simplex_queue()
# recompute all losses
for simplex in self.tri.simplices:
@@ -497,8 +501,7 @@ class LearnerND(BaseLearner):
# now distribute it around the the children if they are present
if simplex not in self._subtriangulations:
loss = round(loss, ndigits=8)
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
self._simplex_queue.add((loss, simplex, None))
continue
self._update_subsimplex_losses(
Loading