Skip to content
Snippets Groups Projects

LearnerND scale output values before computing loss

Merged Jorn Hoofwijk requested to merge 78-scale-output-values into master
All threads resolved!
2 files
+ 169
14
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -9,13 +9,13 @@ import numpy as np
from scipy import interpolate
import scipy.spatial
from .base_learner import BaseLearner
from adaptive.learner.base_learner import BaseLearner
from ..notebook_integration import ensure_holoviews, ensure_plotly
from .triangulation import (Triangulation, point_in_simplex,
circumsphere, simplex_volume_in_embedding,
fast_det)
from ..utils import restore, cache_latest
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
from adaptive.learner.triangulation import (
Triangulation, point_in_simplex, circumsphere,
simplex_volume_in_embedding, fast_det)
from adaptive.utils import restore, cache_latest
def volume(simplex, ys=None):
@@ -177,8 +177,14 @@ class LearnerND(BaseLearner):
# triangulation of the pending points inside a specific simplex
self._subtriangulations = dict() # simplex → triangulation
# scale to unit
# scale to unit hypercube
# for the input
self._transform = np.linalg.inv(np.diag(np.diff(self._bbox).flat))
# for the output
self._min_value = None
self._max_value = None
self._output_multiplier = 1 # If we do not know anything, do not scale the values
self._recompute_losses_factor = 1.1
# create a private random number generator with fixed seed
self._random = random.Random(1)
@@ -277,6 +283,7 @@ class LearnerND(BaseLearner):
to_delete, to_add = tri.add_point(
point, simplex, transform=self._transform)
self.update_losses(to_delete, to_add)
self._update_range(value)
def _simplex_exists(self, simplex):
simplex = tuple(sorted(simplex))
@@ -447,10 +454,8 @@ class LearnerND(BaseLearner):
if p not in self.data)
for simplex in to_add:
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
loss = float(self.loss_per_simplex(vertices, values))
self._losses[simplex] = float(loss)
loss = self.compute_loss(simplex)
self._losses[simplex] = loss
for p in pending_points_unbound:
self._try_adding_pending_point_to_simplex(p, simplex)
@@ -462,6 +467,83 @@ class LearnerND(BaseLearner):
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
def compute_loss(self, simplex):
# get the loss
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
# scale them to a cube with sides 1
vertices = vertices @ self._transform
values = self._output_multiplier * values
# compute the loss on the scaled simplex
return float(self.loss_per_simplex(vertices, values))
def recompute_all_losses(self):
"""Recompute all losses and pending losses."""
# amortized O(N) complexity
if self.tri is None:
return
# reset the _simplex_queue
self._simplex_queue = []
# recompute all losses
for simplex in self.tri.simplices:
loss = self.compute_loss(simplex)
self._losses[simplex] = loss
# now distribute it around the the children if they are present
if simplex not in self._subtriangulations:
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
continue
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
@property
def _scale(self):
# get the output scale
return self._max_value - self._min_value
def _update_range(self, new_output):
if self._min_value is None or self._max_value is None:
# this is the first point, nothing to do, just set the range
self._min_value = np.array(new_output)
self._max_value = np.array(new_output)
self._old_scale = self._scale
return False
# if range in one or more directions is doubled, then update all losses
self._min_value = np.minimum(self._min_value, new_output)
self._max_value = np.maximum(self._max_value, new_output)
scale_multiplier = 1 / self._scale
if isinstance(scale_multiplier, float):
scale_multiplier = np.array([scale_multiplier], dtype=float)
# the maximum absolute value that is in the range. Because this is the
# largest number, this also has the largest absolute numerical error.
max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]), axis=0)
# since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
abs_err = 1e-15 * max_absolute_value_in_range
# when scaling the floats, the error gets increased.
scaled_err = abs_err * scale_multiplier
allowed_numerical_error = 1e-2
# do not scale along the axis if the numerical error gets too big
scale_multiplier[scaled_err > allowed_numerical_error] = 1
self._output_multiplier = scale_multiplier
scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale))
if scale_factor > self._recompute_losses_factor:
self._old_scale = self._scale
self.recompute_all_losses()
return True
return False
def losses(self):
"""Get the losses of each simplex in the current triangulation, as dict
Loading