Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
No results found
Show changes
Commits on Source (3)
......@@ -14,12 +14,19 @@ class AverageLearner(BaseLearner):
The learned function must depend on an integer input variable that
represents the source of randomness.
Parameters:
-----------
Parameters
----------
atol : float
Desired absolute tolerance
rtol : float
Desired relative tolerance
Attributes
----------
data : dict
Sampled points and values.
pending_points : set
Points that still have to be evaluated.
"""
def __init__(self, function, atol=None, rtol=None):
......@@ -31,6 +38,7 @@ class AverageLearner(BaseLearner):
rtol = np.inf
self.data = {}
self.pending_points = set()
self.function = function
self.atol = atol
self.rtol = rtol
......@@ -40,28 +48,35 @@ class AverageLearner(BaseLearner):
@property
def n_requested(self):
return len(self.data)
return len(self.data) + len(self.pending_points)
def ask(self, n, add_data=True):
points = list(range(self.n_requested, self.n_requested + n))
if any(p in self.data or p in self.pending_points for p in points):
# This means some of the points `< self.n_requested` do not exist.
points = list(set(range(self.n_requested + n))
- set(self.data)
- set(self.pending_points))[:n]
loss_improvements = [self.loss_improvement(n) / n] * n
if add_data:
self.tell_many(points, itertools.repeat(None))
return points, loss_improvements
def tell(self, n, value):
value_is_new = not (n in self.data and value == self.data[n])
if not value_is_new:
value_old = self.data[n]
self.data[n] = value
if value is not None:
if n in self.data:
# The point has already been added before.
return
if value is None:
self.pending_points.add(n)
else:
self.data[n] = value
self.pending_points.discard(n)
self.sum_f += value
self.sum_f_sq += value**2
if value_is_new:
self.npoints += 1
else:
self.sum_f -= value_old
self.sum_f_sq -= value_old**2
self.npoints += 1
@property
def mean(self):
......@@ -94,7 +109,7 @@ class AverageLearner(BaseLearner):
def remove_unfinished(self):
"""Remove uncomputed data from the learner."""
pass
self.pending_points = set()
def plot(self):
hv = ensure_holoviews()
......
# -*- coding: utf-8 -*-
from ..learner import AverageLearner
def test_only_returns_new_points():
learner = AverageLearner(lambda x: x, atol=None, rtol=0.01)
# Only tell it n = 5...10
for i in range(5, 10):
learner.tell(i, 1)
learner.tell(0, None) # This means it shouldn't return 0 anymore
assert learner.ask(1)[0][0] == 1
assert learner.ask(1)[0][0] == 2
assert learner.ask(1)[0][0] == 3
assert learner.ask(1)[0][0] == 4
assert learner.ask(1)[0][0] == 10