Skip to content
Snippets Groups Projects
Commit 87eca332 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

dedent common code in 'update_losses'

parent 2a03f093
No related branches found
No related tags found
1 merge request!100Resolve "Learner1D doesn't correctly set the interpolated loss when a point is added"
......@@ -147,22 +147,23 @@ class Learner1D(BaseLearner):
self.losses_combined[x_left, x_right] = loss
def update_losses(self, x, real=True):
a, b = self.find_neighbors(x, self.neighbors_combined)
x_left, x_right = self.find_neighbors(x, self.neighbors)
losses_combined = self.losses_combined
if real:
x_left, x_right = self.find_neighbors(x, self.neighbors)
self.update_interpolated_loss_in_interval(x_left, x)
self.update_interpolated_loss_in_interval(x, x_right)
self.losses.pop((x_left, x_right), None)
self.losses_combined.pop((x_left, x_right), None)
a, b = self.find_neighbors(x, self.neighbors_combined)
self.losses_combined.pop((a, b), None)
if x_left is None and a is not None:
self.losses_combined[a, x] = float('inf')
losses_combined[a, x] = float('inf')
if x_right is None and b is not None:
self.losses_combined[x, b] = float('inf')
losses_combined[x, b] = float('inf')
else:
losses_combined = self.losses_combined
x_left, x_right = self.find_neighbors(x, self.neighbors)
a, b = self.find_neighbors(x, self.neighbors_combined)
if x_left is not None and x_right is not None:
dx = x_right - x_left
loss = self.losses[x_left, x_right]
......@@ -174,7 +175,7 @@ class Learner1D(BaseLearner):
if b is not None:
losses_combined[x, b] = float('inf')
losses_combined.pop((a, b), None)
losses_combined.pop((a, b), None)
def find_neighbors(self, x, neighbors):
if x in neighbors:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment