From f761e74f7a048b425861b841d7af8b1229da57bb Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Thu, 20 Sep 2018 12:52:16 +0200
Subject: [PATCH] AverageLearner: create 'tell_pending' which deprecates
 'tell(x, None)'

---
 adaptive/learner/average_learner.py    | 23 ++++++++++++-----------
 adaptive/tests/test_average_learner.py |  2 +-
 2 files changed, 13 insertions(+), 12 deletions(-)

diff --git a/adaptive/learner/average_learner.py b/adaptive/learner/average_learner.py
index 0b72b32f..d9832b11 100644
--- a/adaptive/learner/average_learner.py
+++ b/adaptive/learner/average_learner.py
@@ -50,7 +50,7 @@ class AverageLearner(BaseLearner):
     def n_requested(self):
         return len(self.data) + len(self.pending_points)
 
-    def ask(self, n, add_data=True):
+    def ask(self, n, tell_pending=True):
         points = list(range(self.n_requested, self.n_requested + n))
 
         if any(p in self.data or p in self.pending_points for p in points):
@@ -60,8 +60,9 @@ class AverageLearner(BaseLearner):
                           - set(self.pending_points))[:n]
 
         loss_improvements = [self.loss_improvement(n) / n] * n
-        if add_data:
-            self.tell_many(points, itertools.repeat(None))
+        if tell_pending:
+            for p in points:
+                self.tell_pending(p)
         return points, loss_improvements
 
     def tell(self, n, value):
@@ -69,14 +70,14 @@ class AverageLearner(BaseLearner):
             # The point has already been added before.
             return
 
-        if value is None:
-            self.pending_points.add(n)
-        else:
-            self.data[n] = value
-            self.pending_points.discard(n)
-            self.sum_f += value
-            self.sum_f_sq += value**2
-            self.npoints += 1
+        self.data[n] = value
+        self.pending_points.discard(n)
+        self.sum_f += value
+        self.sum_f_sq += value**2
+        self.npoints += 1
+
+    def tell_pending(self, n):
+        self.pending_points.add(n)
 
     @property
     def mean(self):
diff --git a/adaptive/tests/test_average_learner.py b/adaptive/tests/test_average_learner.py
index 23b17de8..b652cc85 100644
--- a/adaptive/tests/test_average_learner.py
+++ b/adaptive/tests/test_average_learner.py
@@ -10,7 +10,7 @@ def test_only_returns_new_points():
     for i in range(5, 10):
         learner.tell(i, 1)
 
-    learner.tell(0, None)  # This means it shouldn't return 0 anymore
+    learner.tell_pending(0)  # This means it shouldn't return 0 anymore
 
     assert learner.ask(1)[0][0] == 1
     assert learner.ask(1)[0][0] == 2
-- 
GitLab