From 78cd74eb86aa605d1e5066d61fa7d2cefb4d3a7f Mon Sep 17 00:00:00 2001 From: Bas Nijholt <basnijholt@gmail.com> Date: Wed, 19 Sep 2018 14:51:19 +0200 Subject: [PATCH] fix that the 'AverageLearner' only returns new points, and introduce 'learner.pending_points' --- adaptive/learner/average_learner.py | 42 ++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/adaptive/learner/average_learner.py b/adaptive/learner/average_learner.py index 0b0a9e4f..f9163071 100644 --- a/adaptive/learner/average_learner.py +++ b/adaptive/learner/average_learner.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- import itertools from math import sqrt +import warnings import numpy as np @@ -14,12 +15,19 @@ class AverageLearner(BaseLearner): The learned function must depend on an integer input variable that represents the source of randomness. - Parameters: - ----------- + Parameters + ---------- atol : float Desired absolute tolerance rtol : float Desired relative tolerance + + Attributes + ---------- + data : dict + Sampled points and values. + pending_points : set + Points that still have to be evaluated. """ def __init__(self, function, atol=None, rtol=None): @@ -31,6 +39,7 @@ class AverageLearner(BaseLearner): rtol = np.inf self.data = {} + self.pending_points = set() self.function = function self.atol = atol self.rtol = rtol @@ -40,28 +49,35 @@ class AverageLearner(BaseLearner): @property def n_requested(self): - return len(self.data) + return len(self.data) + len(self.pending_points) def ask(self, n, add_data=True): points = list(range(self.n_requested, self.n_requested + n)) + + if any(True for p in points if p in self.data or p in self.pending_points): + # This means some of the points `< self.n_requested` do not exist. + points = list(set(range(self.n_requested + n)) + - set(self.data) + - set(self.pending_points))[:n] + loss_improvements = [self.loss_improvement(n) / n] * n if add_data: self.tell_many(points, itertools.repeat(None)) return points, loss_improvements def tell(self, n, value): - value_is_new = not (n in self.data and value == self.data[n]) - if not value_is_new: - value_old = self.data[n] - self.data[n] = value - if value is not None: + if n in self.data: + warnings.warn(f"n={n} already exists in `learner.data` with value {self.data[n]}.") + return + + if value is None: + self.pending_points.add(n) + else: + self.data[n] = value + self.pending_points.discard(n) self.sum_f += value self.sum_f_sq += value**2 - if value_is_new: - self.npoints += 1 - else: - self.sum_f -= value_old - self.sum_f_sq -= value_old**2 + self.npoints += 1 @property def mean(self): -- GitLab