diff --git a/adaptive/learner/average_learner.py b/adaptive/learner/average_learner.py index 0b72b32f1dedac8ec4ce49b3efa853fa6e1f8444..d9832b110f6d2cc6f43babcdb5b9006b9129fff0 100644 --- a/adaptive/learner/average_learner.py +++ b/adaptive/learner/average_learner.py @@ -50,7 +50,7 @@ class AverageLearner(BaseLearner): def n_requested(self): return len(self.data) + len(self.pending_points) - def ask(self, n, add_data=True): + def ask(self, n, tell_pending=True): points = list(range(self.n_requested, self.n_requested + n)) if any(p in self.data or p in self.pending_points for p in points): @@ -60,8 +60,9 @@ class AverageLearner(BaseLearner): - set(self.pending_points))[:n] loss_improvements = [self.loss_improvement(n) / n] * n - if add_data: - self.tell_many(points, itertools.repeat(None)) + if tell_pending: + for p in points: + self.tell_pending(p) return points, loss_improvements def tell(self, n, value): @@ -69,14 +70,14 @@ class AverageLearner(BaseLearner): # The point has already been added before. return - if value is None: - self.pending_points.add(n) - else: - self.data[n] = value - self.pending_points.discard(n) - self.sum_f += value - self.sum_f_sq += value**2 - self.npoints += 1 + self.data[n] = value + self.pending_points.discard(n) + self.sum_f += value + self.sum_f_sq += value**2 + self.npoints += 1 + + def tell_pending(self, n): + self.pending_points.add(n) @property def mean(self): diff --git a/adaptive/tests/test_average_learner.py b/adaptive/tests/test_average_learner.py index 23b17de824f1e5ac1b716503a4538770b7dd9cfc..b652cc8549c1f4cf9b2a0fb0609beea2a09e6ec0 100644 --- a/adaptive/tests/test_average_learner.py +++ b/adaptive/tests/test_average_learner.py @@ -10,7 +10,7 @@ def test_only_returns_new_points(): for i in range(5, 10): learner.tell(i, 1) - learner.tell(0, None) # This means it shouldn't return 0 anymore + learner.tell_pending(0) # This means it shouldn't return 0 anymore assert learner.ask(1)[0][0] == 1 assert learner.ask(1)[0][0] == 2