From 78cd74eb86aa605d1e5066d61fa7d2cefb4d3a7f Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Wed, 19 Sep 2018 14:51:19 +0200
Subject: [PATCH] fix that the 'AverageLearner' only returns new points, and
 introduce 'learner.pending_points'

---
 adaptive/learner/average_learner.py | 42 ++++++++++++++++++++---------
 1 file changed, 29 insertions(+), 13 deletions(-)

diff --git a/adaptive/learner/average_learner.py b/adaptive/learner/average_learner.py
index 0b0a9e4f..f9163071 100644
--- a/adaptive/learner/average_learner.py
+++ b/adaptive/learner/average_learner.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 import itertools
 from math import sqrt
+import warnings
 
 import numpy as np
 
@@ -14,12 +15,19 @@ class AverageLearner(BaseLearner):
     The learned function must depend on an integer input variable that
     represents the source of randomness.
 
-    Parameters:
-    -----------
+    Parameters
+    ----------
     atol : float
         Desired absolute tolerance
     rtol : float
         Desired relative tolerance
+
+    Attributes
+    ----------
+    data : dict
+        Sampled points and values.
+    pending_points : set
+        Points that still have to be evaluated.
     """
 
     def __init__(self, function, atol=None, rtol=None):
@@ -31,6 +39,7 @@ class AverageLearner(BaseLearner):
             rtol = np.inf
 
         self.data = {}
+        self.pending_points = set()
         self.function = function
         self.atol = atol
         self.rtol = rtol
@@ -40,28 +49,35 @@ class AverageLearner(BaseLearner):
 
     @property
     def n_requested(self):
-        return len(self.data)
+        return len(self.data) + len(self.pending_points)
 
     def ask(self, n, add_data=True):
         points = list(range(self.n_requested, self.n_requested + n))
+
+        if any(True for p in points if p in self.data or p in self.pending_points):
+            # This means some of the points `< self.n_requested` do not exist.
+            points = list(set(range(self.n_requested + n))
+                          - set(self.data)
+                          - set(self.pending_points))[:n]
+
         loss_improvements = [self.loss_improvement(n) / n] * n
         if add_data:
             self.tell_many(points, itertools.repeat(None))
         return points, loss_improvements
 
     def tell(self, n, value):
-        value_is_new = not (n in self.data and value == self.data[n])
-        if not value_is_new:
-            value_old = self.data[n]
-        self.data[n] = value
-        if value is not None:
+        if n in self.data:
+            warnings.warn(f"n={n} already exists in `learner.data` with value {self.data[n]}.")
+            return
+
+        if value is None:
+            self.pending_points.add(n)
+        else:
+            self.data[n] = value
+            self.pending_points.discard(n)
             self.sum_f += value
             self.sum_f_sq += value**2
-            if value_is_new:
-                self.npoints += 1
-            else:
-                self.sum_f -= value_old
-                self.sum_f_sq -= value_old**2
+            self.npoints += 1
 
     @property
     def mean(self):
-- 
GitLab