Skip to content
Snippets Groups Projects

Resolve "Learner1D's bound check algo in self.ask doesn't take self.data or self.pending_points"

All threads resolved!
Compare and
3 files
+ 127
35
Compare changes
  • Side-by-side
  • Inline
Files
3
@@ -287,44 +287,92 @@ class Learner1D(BaseLearner):
if n == 0:
return [], []
# Some temporary functions we are gonna need later in ask().
# xs is very similar to np.linspace but doesn't include the bounds
def xs(x_left, x_right, n):
if n == 1:
# This is just an optimization
return []
else:
step = (x_right - x_left) / n
return [x_left + step * i for i in range(1, n)]
x_scale = self._scale[0]
+2
def finite_loss(loss, x):
# if the loss is infinite, return the distance between the two points
if math.isinf(loss):
return (x[1] - x[0]) / x_scale
return loss
# If the bounds have not been chosen yet, we choose them first.
missing_bounds = [b for b in self.bounds if b not in self.data
and b not in self.pending_points]
if missing_bounds:
loss_improvements = [np.inf] * n
# XXX: should check if points are present in self.data or self.pending_points
points = np.linspace(*self.bounds, n + 2 - len(missing_bounds)).tolist()
if len(missing_bounds) == 1:
points = points[1:] if missing_bounds[0] == self.bounds[1] else points[:-1]
else:
def xs(x_left, x_right, n):
if n == 1:
# This is just an optimization
return []
else:
step = (x_right - x_left) / n
return [x_left + step * i for i in range(1, n)]
# Calculate how many points belong to each interval.
x_scale = self._scale[0]
quals = [((-loss if not math.isinf(loss) else -(x[1] - x[0]) / x_scale, x, 1))
for x, loss in self.losses_combined.items()]
heapq.heapify(quals)
for point_number in range(n):
quality, x, n = quals[0]
if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
# The interval is too small and should not be subdivided
quality = np.inf
heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
points = list(itertools.chain.from_iterable(
xs(*x, n) for quality, x, n in quals))
loss_improvements = list(itertools.chain.from_iterable(
itertools.repeat(-quality, n - 1)
for quality, x, n in quals))
if len(missing_bounds) >= n:
# shortcut as we do not need to find any more points
if add_data:
self.tell_many(missing_bounds[:n], itertools.repeat(None))
return missing_bounds[:n], [np.inf] * n
quals = [(-finite_loss(loss, x), x, 1)
for x, loss in self.losses_combined.items()]
if len(missing_bounds) == 0:
pass # perfect, do nothing
elif len(missing_bounds) == 1:
# add a connection from the bound to the nearest point/pending_point
x1, = missing_bounds
all_x_combined = list(set(self.data.keys()) | self.pending_points)
assert len(all_x_combined) > 0 # because at least the other bound should be in here
if x1 == self.bounds[0]:
# left bound -> find minimum x and connect
other = np.min(all_x_combined)
x = (x1, other)
quals.append((-finite_loss(np.inf, x), x, 1))
else:
# right bound -> find maximum x and connect
other = np.max(all_x_combined)
x = (other, x1)
quals.append((-finite_loss(np.inf, x), x, 1))
elif len(missing_bounds) == 2:
# add the bounds to quals
x1, x2 = sorted(missing_bounds) # sort just to be sure
all_x_combined = list(set(self.data.keys()) | self.pending_points)
if len(all_x_combined) == 0:
# XXX: shortcut for performance, just return linspace
x = x1, x2
quals.append((-finite_loss(np.inf, x), x, 1))
else:
left_interval = (x1, np.min(all_x_combined))
right_interval = (np.max(all_x_combined), x2)
quals.append((-finite_loss(np.inf, right_interval), right_interval, 1))
quals.append((-finite_loss(np.inf, left_interval), left_interval, 1))
# Calculate how many points belong to each interval.
heapq.heapify(quals)
for _point_number in range(n - len(missing_bounds)):
quality, x, n = quals[0]
if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
# The interval is too small and should not be subdivided
quality = np.inf
heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
points = list(itertools.chain.from_iterable(
xs(*x, n) for quality, x, n in quals))
loss_improvements = list(itertools.chain.from_iterable(
itertools.repeat(-quality, n - 1)
for quality, x, n in quals))
points = missing_bounds + points
loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
if add_data:
self.tell_many(points, itertools.repeat(None))
Loading