Skip to content
Snippets Groups Projects
Commit 571bbbf4 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

2D: fix choose_points for arbitrary n

parent 3aa44f6b
No related branches found
No related tags found
1 merge request!7implement 2D learner
......@@ -695,10 +695,10 @@ class Learner2D(BaseLearner):
# Interpolate the unfinished points
if self._interp:
try:
if self.n - len(self._interp) > 3:
ip = interpolate.LinearNDInterpolator(self.points_real,
self.values_real)
except ValueError:
else:
ip = lambda x: np.empty(len(x)) # Important not to return exact zeros
n_interp = list(self._interp.values())
values = ip(p[n_interp])
......@@ -772,7 +772,7 @@ class Learner2D(BaseLearner):
continue
# Add to stack
self._stack.append(point_new.copy())
self._stack.append(tuple(point_new))
if len(self._stack) >= nstack:
break
......@@ -783,16 +783,22 @@ class Learner2D(BaseLearner):
pass
def choose_points(self, n, add_data=True):
if len(self._stack) < n:
if n > self.n:
raise NotImplementedError('Need to recursively fill up the stack')
else:
self._fill_stack(stack_till=max(n, self.nstack))
points = self._stack[:n]
if add_data:
for point in points:
self.add_point(point, None)
if n <= len(self._stack):
points = self._stack[:n]
self.add_data(points, itertools.repeat(None))
self._stack = self._stack[n:]
else:
points = []
n_left = n
while n_left > 0:
if self.n >= 2**self.ndim:
self._fill_stack(stack_till=max(n_left, self.nstack))
from_stack = self._stack[:n_left]
points += from_stack
self.add_data(from_stack, itertools.repeat(None))
self._stack = self._stack[n_left:]
n_left -= len(from_stack)
loss_improvements = [1] * n
return points, loss_improvements
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment