From a8866d4acfe8fa9a7ea69785ef341552aadb9974 Mon Sep 17 00:00:00 2001 From: Bas Nijholt <basnijholt@gmail.com> Date: Thu, 7 Sep 2017 10:49:15 +0200 Subject: [PATCH] remove _choose_points, closes #19 --- adaptive/learner.py | 35 ++++++++++------------------------- 1 file changed, 10 insertions(+), 25 deletions(-) diff --git a/adaptive/learner.py b/adaptive/learner.py index 8ff79804..ff9af043 100644 --- a/adaptive/learner.py +++ b/adaptive/learner.py @@ -72,6 +72,7 @@ class BaseLearner(metaclass=abc.ABCMeta): (possibly by interpolation). """ + @abc.abstractmethod def choose_points(self, n, add_data=True): """Choose the next 'n' points to evaluate. @@ -85,22 +86,7 @@ class BaseLearner(metaclass=abc.ABCMeta): values. Set this to False if you do not want to modify the state of the learner. """ - points, loss_improvements = self._choose_points(n) - if add_data: - self.add_data(points, itertools.repeat(None)) - return points, loss_improvements - - @abc.abstractmethod - def _choose_points(self, n): - """Choose the next 'n' points to evaluate. - - Should be overridden by subclasses. - - Parameters - ---------- - n : int - The number of points to choose. - """ + pass def __getstate__(self): return copy(self.__dict__) @@ -139,9 +125,11 @@ class AverageLearner(BaseLearner): self.sum_f = 0 self.sum_f_sq = 0 - def _choose_points(self, n=10): + def choose_points(self, n=10, add_data=True): points = list(range(self.n_requested, self.n_requested + n)) loss_improvements = [None] * n + if add_data: + self.add_data(points, itertools.repeat(None)) return points, loss_improvements def add_point(self, n, value): @@ -324,7 +312,7 @@ class Learner1D(BaseLearner): self._oldscale = self._scale - def _choose_points(self, n=10): + def choose_points(self, n=10, add_data=True): """Return n points that are expected to maximally reduce the loss.""" # Find out how to divide the n points over the intervals # by finding positive integer n_i that minimize max(L_i / n_i) subject @@ -373,7 +361,10 @@ class Learner1D(BaseLearner): itertools.repeat(-quality, n) for quality, x, n in quals)) - return (xs, loss_improvements) + if add_data: + self.add_data(points, itertools.repeat(None)) + + return xs, loss_improvements def interpolate(self, extra_points=None): xs = list(self.data.keys()) @@ -462,9 +453,6 @@ class BalancingLearner(BaseLearner): else: return self._choose_and_add_points(n) - def _choose_points(self, n): - pass - def add_point(self, x, y): index, x = x self.learners[index].add_point(x, y) @@ -793,9 +781,6 @@ class Learner2D(BaseLearner): else: dev[jsimplex] = 0 - def _choose_points(self, n): - pass - def _choose_and_add_points(self, n): if n <= len(self._stack): points = self._stack[:n] -- GitLab