Skip to content
Snippets Groups Projects
Commit 571bbbf4 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

2D: fix choose_points for arbitrary n

parent 3aa44f6b
No related branches found
No related tags found
1 merge request!7implement 2D learner
...@@ -695,10 +695,10 @@ class Learner2D(BaseLearner): ...@@ -695,10 +695,10 @@ class Learner2D(BaseLearner):
# Interpolate the unfinished points # Interpolate the unfinished points
if self._interp: if self._interp:
try: if self.n - len(self._interp) > 3:
ip = interpolate.LinearNDInterpolator(self.points_real, ip = interpolate.LinearNDInterpolator(self.points_real,
self.values_real) self.values_real)
except ValueError: else:
ip = lambda x: np.empty(len(x)) # Important not to return exact zeros ip = lambda x: np.empty(len(x)) # Important not to return exact zeros
n_interp = list(self._interp.values()) n_interp = list(self._interp.values())
values = ip(p[n_interp]) values = ip(p[n_interp])
...@@ -772,7 +772,7 @@ class Learner2D(BaseLearner): ...@@ -772,7 +772,7 @@ class Learner2D(BaseLearner):
continue continue
# Add to stack # Add to stack
self._stack.append(point_new.copy()) self._stack.append(tuple(point_new))
if len(self._stack) >= nstack: if len(self._stack) >= nstack:
break break
...@@ -783,16 +783,22 @@ class Learner2D(BaseLearner): ...@@ -783,16 +783,22 @@ class Learner2D(BaseLearner):
pass pass
def choose_points(self, n, add_data=True): def choose_points(self, n, add_data=True):
if len(self._stack) < n: if n <= len(self._stack):
if n > self.n: points = self._stack[:n]
raise NotImplementedError('Need to recursively fill up the stack') self.add_data(points, itertools.repeat(None))
else:
self._fill_stack(stack_till=max(n, self.nstack))
points = self._stack[:n]
if add_data:
for point in points:
self.add_point(point, None)
self._stack = self._stack[n:] self._stack = self._stack[n:]
else:
points = []
n_left = n
while n_left > 0:
if self.n >= 2**self.ndim:
self._fill_stack(stack_till=max(n_left, self.nstack))
from_stack = self._stack[:n_left]
points += from_stack
self.add_data(from_stack, itertools.repeat(None))
self._stack = self._stack[n_left:]
n_left -= len(from_stack)
loss_improvements = [1] * n loss_improvements = [1] * n
return points, loss_improvements return points, loss_improvements
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment