Skip to content
Snippets Groups Projects
Commit 669cf077 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

Merge branch '126--speed-up-learner1D' into 'master'

Resolve "(Learner1D) improve time complexity"

Closes #126 and #104

See merge request !139
parents c0012a9a 4381a9e5
No related branches found
No related tags found
1 merge request!139Resolve "(Learner1D) improve time complexity"
Pipeline #13928 passed
......@@ -8,6 +8,7 @@ from collections import Iterable
import numpy as np
import sortedcontainers
import sortedcollections
from adaptive.learner.base_learner import BaseLearner
from adaptive.learner.learnerND import volume
......@@ -225,9 +226,6 @@ class Learner1D(BaseLearner):
self.loss_per_interval = loss_per_interval or default_loss
# A dict storing the loss function for each interval x_n.
self.losses = {}
self.losses_combined = {}
# When the scale changes by a factor 2, the losses are
# recomputed. This is tunable such that we can test
......@@ -249,6 +247,10 @@ class Learner1D(BaseLearner):
self._scale = [bounds[1] - bounds[0], 0]
self._oldscale = deepcopy(self._scale)
# A LossManager storing the loss function for each interval x_n.
self.losses = loss_manager(self._scale[0])
self.losses_combined = loss_manager(self._scale[0])
# The precision in 'x' below which we set losses to 0.
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
......@@ -284,7 +286,10 @@ class Learner1D(BaseLearner):
@cache_latest
def loss(self, real=True):
losses = self.losses if real else self.losses_combined
return max(losses.values()) if len(losses) > 0 else float('inf')
if not losses:
return np.inf
max_interval, max_loss = losses.peekitem(0)
return max_loss
def _scale_x(self, x):
if x is None:
......@@ -454,8 +459,7 @@ class Learner1D(BaseLearner):
# If the scale has increased enough, recompute all losses.
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
for interval in self.losses:
for interval in reversed(self.losses):
self._update_interpolated_loss_in_interval(*interval)
self._oldscale = deepcopy(self._scale)
......@@ -504,18 +508,18 @@ class Learner1D(BaseLearner):
for neighbors in (self.neighbors, self.neighbors_combined)]
# The the losses for the "real" intervals.
self.losses = {}
self.losses = loss_manager(self._scale[0])
for ival in intervals:
self.losses[ival] = self._get_loss_in_interval(*ival)
# List with "real" intervals that have interpolated intervals inside
to_interpolate = []
self.losses_combined = {}
self.losses_combined = loss_manager(self._scale[0])
for ival in intervals_combined:
# If this interval exists in 'losses' then copy it otherwise
# calculate it.
if ival in self.losses:
if ival in reversed(self.losses):
self.losses_combined[ival] = self.losses[ival]
else:
# Set all losses to inf now, later they might be udpdated if the
......@@ -530,7 +534,7 @@ class Learner1D(BaseLearner):
to_interpolate.append((x_left, x_right))
for ival in to_interpolate:
if ival in self.losses:
if ival in reversed(self.losses):
# If this interval does not exist it should already
# have an inf loss.
self._update_interpolated_loss_in_interval(*ival)
......@@ -566,64 +570,57 @@ class Learner1D(BaseLearner):
if len(missing_bounds) >= n:
return missing_bounds[:n], [np.inf] * n
def finite_loss(loss, xs):
# If the loss is infinite we return the
# distance between the two points.
if math.isinf(loss):
loss = (xs[1] - xs[0]) / self._scale[0]
# We round the loss to 12 digits such that losses
# are equal up to numerical precision will be considered
# equal.
return round(loss, ndigits=12)
quals = [(-finite_loss(loss, x), x, 1)
for x, loss in self.losses_combined.items()]
# Add bound intervals to quals if bounds were missing.
if len(self.data) + len(self.pending_points) == 0:
# We don't have any points, so return a linspace with 'n' points.
return np.linspace(*self.bounds, n).tolist(), [np.inf] * n
elif len(missing_bounds) > 0:
quals = loss_manager(self._scale[0])
if len(missing_bounds) > 0:
# There is at least one point in between the bounds.
all_points = list(self.data.keys()) + list(self.pending_points)
intervals = [(self.bounds[0], min(all_points)),
(max(all_points), self.bounds[1])]
for interval, bound in zip(intervals, self.bounds):
if bound in missing_bounds:
qual = (-finite_loss(np.inf, interval), interval, 1)
quals.append(qual)
# Calculate how many points belong to each interval.
points, loss_improvements = self._subdivide_quals(
quals, n - len(missing_bounds))
points = missing_bounds + points
loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
quals[(*interval, 1)] = np.inf
return points, loss_improvements
points_to_go = n - len(missing_bounds)
def _subdivide_quals(self, quals, n):
# Calculate how many points belong to each interval.
heapq.heapify(quals)
for _ in range(n):
quality, x, n = quals[0]
if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
# The interval is too small and should not be subdivided.
quality = np.inf
# XXX: see https://gitlab.kwant-project.org/qt/adaptive/issues/104
heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
i, i_max = 0, len(self.losses_combined)
for _ in range(points_to_go):
qual, loss_qual = quals.peekitem(0) if quals else (None, 0)
ival, loss_ival = self.losses_combined.peekitem(i) if i < i_max else (None, 0)
if (qual is None
or (ival is not None
and self._loss(self.losses_combined, ival)
>= self._loss(quals, qual))):
i += 1
quals[(*ival, 2)] = loss_ival / 2
else:
quals.pop(qual, None)
*xs, n = qual
quals[(*xs, n+1)] = loss_qual * n / (n+1)
points = list(itertools.chain.from_iterable(
linspace(*interval, n) for quality, interval, n in quals))
linspace(*ival, n) for (*ival, n) in quals))
loss_improvements = list(itertools.chain.from_iterable(
itertools.repeat(-quality, n - 1)
for quality, interval, n in quals))
itertools.repeat(quals[x0, x1, n], n - 1)
for (x0, x1, n) in quals))
# add the missing bounds
points = missing_bounds + points
loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
return points, loss_improvements
def _loss(self, mapping, ival):
loss = mapping[ival]
return finite_loss(ival, loss, self._scale[0])
def plot(self):
"""Returns a plot of the evaluated data.
......@@ -658,3 +655,42 @@ class Learner1D(BaseLearner):
def _set_data(self, data):
self.tell_many(*zip(*data.items()))
def _fix_deepcopy(sorted_dict, x_scale):
# XXX: until https://github.com/grantjenks/sortedcollections/issues/5 is fixed
import types
def __deepcopy__(self, memo):
items = deepcopy(list(self.items()))
lm = loss_manager(self.x_scale)
lm.update(items)
return lm
sorted_dict.x_scale = x_scale
sorted_dict.__deepcopy__ = types.MethodType(__deepcopy__, sorted_dict)
def loss_manager(x_scale):
def sort_key(ival, loss):
loss, ival = finite_loss(ival, loss, x_scale)
return -loss, ival
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
_fix_deepcopy(sorted_dict, x_scale)
return sorted_dict
def finite_loss(ival, loss, x_scale):
"""Get the socalled finite_loss of an interval in order to be able to
sort intervals that have infinite loss."""
# If the loss is infinite we return the
# distance between the two points.
if math.isinf(loss):
loss = (ival[1] - ival[0]) / x_scale
if len(ival) == 3:
# Used when constructing quals. Last item is
# the number of points inside the qual.
loss /= ival[2]
# We round the loss to 12 digits such that losses
# are equal up to numerical precision will be considered
# equal.
return round(loss, ndigits=12), ival
......@@ -6,6 +6,7 @@ channels:
dependencies:
- python=3.6
- sortedcontainers
- sortedcollections
- scipy
- holoviews
- ipyparallel
......
......@@ -26,6 +26,7 @@ version, cmdclass = get_version_and_cmdclass('adaptive')
install_requires = [
'scipy',
'sortedcollections',
'sortedcontainers',
]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment