Skip to content
Snippets Groups Projects
Commit fdc31b7b authored by Bas Nijholt's avatar Bas Nijholt
Browse files

1D: swap xs <--> points as variable names to be more consistent

parent 3286d4b7
No related branches found
No related tags found
1 merge request!7implement 2D learner
......@@ -326,19 +326,19 @@ class Learner1D(BaseLearner):
return []
# If the bounds have not been chosen yet, we choose them first.
xs = []
points = []
for bound in self.bounds:
if bound not in self.data and bound not in self.data_interp:
xs.append(bound)
points.append(bound)
# Ensure we return exactly 'n' points.
if xs:
if points:
loss_improvements = [float('inf')] * n
if n <= 2:
return xs[:n], loss_improvements
return points[:n], loss_improvements
else:
return np.linspace(*self.bounds, n), loss_improvements
def points(x, n):
def xs(x, n):
if n == 1:
return []
else:
......@@ -355,8 +355,8 @@ class Learner1D(BaseLearner):
quality, x, n = quals[0]
heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1))
xs = list(itertools.chain.from_iterable(points(x, n)
for quality, x, n in quals))
points = list(itertools.chain.from_iterable(xs(x, n)
for quality, x, n in quals))
loss_improvements = list(itertools.chain.from_iterable(
itertools.repeat(-quality, n)
......@@ -365,7 +365,7 @@ class Learner1D(BaseLearner):
if add_data:
self.add_data(points, itertools.repeat(None))
return xs, loss_improvements
return points, loss_improvements
def interpolate(self, extra_points=None):
xs = list(self.data.keys())
......@@ -472,7 +472,7 @@ class BalancingLearner(BaseLearner):
def _max_disagreement_location_in_simplex(points, values, grad, transform):
"""Find the point of maximum disagreement between linear and quadratic model
"""Find the point of maximum disagreement between linear and quadratic model.
Parameters
----------
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment