Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (4)
......@@ -335,13 +335,13 @@ class BalancingLearner(BaseLearner):
-------
>>> def combo_fname(learner):
... val = learner.function.keywords # because functools.partial
... fname = '__'.join([f'{k}_{v}.pickle' for k, v in val])
... fname = '__'.join([f'{k}_{v}.pickle' for k, v in val.items()])
... return 'data_folder/' + fname
>>>
>>> def f(x, a, b): return a * x**2 + b
>>>
>>> learners = [Learner1D(functools.partial(f, **combo), (-1, 1))
... for combo in adaptive.utils.named_product(a=[1, 2], b=[1]]
... for combo in adaptive.utils.named_product(a=[1, 2], b=[1])]
>>>
>>> learner = BalancingLearner(learners)
>>> # Run the learner
......
......@@ -8,6 +8,7 @@ from collections import Iterable
import numpy as np
import sortedcontainers
import sortedcollections
from adaptive.learner.base_learner import BaseLearner
from adaptive.learner.learnerND import volume
......@@ -225,9 +226,6 @@ class Learner1D(BaseLearner):
self.loss_per_interval = loss_per_interval or default_loss
# A dict storing the loss function for each interval x_n.
self.losses = {}
self.losses_combined = {}
# When the scale changes by a factor 2, the losses are
# recomputed. This is tunable such that we can test
......@@ -249,6 +247,10 @@ class Learner1D(BaseLearner):
self._scale = [bounds[1] - bounds[0], 0]
self._oldscale = deepcopy(self._scale)
# A LossManager storing the loss function for each interval x_n.
self.losses = loss_manager(self._scale[0])
self.losses_combined = loss_manager(self._scale[0])
# The precision in 'x' below which we set losses to 0.
self._dx_eps = 2 * max(np.abs(bounds)) * np.finfo(float).eps
......@@ -284,7 +286,10 @@ class Learner1D(BaseLearner):
@cache_latest
def loss(self, real=True):
losses = self.losses if real else self.losses_combined
return max(losses.values()) if len(losses) > 0 else float('inf')
if not losses:
return np.inf
max_interval, max_loss = losses.peekitem(0)
return max_loss
def _scale_x(self, x):
if x is None:
......@@ -454,8 +459,7 @@ class Learner1D(BaseLearner):
# If the scale has increased enough, recompute all losses.
if self._scale[1] > self._recompute_losses_factor * self._oldscale[1]:
for interval in self.losses:
for interval in reversed(self.losses):
self._update_interpolated_loss_in_interval(*interval)
self._oldscale = deepcopy(self._scale)
......@@ -504,18 +508,18 @@ class Learner1D(BaseLearner):
for neighbors in (self.neighbors, self.neighbors_combined)]
# The the losses for the "real" intervals.
self.losses = {}
self.losses = loss_manager(self._scale[0])
for ival in intervals:
self.losses[ival] = self._get_loss_in_interval(*ival)
# List with "real" intervals that have interpolated intervals inside
to_interpolate = []
self.losses_combined = {}
self.losses_combined = loss_manager(self._scale[0])
for ival in intervals_combined:
# If this interval exists in 'losses' then copy it otherwise
# calculate it.
if ival in self.losses:
if ival in reversed(self.losses):
self.losses_combined[ival] = self.losses[ival]
else:
# Set all losses to inf now, later they might be udpdated if the
......@@ -530,7 +534,7 @@ class Learner1D(BaseLearner):
to_interpolate.append((x_left, x_right))
for ival in to_interpolate:
if ival in self.losses:
if ival in reversed(self.losses):
# If this interval does not exist it should already
# have an inf loss.
self._update_interpolated_loss_in_interval(*ival)
......@@ -566,64 +570,57 @@ class Learner1D(BaseLearner):
if len(missing_bounds) >= n:
return missing_bounds[:n], [np.inf] * n
def finite_loss(loss, xs):
# If the loss is infinite we return the
# distance between the two points.
if math.isinf(loss):
loss = (xs[1] - xs[0]) / self._scale[0]
# We round the loss to 12 digits such that losses
# are equal up to numerical precision will be considered
# equal.
return round(loss, ndigits=12)
quals = [(-finite_loss(loss, x), x, 1)
for x, loss in self.losses_combined.items()]
# Add bound intervals to quals if bounds were missing.
if len(self.data) + len(self.pending_points) == 0:
# We don't have any points, so return a linspace with 'n' points.
return np.linspace(*self.bounds, n).tolist(), [np.inf] * n
elif len(missing_bounds) > 0:
quals = loss_manager(self._scale[0])
if len(missing_bounds) > 0:
# There is at least one point in between the bounds.
all_points = list(self.data.keys()) + list(self.pending_points)
intervals = [(self.bounds[0], min(all_points)),
(max(all_points), self.bounds[1])]
for interval, bound in zip(intervals, self.bounds):
if bound in missing_bounds:
qual = (-finite_loss(np.inf, interval), interval, 1)
quals.append(qual)
# Calculate how many points belong to each interval.
points, loss_improvements = self._subdivide_quals(
quals, n - len(missing_bounds))
points = missing_bounds + points
loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
quals[(*interval, 1)] = np.inf
return points, loss_improvements
points_to_go = n - len(missing_bounds)
def _subdivide_quals(self, quals, n):
# Calculate how many points belong to each interval.
heapq.heapify(quals)
for _ in range(n):
quality, x, n = quals[0]
if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps:
# The interval is too small and should not be subdivided.
quality = np.inf
# XXX: see https://gitlab.kwant-project.org/qt/adaptive/issues/104
heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
i, i_max = 0, len(self.losses_combined)
for _ in range(points_to_go):
qual, loss_qual = quals.peekitem(0) if quals else (None, 0)
ival, loss_ival = self.losses_combined.peekitem(i) if i < i_max else (None, 0)
if (qual is None
or (ival is not None
and self._loss(self.losses_combined, ival)
>= self._loss(quals, qual))):
i += 1
quals[(*ival, 2)] = loss_ival / 2
else:
quals.pop(qual, None)
*xs, n = qual
quals[(*xs, n+1)] = loss_qual * n / (n+1)
points = list(itertools.chain.from_iterable(
linspace(*interval, n) for quality, interval, n in quals))
linspace(*ival, n) for (*ival, n) in quals))
loss_improvements = list(itertools.chain.from_iterable(
itertools.repeat(-quality, n - 1)
for quality, interval, n in quals))
itertools.repeat(quals[x0, x1, n], n - 1)
for (x0, x1, n) in quals))
# add the missing bounds
points = missing_bounds + points
loss_improvements = [np.inf] * len(missing_bounds) + loss_improvements
return points, loss_improvements
def _loss(self, mapping, ival):
loss = mapping[ival]
return finite_loss(ival, loss, self._scale[0])
def plot(self):
"""Returns a plot of the evaluated data.
......@@ -658,3 +655,44 @@ class Learner1D(BaseLearner):
def _set_data(self, data):
self.tell_many(*zip(*data.items()))
def _fix_deepcopy(sorted_dict, x_scale):
# XXX: until https://github.com/grantjenks/sortedcollections/issues/5 is fixed
import types
def __deepcopy__(self, memo):
items = deepcopy(list(self.items()))
lm = loss_manager(self.x_scale)
lm.update(items)
return lm
sorted_dict.x_scale = x_scale
sorted_dict.__deepcopy__ = types.MethodType(__deepcopy__, sorted_dict)
def loss_manager(x_scale):
def sort_key(ival, loss):
loss, ival = finite_loss(ival, loss, x_scale)
return -loss, ival
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
_fix_deepcopy(sorted_dict, x_scale)
return sorted_dict
def finite_loss(ival, loss, x_scale):
"""Get the socalled finite_loss of an interval in order to be able to
sort intervals that have infinite loss."""
# If the loss is infinite we return the
# distance between the two points.
if math.isinf(loss):
loss = (ival[1] - ival[0]) / x_scale
if len(ival) == 3:
# Used when constructing quals. Last item is
# the number of points inside the qual.
loss /= ival[2]
# We round the loss to 12 digits such that losses
# are equal up to numerical precision will be considered
# equal. This is 3.5x faster than unsing the `round` function.
round_fac = 1e12
loss = int(loss * round_fac + 0.5) / round_fac
return loss, ival
......@@ -6,6 +6,7 @@ channels:
dependencies:
- python=3.6
- sortedcontainers
- sortedcollections
- scipy
- holoviews
- ipyparallel
......
......@@ -26,6 +26,7 @@ version, cmdclass = get_version_and_cmdclass('adaptive')
install_requires = [
'scipy',
'sortedcollections',
'sortedcontainers',
]
......