Skip to content
Snippets Groups Projects

Make the AverageLearner only return new points ...

Merged Bas Nijholt requested to merge average_learner_improvements into master
All threads resolved!
1 file
+ 29
13
Compare changes
  • Side-by-side
  • Inline
# -*- coding: utf-8 -*-
import itertools
from math import sqrt
import warnings
import numpy as np
@@ -14,12 +15,19 @@ class AverageLearner(BaseLearner):
The learned function must depend on an integer input variable that
represents the source of randomness.
Parameters:
-----------
Parameters
----------
atol : float
Desired absolute tolerance
rtol : float
Desired relative tolerance
Attributes
----------
data : dict
Sampled points and values.
pending_points : set
Points that still have to be evaluated.
"""
def __init__(self, function, atol=None, rtol=None):
@@ -31,6 +39,7 @@ class AverageLearner(BaseLearner):
rtol = np.inf
self.data = {}
self.pending_points = set()
self.function = function
self.atol = atol
self.rtol = rtol
@@ -40,28 +49,35 @@ class AverageLearner(BaseLearner):
@property
def n_requested(self):
return len(self.data)
return len(self.data) + len(self.pending_points)
def ask(self, n, add_data=True):
points = list(range(self.n_requested, self.n_requested + n))
if any(True for p in points if p in self.data or p in self.pending_points):
# This means some of the points `< self.n_requested` do not exist.
points = list(set(range(self.n_requested + n))
- set(self.data)
- set(self.pending_points))[:n]
loss_improvements = [self.loss_improvement(n) / n] * n
if add_data:
self.tell_many(points, itertools.repeat(None))
return points, loss_improvements
def tell(self, n, value):
value_is_new = not (n in self.data and value == self.data[n])
if not value_is_new:
value_old = self.data[n]
self.data[n] = value
if value is not None:
if n in self.data:
warnings.warn(f"n={n} already exists in `learner.data` with value {self.data[n]}.")
return
if value is None:
self.pending_points.add(n)
else:
self.data[n] = value
self.pending_points.discard(n)
self.sum_f += value
self.sum_f_sq += value**2
if value_is_new:
self.npoints += 1
else:
self.sum_f -= value_old
self.sum_f_sq -= value_old**2
self.npoints += 1
@property
def mean(self):
Loading