From f761e74f7a048b425861b841d7af8b1229da57bb Mon Sep 17 00:00:00 2001 From: Bas Nijholt <basnijholt@gmail.com> Date: Thu, 20 Sep 2018 12:52:16 +0200 Subject: [PATCH] AverageLearner: create 'tell_pending' which deprecates 'tell(x, None)' --- adaptive/learner/average_learner.py | 23 ++++++++++++----------- adaptive/tests/test_average_learner.py | 2 +- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/adaptive/learner/average_learner.py b/adaptive/learner/average_learner.py index 0b72b32f..d9832b11 100644 --- a/adaptive/learner/average_learner.py +++ b/adaptive/learner/average_learner.py @@ -50,7 +50,7 @@ class AverageLearner(BaseLearner): def n_requested(self): return len(self.data) + len(self.pending_points) - def ask(self, n, add_data=True): + def ask(self, n, tell_pending=True): points = list(range(self.n_requested, self.n_requested + n)) if any(p in self.data or p in self.pending_points for p in points): @@ -60,8 +60,9 @@ class AverageLearner(BaseLearner): - set(self.pending_points))[:n] loss_improvements = [self.loss_improvement(n) / n] * n - if add_data: - self.tell_many(points, itertools.repeat(None)) + if tell_pending: + for p in points: + self.tell_pending(p) return points, loss_improvements def tell(self, n, value): @@ -69,14 +70,14 @@ class AverageLearner(BaseLearner): # The point has already been added before. return - if value is None: - self.pending_points.add(n) - else: - self.data[n] = value - self.pending_points.discard(n) - self.sum_f += value - self.sum_f_sq += value**2 - self.npoints += 1 + self.data[n] = value + self.pending_points.discard(n) + self.sum_f += value + self.sum_f_sq += value**2 + self.npoints += 1 + + def tell_pending(self, n): + self.pending_points.add(n) @property def mean(self): diff --git a/adaptive/tests/test_average_learner.py b/adaptive/tests/test_average_learner.py index 23b17de8..b652cc85 100644 --- a/adaptive/tests/test_average_learner.py +++ b/adaptive/tests/test_average_learner.py @@ -10,7 +10,7 @@ def test_only_returns_new_points(): for i in range(5, 10): learner.tell(i, 1) - learner.tell(0, None) # This means it shouldn't return 0 anymore + learner.tell_pending(0) # This means it shouldn't return 0 anymore assert learner.ask(1)[0][0] == 1 assert learner.ask(1)[0][0] == 2 -- GitLab