Skip to content
Snippets Groups Projects
Commit 962ba684 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

dedent common code in 'update_losses'

parent 7972db03
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !104. Comments created here will be created in the context of that merge request.
...@@ -147,22 +147,23 @@ class Learner1D(BaseLearner): ...@@ -147,22 +147,23 @@ class Learner1D(BaseLearner):
self.losses_combined[x_left, x_right] = loss self.losses_combined[x_left, x_right] = loss
def update_losses(self, x, real=True): def update_losses(self, x, real=True):
a, b = self.find_neighbors(x, self.neighbors_combined)
x_left, x_right = self.find_neighbors(x, self.neighbors)
losses_combined = self.losses_combined
if real: if real:
x_left, x_right = self.find_neighbors(x, self.neighbors)
self.update_interpolated_loss_in_interval(x_left, x) self.update_interpolated_loss_in_interval(x_left, x)
self.update_interpolated_loss_in_interval(x, x_right) self.update_interpolated_loss_in_interval(x, x_right)
self.losses.pop((x_left, x_right), None) self.losses.pop((x_left, x_right), None)
self.losses_combined.pop((x_left, x_right), None) self.losses_combined.pop((x_left, x_right), None)
a, b = self.find_neighbors(x, self.neighbors_combined) a, b = self.find_neighbors(x, self.neighbors_combined)
self.losses_combined.pop((a, b), None) self.losses_combined.pop((a, b), None)
if x_left is None and a is not None: if x_left is None and a is not None:
self.losses_combined[a, x] = float('inf') losses_combined[a, x] = float('inf')
if x_right is None and b is not None: if x_right is None and b is not None:
self.losses_combined[x, b] = float('inf') losses_combined[x, b] = float('inf')
else: else:
losses_combined = self.losses_combined
x_left, x_right = self.find_neighbors(x, self.neighbors)
a, b = self.find_neighbors(x, self.neighbors_combined)
if x_left is not None and x_right is not None: if x_left is not None and x_right is not None:
dx = x_right - x_left dx = x_right - x_left
loss = self.losses[x_left, x_right] loss = self.losses[x_left, x_right]
...@@ -174,7 +175,7 @@ class Learner1D(BaseLearner): ...@@ -174,7 +175,7 @@ class Learner1D(BaseLearner):
if b is not None: if b is not None:
losses_combined[x, b] = float('inf') losses_combined[x, b] = float('inf')
losses_combined.pop((a, b), None) losses_combined.pop((a, b), None)
def find_neighbors(self, x, neighbors): def find_neighbors(self, x, neighbors):
if x in neighbors: if x in neighbors:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment