Skip to content
Snippets Groups Projects
Commit 8ec83c23 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

change the simplex_queue to a SortedKeyList

parent 7c340516
No related branches found
No related tags found
1 merge request!141change the simplex_queue to a SortedKeyList
......@@ -9,6 +9,7 @@ import random
import numpy as np
from scipy import interpolate
import scipy.spatial
from sortedcontainers import SortedKeyList
from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
......@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
distance_matrix = scipy.spatial.distance.squareform(distances)
i, j = np.unravel_index(np.argmax(distance_matrix),
distance_matrix.shape)
point = (simplex[i, :] + simplex[j, :]) / 2
if transform is not None:
......@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None):
return point
def _simplex_evaluation_priority(key):
# We round the loss to 8 digits such that losses
# are equal up to numerical precision will be considered
# to be equal. This is needed because we want the learner
# to behave in a deterministic fashion.
loss, simplex, subsimplex = key
return -round(loss, ndigits=8), simplex, subsimplex or (0,)
class LearnerND(BaseLearner):
"""Learns and predicts a function 'f: ℝ^N → ℝ^M'.
......@@ -200,7 +209,7 @@ class LearnerND(BaseLearner):
# so when popping an item, you should check that the simplex that has
# been returned has not been deleted. This checking is done by
# _pop_highest_existing_simplex
self._simplex_queue = [] # heap
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
@property
def npoints(self):
......@@ -344,9 +353,7 @@ class LearnerND(BaseLearner):
subtriangulation = self._subtriangulations[simplex]
for subsimplex in new_subsimplices:
subloss = subtriangulation.volume(subsimplex) * loss_density
subloss = round(subloss, ndigits=8)
heapq.heappush(self._simplex_queue,
(-subloss, simplex, subsimplex))
self._simplex_queue.add((subloss, simplex, subsimplex))
def _ask_and_tell_pending(self, n=1):
xs, losses = zip(*(self._ask() for _ in range(n)))
......@@ -386,7 +393,7 @@ class LearnerND(BaseLearner):
# find the simplex with the highest loss, we do need to check that the
# simplex hasn't been deleted yet
while len(self._simplex_queue):
loss, simplex, subsimplex = heapq.heappop(self._simplex_queue)
loss, simplex, subsimplex = self._simplex_queue.pop(0)
if (subsimplex is None
and simplex in self.tri.simplices
and simplex not in self._subtriangulations):
......@@ -462,8 +469,7 @@ class LearnerND(BaseLearner):
self._try_adding_pending_point_to_simplex(p, simplex)
if simplex not in self._subtriangulations:
loss = round(loss, ndigits=8)
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
self._simplex_queue.add((loss, simplex, None))
continue
self._update_subsimplex_losses(
......@@ -488,7 +494,7 @@ class LearnerND(BaseLearner):
return
# reset the _simplex_queue
self._simplex_queue = []
self._simplex_queue = SortedKeyList(key=_simplex_evaluation_priority)
# recompute all losses
for simplex in self.tri.simplices:
......@@ -497,8 +503,7 @@ class LearnerND(BaseLearner):
# now distribute it around the the children if they are present
if simplex not in self._subtriangulations:
loss = round(loss, ndigits=8)
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
self._simplex_queue.add((loss, simplex, None))
continue
self._update_subsimplex_losses(
......
......@@ -362,9 +362,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(learner_type, f, lear
# XXX: This *should* pass (https://gitlab.kwant-project.org/qt/adaptive/issues/84)
# but we xfail it now, as Learner2D will be deprecated anyway
# The LearnerND fails sometimes, see
# https://gitlab.kwant-project.org/qt/adaptive/merge_requests/128#note_21807
@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND))
@run_with(Learner1D, xfail(Learner2D), LearnerND)
def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner_kwargs):
"""Learners behave identically under transformations that leave
the loss invariant.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment