Skip to content
Snippets Groups Projects
Commit f7c63a26 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

1D: change update losses and neighbors methods to take the data instead of acting inplace

parent 8b00f986
No related branches found
No related tags found
1 merge request!4Implement BalancingLearner
...@@ -234,17 +234,12 @@ class Learner1D(BaseLearner): ...@@ -234,17 +234,12 @@ class Learner1D(BaseLearner):
def loss(self, real=True): def loss(self, real=True):
losses = self.losses if real else self.losses_combined losses = self.losses if real else self.losses_combined
if len(losses) == 0: if len(losses) == 0:
return float('inf') return float('inf')
else: else:
return max(losses.values()) return max(losses.values())
def update_losses(self, x, real): def get_losses(self, x, data, neighbors, losses):
losses = self.losses if real else self.losses_combined
neighbors = self.neighbors if real else self.neighbors_combined
data = self.data if real else self.data_combined
x_lower, x_upper = neighbors[x] x_lower, x_upper = neighbors[x]
if x_lower is not None: if x_lower is not None:
losses[x_lower, x] = self.interval_loss(x_lower, x, data) losses[x_lower, x] = self.interval_loss(x_lower, x, data)
...@@ -254,29 +249,20 @@ class Learner1D(BaseLearner): ...@@ -254,29 +249,20 @@ class Learner1D(BaseLearner):
del losses[x_lower, x_upper] del losses[x_lower, x_upper]
except KeyError: except KeyError:
pass pass
return losses
def loss_improvement(self, points): def loss_improvement(self, points):
current_loss = self.loss(real=False) current_loss = self.loss(real=False)
data_interp = self.interpolate(points) data_interp = self.interpolate(points)
data = {**self.data_combined, **data_interp} data = {**self.data_combined, **data_interp}
# Create a new neighbors and losses dict # Create a new losses and neighbors dict
neighbors = copy(self.neighbors_combined) neighbors = copy(self.neighbors_combined)
losses = copy(self.losses_combined) losses = copy(self.losses_combined)
for x in points: for x in points:
x_lower, x_upper = self.find_neighbors(x, neighbors) neighbors = self.get_neighbors(x, neighbors, check_if_new=False)
neighbors[x] = [x_lower, x_upper] losses = self.get_losses(x, data, neighbors, losses)
neighbors.get(x_lower, [None, None])[1] = x
neighbors.get(x_upper, [None, None])[0] = x
if x_lower is not None:
losses[x_lower, x] = self.interval_loss(x_lower, x, data)
if x_upper is not None:
losses[x, x_upper] = self.interval_loss(x, x_upper, data)
try:
del losses[x_lower, x_upper]
except KeyError:
pass
# Calculate the loss improvement # Calculate the loss improvement
if len(losses) == 0: if len(losses) == 0:
...@@ -294,13 +280,13 @@ class Learner1D(BaseLearner): ...@@ -294,13 +280,13 @@ class Learner1D(BaseLearner):
x_upper = neighbors.iloc[pos] if pos != len(neighbors) else None x_upper = neighbors.iloc[pos] if pos != len(neighbors) else None
return x_lower, x_upper return x_lower, x_upper
def update_neighbors(self, x, real): def get_neighbors(self, x, neighbors, check_if_new=True):
neighbors = self.neighbors if real else self.neighbors_combined if not check_if_new or x not in neighbors: # The point is new
if x not in neighbors: # The point is new
x_lower, x_upper = self.find_neighbors(x, neighbors) x_lower, x_upper = self.find_neighbors(x, neighbors)
neighbors[x] = [x_lower, x_upper] neighbors[x] = [x_lower, x_upper]
neighbors.get(x_lower, [None, None])[1] = x neighbors.get(x_lower, [None, None])[1] = x
neighbors.get(x_upper, [None, None])[0] = x neighbors.get(x_upper, [None, None])[0] = x
return neighbors
def update_scale(self, x, y): def update_scale(self, x, y):
self._bbox[0][0] = min(self._bbox[0][0], x) self._bbox[0][0] = min(self._bbox[0][0], x)
...@@ -332,9 +318,9 @@ class Learner1D(BaseLearner): ...@@ -332,9 +318,9 @@ class Learner1D(BaseLearner):
self.data_interp[x] = None self.data_interp[x] = None
# Update the neighbors # Update the neighbors
self.update_neighbors(x, False) self.neighbors_combined = self.get_neighbors(x, self.neighbors_combined)
if real: if real:
self.update_neighbors(x, True) self.neighbors = self.get_neighbors(x, self.neighbors)
# Update the scale # Update the scale
self.update_scale(x, y) self.update_scale(x, y)
...@@ -344,9 +330,12 @@ class Learner1D(BaseLearner): ...@@ -344,9 +330,12 @@ class Learner1D(BaseLearner):
self.data_interp = self.interpolate() self.data_interp = self.interpolate()
# Update the losses # Update the losses
self.update_losses(x, False) self.losses_combined = self.get_losses(x, self.data_combined,
self.neighbors_combined,
self.losses_combined)
if real: if real:
self.update_losses(x, True) self.losses = self.get_losses(x, self.data, self.neighbors,
self.losses)
if real: if real:
# If the scale has doubled, recompute all losses. # If the scale has doubled, recompute all losses.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment