Skip to content
Snippets Groups Projects
Commit 340cd6aa authored by Joseph Weston's avatar Joseph Weston
Browse files

Merge branch 'average_learner_improvements' into 'master'

Make the AverageLearner only return new points, and ignore
subsequent updates to existing points (as in the rest of
adaptive).

See merge request !104
parents e378367d c1d0885d
No related branches found
No related tags found
1 merge request!104Make the AverageLearner only return new points ...
Pipeline #12180 passed
......@@ -14,12 +14,19 @@ class AverageLearner(BaseLearner):
The learned function must depend on an integer input variable that
represents the source of randomness.
Parameters:
-----------
Parameters
----------
atol : float
Desired absolute tolerance
rtol : float
Desired relative tolerance
Attributes
----------
data : dict
Sampled points and values.
pending_points : set
Points that still have to be evaluated.
"""
def __init__(self, function, atol=None, rtol=None):
......@@ -31,6 +38,7 @@ class AverageLearner(BaseLearner):
rtol = np.inf
self.data = {}
self.pending_points = set()
self.function = function
self.atol = atol
self.rtol = rtol
......@@ -40,28 +48,35 @@ class AverageLearner(BaseLearner):
@property
def n_requested(self):
return len(self.data)
return len(self.data) + len(self.pending_points)
def ask(self, n, add_data=True):
points = list(range(self.n_requested, self.n_requested + n))
if any(p in self.data or p in self.pending_points for p in points):
# This means some of the points `< self.n_requested` do not exist.
points = list(set(range(self.n_requested + n))
- set(self.data)
- set(self.pending_points))[:n]
loss_improvements = [self.loss_improvement(n) / n] * n
if add_data:
self.tell_many(points, itertools.repeat(None))
return points, loss_improvements
def tell(self, n, value):
value_is_new = not (n in self.data and value == self.data[n])
if not value_is_new:
value_old = self.data[n]
self.data[n] = value
if value is not None:
if n in self.data:
# The point has already been added before.
return
if value is None:
self.pending_points.add(n)
else:
self.data[n] = value
self.pending_points.discard(n)
self.sum_f += value
self.sum_f_sq += value**2
if value_is_new:
self.npoints += 1
else:
self.sum_f -= value_old
self.sum_f_sq -= value_old**2
self.npoints += 1
@property
def mean(self):
......@@ -94,7 +109,7 @@ class AverageLearner(BaseLearner):
def remove_unfinished(self):
"""Remove uncomputed data from the learner."""
pass
self.pending_points = set()
def plot(self):
hv = ensure_holoviews()
......
# -*- coding: utf-8 -*-
from ..learner import AverageLearner
def test_only_returns_new_points():
learner = AverageLearner(lambda x: x, atol=None, rtol=0.01)
# Only tell it n = 5...10
for i in range(5, 10):
learner.tell(i, 1)
learner.tell(0, None) # This means it shouldn't return 0 anymore
assert learner.ask(1)[0][0] == 1
assert learner.ask(1)[0][0] == 2
assert learner.ask(1)[0][0] == 3
assert learner.ask(1)[0][0] == 4
assert learner.ask(1)[0][0] == 10
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment