Skip to content
Snippets Groups Projects
Commit 691c56a9 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

1D: implement a more efficient 'tell_many'

parent dc57ba19
No related branches found
No related tags found
1 merge request!96More efficient 'tell_many'
......@@ -67,6 +67,17 @@ def linspace(x_left, x_right, n):
return [x_left + step * i for i in range(1, n)]
def _get_neighbors_from_list(xs):
xs = np.sort(xs)
xs_left = np.roll(xs, 1).tolist()
xs_right = np.roll(xs, -1).tolist()
xs_left[0] = None
xs_right[-1] = None
neighbors = {x: [x_L, x_R] for x, x_L, x_R
in zip(xs, xs_left, xs_right)}
return sortedcontainers.SortedDict(neighbors)
class Learner1D(BaseLearner):
"""Learns and predicts a function 'f:ℝ → ℝ^N'.
......@@ -105,7 +116,7 @@ class Learner1D(BaseLearner):
self.losses = {}
self.losses_combined = {}
self.data = sortedcontainers.SortedDict()
self.data = {}
self.pending_points = set()
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
......@@ -273,7 +284,7 @@ class Learner1D(BaseLearner):
self.update_losses(x, real=True)
# If the scale has increased enough, recompute all losses.
if self._scale[1] > self._oldscale[1] * 2:
if self._scale[1] > 2 * self._oldscale[1]:
for interval in self.losses:
self.update_interpolated_loss_in_interval(*interval)
......@@ -288,6 +299,75 @@ class Learner1D(BaseLearner):
self.update_neighbors(x, self.neighbors_combined)
self.update_losses(x, real=False)
def tell_many(self, xs, ys, *, force=False):
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
# Only run this more efficient method if there are
# at least 2 points and the amount of points added are
# at least half of the number of points already in 'data'.
# These "magic numbers" are somewhat arbitrary.
super().tell_many(xs, ys)
return
# Add data points
self.data.update(zip(xs, ys))
self.pending_points.difference_update(xs)
# Get all data as numpy arrays
points = np.array(list(self.data.keys()))
values = np.array(list(self.data.values()))
points_pending = np.array(list(self.pending_points))
points_combined = np.hstack([points_pending, points])
# Generate neighbors
self.neighbors = _get_neighbors_from_list(points)
self.neighbors_combined = _get_neighbors_from_list(points_combined)
# Update scale
self._bbox[0] = [points_combined.min(), points_combined.max()]
self._bbox[1] = [values.min(axis=0), values.max(axis=0)]
self._scale[0] = self._bbox[0][1] - self._bbox[0][0]
self._scale[1] = np.max(self._bbox[1][1] - self._bbox[1][0])
self._oldscale = deepcopy(self._scale)
# Find the intervals for which the losses should be calculated.
intervals, intervals_combined = [
[(x_m, x_r) for x_m, (x_l, x_r) in neighbors.items()][:-1]
for neighbors in (self.neighbors, self.neighbors_combined)]
# The the losses for the "real" intervals.
self.losses = {}
for x_left, x_right in intervals:
self.losses[x_left, x_right] = (
self.loss_per_interval((x_left, x_right), self._scale, self.data)
if x_right - x_left >= self._dx_eps else 0)
# List with "real" intervals that have interpolated intervals inside
to_interpolate = []
self.losses_combined = {}
for ival in intervals_combined:
# If this interval exists in 'losses' then copy it otherwise
# calculate it.
if ival in self.losses:
self.losses_combined[ival] = self.losses[ival]
else:
# Set all losses to inf now, later they might be udpdated if the
# interval appears to be inside a real interval.
self.losses_combined[ival] = np.inf
x_left, x_right = ival
a, b = to_interpolate[-1] if to_interpolate else (None, None)
if b == x_left and (a, b) not in self.losses:
# join (a, b) and (x_left, x_right) --> (a, x_right)
to_interpolate[-1] = (a, x_right)
else:
to_interpolate.append((x_left, x_right))
for ival in to_interpolate:
if ival in self.losses:
# If this interval does not exist it should already
# have an inf loss.
self.update_interpolated_loss_in_interval(*ival)
def ask(self, n, tell_pending=True):
"""Return n points that are expected to maximally reduce the loss."""
points, loss_improvements = self._ask_points_without_adding(n)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment