Skip to content
Snippets Groups Projects

More efficient 'tell_many'

Merged Bas Nijholt requested to merge efficient_tell_many into master
All threads resolved!
Compare and
2 files
+ 177
11
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -67,6 +67,17 @@ def linspace(x_left, x_right, n):
return [x_left + step * i for i in range(1, n)]
def _get_neighbors_from_list(xs):
xs = np.sort(xs)
xs_left = np.roll(xs, 1).tolist()
xs_right = np.roll(xs, -1).tolist()
xs_left[0] = None
xs_right[-1] = None
neighbors = {x: [x_L, x_R] for x, x_L, x_R
in zip(xs, xs_left, xs_right)}
return sortedcontainers.SortedDict(neighbors)
class Learner1D(BaseLearner):
"""Learns and predicts a function 'f:ℝ → ℝ^N'.
@@ -105,7 +116,7 @@ class Learner1D(BaseLearner):
self.losses = {}
self.losses_combined = {}
self.data = sortedcontainers.SortedDict()
self.data = {}
self.pending_points = set()
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
@@ -129,7 +140,17 @@ class Learner1D(BaseLearner):
@property
def vdim(self):
return 1 if self._vdim is None else self._vdim
if self._vdim is None:
if self.data:
y = next(iter(self.data.values()))
try:
self._vdim = len(np.squeeze(y))
except TypeError:
# Means we are taking the length of a float
self._vdim = 1
else:
return 1
return self._vdim
@property
def npoints(self):
@@ -258,12 +279,6 @@ class Learner1D(BaseLearner):
# remove from set of pending points
self.pending_points.discard(x)
if self._vdim is None:
try:
self._vdim = len(np.squeeze(y))
except TypeError:
self._vdim = 1
if not self.bounds[0] <= x <= self.bounds[1]:
return
@@ -273,7 +288,7 @@ class Learner1D(BaseLearner):
self.update_losses(x, real=True)
# If the scale has increased enough, recompute all losses.
if self._scale[1] > self._oldscale[1] * 2:
if self._scale[1] > 2 * self._oldscale[1]:
for interval in self.losses:
self.update_interpolated_loss_in_interval(*interval)
@@ -288,6 +303,73 @@ class Learner1D(BaseLearner):
self.update_neighbors(x, self.neighbors_combined)
self.update_losses(x, real=False)
def tell_many(self, xs, ys):
if not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
# Only run this more efficient method if there are
# at least 2 points and the amount of points added are
# at least half of the number of points already in 'data'.
super().tell_many(xs, ys)
return
# Add data points
for x, y in zip(xs, ys):
self.data[x] = y
self.pending_points.discard(x)
# Get all data as numpy arrays
points = np.array(list(self.data.keys()))
values = np.array(list(self.data.values()))
# Generate neighbors
points_pending = np.array(list(self.pending_points))
points_combined = np.hstack([points_pending, points])
self.neighbors = _get_neighbors_from_list(points)
self.neighbors_combined = _get_neighbors_from_list(points_combined)
# Update scale
self._bbox[0] = [points_combined.min(), points_combined.max()]
self._bbox[1] = [values.min(axis=0), values.max(axis=0)]
self._scale[0] = self._bbox[0][1] - self._bbox[0][0]
self._scale[1] = np.max(self._bbox[1][1] - self._bbox[1][0])
self._oldscale = deepcopy(self._scale)
# Find the intervals for which the losses should be calculated.
intervals, intervals_combined = [
[(x_m, x_r) for x_m, (x_l, x_r) in neighbors.items()][:-1]
for neighbors in (self.neighbors, self.neighbors_combined)]
# The the losses for the "real" intervals.
self.losses = {}
for x_left, x_right in intervals:
self.losses[x_left, x_right] = (
self.loss_per_interval((x_left, x_right), self._scale, self.data)
if x_right - x_left >= self._dx_eps else 0)
# List with "real" intervals that have interpolated intervals inside
to_interpolate = []
self.losses_combined = {}
for ival in intervals_combined:
# If this interval exists in 'losses' then copy it otherwise
# calculate it.
if ival in self.losses:
self.losses_combined[ival] = self.losses[ival]
else:
# set all invals to inf now, later they might be udpdated if the
# interval appears to be inside a real interval
self.losses_combined[ival] = np.inf
x_left, x_right = ival
a, b = to_interpolate[-1] if to_interpolate else (None, None)
if b == x_left and (a, b) not in self.losses:
# join (a, b) and (x_left, x_right) --> (a, x_right)
to_interpolate[-1] = (a, x_right)
else:
to_interpolate.append((x_left, x_right))
for ival in to_interpolate:
if ival in self.losses:
self.update_interpolated_loss_in_interval(*ival)
def ask(self, n, tell_pending=True):
"""Return n points that are expected to maximally reduce the loss."""
points, loss_improvements = self._ask_points_without_adding(n)
@@ -379,8 +461,7 @@ class Learner1D(BaseLearner):
elif not self.vdim > 1:
p = hv.Scatter(self.data) * hv.Path([])
else:
xs = list(self.data.keys())
ys = list(self.data.values())
xs, ys = zip(*sorted(self.data.items()))
p = hv.Path((xs, ys)) * hv.Scatter([])
# Plot with 5% empty margins such that the boundary points are visible
Loading