Skip to content
Snippets Groups Projects
Commit 217c356a authored by Bas Nijholt's avatar Bas Nijholt
Browse files

2D: fix loss_improvements

parent 8011b08e
No related branches found
No related tags found
1 merge request!7implement 2D learner
......@@ -5,7 +5,7 @@ from copy import deepcopy as copy
import functools
import heapq
import itertools
from math import sqrt, isinf
from math import sqrt, isinf, hypot
from operator import itemgetter
import holoviews as hv
......@@ -622,6 +622,7 @@ class Learner2D(BaseLearner):
# XXX: Remove this once we correctly implemented the loss_improvements
self._loss_improvements = []
self.tri_combined = None
self._loss = np.inf
# Keeps track till which index _points and _values are filled
self.n = 0
......@@ -736,7 +737,7 @@ class Learner2D(BaseLearner):
return dev * vol
def tri_radius(self, points):
center = points.mean(axis=-2) / (self.ndim + 1)
center = points.mean(axis=-2)
return np.linalg.norm(points - center, axis=1).max()
def _fill_stack(self, stack_till=None):
......@@ -791,12 +792,10 @@ class Learner2D(BaseLearner):
point_new = _max_disagreement_location_in_simplex(
p, v, g, transform)
# XXX: scale dev[jsimplex] by max(z) - min(z) and tri_radius by bounds diagonal
z_scale = self.values_combined.max() - self.values_combined.min()
x, y = self.bounds
xy_scale = (x[1] - x[0])**2 + (y[1] - y[0])**2
loss_improvement = sqrt((dev[jsimplex] / z_scale)**2 +
self.tri_radius(p)**2 / xy_scale)
xy_scale = hypot(x[1]-x[0], y[1]-y[0])
loss_improvement = hypot(dev[jsimplex] / (v.max() - v.min()),
self.tri_radius(p) / xy_scale)
# Reduce to bounds
point_new = np.clip(point_new, *zip(*self.bounds))
......@@ -842,8 +841,10 @@ class Learner2D(BaseLearner):
loss_improvements += new_loss_improvements
self.add_data(new_points, itertools.repeat(None))
n_left -= len(new_points)
# XXX: Remove this once we correctly implemented the loss_improvements
self._loss_improvements += loss_improvements
self._loss = min(self._loss, max(loss_improvements))
return points, loss_improvements
def choose_points(self, n, add_data=True):
......@@ -854,8 +855,8 @@ class Learner2D(BaseLearner):
return self._choose_and_add_points(n)
def loss(self, real=True):
# XXX: we need a smarter way of determining the loss
return self.n_real
# XXX: currently the loss is set before the result of the point is known.
return self._loss
def remove_unfinished(self):
n_real = self.n_real
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment