Skip to content
Snippets Groups Projects

LearnerND scale output values before computing loss

Merged Jorn Hoofwijk requested to merge 78-scale-output-values into master
2 files
+ 162
9
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -11,11 +11,10 @@ from scipy import interpolate
import scipy.spatial
from adaptive.learner.base_learner import BaseLearner
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
from adaptive.learner.triangulation import (
Triangulation, point_in_simplex, circumsphere,
simplex_volume_in_embedding, fast_det
)
from adaptive.notebook_integration import ensure_holoviews, ensure_plotly
simplex_volume_in_embedding, fast_det)
from adaptive.utils import restore, cache_latest
@@ -178,8 +177,14 @@ class LearnerND(BaseLearner):
# triangulation of the pending points inside a specific simplex
self._subtriangulations = dict() # simplex → triangulation
# scale to unit
# scale to unit hypercube
# for the input
self._transform = np.linalg.inv(np.diag(np.diff(self._bbox).flat))
# for the output
self._min_value = None
self._max_value = None
self._output_multiplier = 1 # If we do not know anything, do not scale the values
self._recompute_losses_factor = 1.1
# create a private random number generator with fixed seed
self._random = random.Random(1)
@@ -271,6 +276,7 @@ class LearnerND(BaseLearner):
if not self.inside_bounds(point):
return
self._update_range(value)
if tri is not None:
simplex = self._pending_to_simplex.get(point)
if simplex is not None and not self._simplex_exists(simplex):
@@ -448,10 +454,8 @@ class LearnerND(BaseLearner):
if p not in self.data)
for simplex in to_add:
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
loss = float(self.loss_per_simplex(vertices, values))
self._losses[simplex] = float(loss)
loss = self.compute_loss(simplex)
self._losses[simplex] = loss
for p in pending_points_unbound:
self._try_adding_pending_point_to_simplex(p, simplex)
@@ -463,6 +467,83 @@ class LearnerND(BaseLearner):
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
def compute_loss(self, simplex):
# get the loss
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
# scale them to a cube with sides 1
vertices = vertices @ self._transform
values = self._output_multiplier * values
# compute the loss on the scaled simplex
return float(self.loss_per_simplex(vertices, values))
def recompute_all_losses(self):
"""Recompute all losses and pending losses."""
# amortized O(N) complexity
if self.tri is None:
return
# reset the _simplex_queue
self._simplex_queue = []
# recompute all losses
for simplex in self.tri.simplices:
loss = self.compute_loss(simplex)
self._losses[simplex] = loss
# now distribute it around the the children if they are present
if simplex not in self._subtriangulations:
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
continue
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
@property
def _scale(self):
# get the output scale
return self._max_value - self._min_value
def _update_range(self, new_output):
if self._min_value is None or self._max_value is None:
# this is the first point, nothing to do, just set the range
self._min_value = np.array(new_output)
self._max_value = np.array(new_output)
self._old_scale = self._scale
return False
# if range in one or more directions is doubled, then update all losses
self._min_value = np.minimum(self._min_value, new_output)
self._max_value = np.maximum(self._max_value, new_output)
scale_multiplier = 1 / self._scale
if isinstance(scale_multiplier, float):
scale_multiplier = np.array([scale_multiplier], dtype=float)
# the maximum absolute value that is in the range. Because this is the
# largest number, this also has the largest absolute numerical error.
max_absolute_value_in_range = np.max(np.abs([self._min_value, self._max_value]), axis=0)
# since a float has a relative error of 1e-15, the absolute error is the value * 1e-15
abs_err = 1e-15 * max_absolute_value_in_range
# when scaling the floats, the error gets increased.
scaled_err = abs_err * scale_multiplier
allowed_numerical_error = 1e-2
# do not scale along the axis if the numerical error gets too big
scale_multiplier[scaled_err > allowed_numerical_error] = 1
self._output_multiplier = scale_multiplier
scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale))
if scale_factor > self._recompute_losses_factor:
self._old_scale = self._scale
self.recompute_all_losses()
return True
return False
def losses(self):
"""Get the losses of each simplex in the current triangulation, as dict
Loading