Skip to content
Snippets Groups Projects
Commit 6d7d8e58 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

1D: implement returning of n points for empty learner

parent f7c63a26
No related branches found
No related tags found
1 merge request!4Implement BalancingLearner
...@@ -212,7 +212,7 @@ class Learner1D(BaseLearner): ...@@ -212,7 +212,7 @@ class Learner1D(BaseLearner):
self._oldscale = copy(self._scale) self._oldscale = copy(self._scale)
self.bounds = list(bounds) self.bounds = list(bounds)
self._included_bounds = list(bounds) self._include_bounds = list(bounds)
@property @property
def data_combined(self): def data_combined(self):
...@@ -301,9 +301,9 @@ class Learner1D(BaseLearner): ...@@ -301,9 +301,9 @@ class Learner1D(BaseLearner):
def add_point(self, x, y): def add_point(self, x, y):
real = y is not None real = y is not None
# Remove the point from _included_bounds # Remove the point from _include_bounds
if real and x in self._included_bounds: if x in self._include_bounds:
self._included_bounds.remove(x) self._include_bounds.remove(x)
if real: if real:
# Add point to the real data dict and pop from the unfinished # Add point to the real data dict and pop from the unfinished
...@@ -359,10 +359,13 @@ class Learner1D(BaseLearner): ...@@ -359,10 +359,13 @@ class Learner1D(BaseLearner):
if n == 0: if n == 0:
return [] return []
for bound in self._included_bounds: for bound in self._include_bounds:
if bound not in self.data_combined: if bound not in self.data_combined:
n = min(n, len(self._included_bounds)) bounds = self._include_bounds[:min(n, len(self._include_bounds))]
return self._included_bounds[:n] if n <= 2:
return bounds
else:
return np.linspace(*bounds, n)
def points(x, n): def points(x, n):
return list(np.linspace(x[0], x[1], n, endpoint=False)[1:]) return list(np.linspace(x[0], x[1], n, endpoint=False)[1:])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment