Skip to content
Snippets Groups Projects
Commit a8866d4a authored by Bas Nijholt's avatar Bas Nijholt
Browse files

remove _choose_points, closes #19

parent afd9dbb8
No related branches found
No related tags found
1 merge request!7implement 2D learner
......@@ -72,6 +72,7 @@ class BaseLearner(metaclass=abc.ABCMeta):
(possibly by interpolation).
"""
@abc.abstractmethod
def choose_points(self, n, add_data=True):
"""Choose the next 'n' points to evaluate.
......@@ -85,22 +86,7 @@ class BaseLearner(metaclass=abc.ABCMeta):
values. Set this to False if you do not
want to modify the state of the learner.
"""
points, loss_improvements = self._choose_points(n)
if add_data:
self.add_data(points, itertools.repeat(None))
return points, loss_improvements
@abc.abstractmethod
def _choose_points(self, n):
"""Choose the next 'n' points to evaluate.
Should be overridden by subclasses.
Parameters
----------
n : int
The number of points to choose.
"""
pass
def __getstate__(self):
return copy(self.__dict__)
......@@ -139,9 +125,11 @@ class AverageLearner(BaseLearner):
self.sum_f = 0
self.sum_f_sq = 0
def _choose_points(self, n=10):
def choose_points(self, n=10, add_data=True):
points = list(range(self.n_requested, self.n_requested + n))
loss_improvements = [None] * n
if add_data:
self.add_data(points, itertools.repeat(None))
return points, loss_improvements
def add_point(self, n, value):
......@@ -324,7 +312,7 @@ class Learner1D(BaseLearner):
self._oldscale = self._scale
def _choose_points(self, n=10):
def choose_points(self, n=10, add_data=True):
"""Return n points that are expected to maximally reduce the loss."""
# Find out how to divide the n points over the intervals
# by finding positive integer n_i that minimize max(L_i / n_i) subject
......@@ -373,7 +361,10 @@ class Learner1D(BaseLearner):
itertools.repeat(-quality, n)
for quality, x, n in quals))
return (xs, loss_improvements)
if add_data:
self.add_data(points, itertools.repeat(None))
return xs, loss_improvements
def interpolate(self, extra_points=None):
xs = list(self.data.keys())
......@@ -462,9 +453,6 @@ class BalancingLearner(BaseLearner):
else:
return self._choose_and_add_points(n)
def _choose_points(self, n):
pass
def add_point(self, x, y):
index, x = x
self.learners[index].add_point(x, y)
......@@ -793,9 +781,6 @@ class Learner2D(BaseLearner):
else:
dev[jsimplex] = 0
def _choose_points(self, n):
pass
def _choose_and_add_points(self, n):
if n <= len(self._stack):
points = self._stack[:n]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment