Skip to content
Snippets Groups Projects

make methods private in the LearnerND, closes #85

Merged Bas Nijholt requested to merge private_methods_learnernd into master
1 file
+ 15
28
Compare changes
  • Side-by-side
  • Inline
@@ -227,10 +227,10 @@ class LearnerND(BaseLearner):
def bounds_are_done(self):
return all(p in self.data for p in self._bounds_points)
def ip(self):
def _ip(self):
"""A `scipy.interpolate.LinearNDInterpolator` instance
containing the learner's data."""
# XXX: take our own triangulation into account when generating the ip
# XXX: take our own triangulation into account when generating the _ip
return interpolate.LinearNDInterpolator(self.points, self.values)
@property
@@ -242,7 +242,7 @@ class LearnerND(BaseLearner):
try:
self._tri = Triangulation(self.points)
self.update_losses(set(), self._tri.simplices)
self._update_losses(set(), self._tri.simplices)
return self._tri
except ValueError:
# A ValueError is raised if we do not have enough points or
@@ -283,7 +283,7 @@ class LearnerND(BaseLearner):
simplex = None
to_delete, to_add = tri.add_point(
point, simplex, transform=self._transform)
self.update_losses(to_delete, to_add)
self._update_losses(to_delete, to_add)
def _simplex_exists(self, simplex):
simplex = tuple(sorted(simplex))
@@ -441,7 +441,7 @@ class LearnerND(BaseLearner):
return self._ask_best_point() # O(log N)
def update_losses(self, to_delete: set, to_add: set):
def _update_losses(self, to_delete: set, to_add: set):
# XXX: add the points outside the triangulation to this as well
pending_points_unbound = set()
@@ -455,7 +455,7 @@ class LearnerND(BaseLearner):
if p not in self.data)
for simplex in to_add:
loss = self.compute_loss(simplex)
loss = self._compute_loss(simplex)
self._losses[simplex] = loss
for p in pending_points_unbound:
@@ -469,7 +469,7 @@ class LearnerND(BaseLearner):
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
def compute_loss(self, simplex):
def _compute_loss(self, simplex):
# get the loss
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
@@ -481,7 +481,7 @@ class LearnerND(BaseLearner):
# compute the loss on the scaled simplex
return float(self.loss_per_simplex(vertices, values))
def recompute_all_losses(self):
def _recompute_all_losses(self):
"""Recompute all losses and pending losses."""
# amortized O(N) complexity
if self.tri is None:
@@ -492,7 +492,7 @@ class LearnerND(BaseLearner):
# recompute all losses
for simplex in self.tri.simplices:
loss = self.compute_loss(simplex)
loss = self._compute_loss(simplex)
self._losses[simplex] = loss
# now distribute it around the the children if they are present
@@ -543,27 +543,14 @@ class LearnerND(BaseLearner):
scale_factor = np.max(np.nan_to_num(self._scale / self._old_scale))
if scale_factor > self._recompute_losses_factor:
self._old_scale = self._scale
self.recompute_all_losses()
self._recompute_all_losses()
return True
return False
def losses(self):
"""Get the losses of each simplex in the current triangulation, as dict
Returns
-------
losses : dict
the key is a simplex, the value is the loss of this simplex
"""
# XXX could be a property
if self.tri is None:
return dict()
return self._losses
@cache_latest
def loss(self, real=True):
losses = self.losses() # XXX: compute pending loss if real == False
# XXX: compute pending loss if real == False
losses = self._losses if self.tri is not None else dict()
return max(losses.values()) if losses else float('inf')
def remove_unfinished(self):
@@ -607,7 +594,7 @@ class LearnerND(BaseLearner):
xs = ys = np.linspace(0, 1, n)
xs = xs * (x[1] - x[0]) + x[0]
ys = ys * (y[1] - y[0]) + y[0]
z = self.ip()(xs[:, None], ys[None, :]).squeeze()
z = self._ip()(xs[:, None], ys[None, :]).squeeze()
im = hv.Image(np.rot90(z), bounds=lbrt)
@@ -656,7 +643,7 @@ class LearnerND(BaseLearner):
for i in range(self.ndim)]
ind = next(i for i in range(self.ndim) if i not in cut_mapping)
x = values[ind]
y = self.ip()(*values)
y = self._ip()(*values)
p = hv.Path((x, y))
# Plot with 5% margins such that the boundary points are visible
@@ -686,7 +673,7 @@ class LearnerND(BaseLearner):
lbrt = np.reshape(lbrt, (2, 2)).T.flatten().tolist()
if len(self.data) >= 4:
z = self.ip()(*values).squeeze()
z = self._ip()(*values).squeeze()
im = hv.Image(np.rot90(z), bounds=lbrt)
else:
im = hv.Image([], bounds=lbrt)
Loading