Skip to content
Snippets Groups Projects

Resolve "(BalancingLearner) loss is cached incorrectly"

Merged Jorn Hoofwijk requested to merge 108-balancinglearner-loss-is-cached-incorrectly into master
All threads resolved!
Files
2
@@ -56,6 +56,7 @@ class BalancingLearner(BaseLearner):
self._points = {}
self._loss = {}
self._pending_loss = {}
self._cdims_default = cdims
if len(set(learner.__class__ for learner in self.learners)) > 1:
@@ -95,6 +96,7 @@ class BalancingLearner(BaseLearner):
index, x = x
self._points.pop(index, None)
self._loss.pop(index, None)
self._pending_loss.pop(index, None)
self.learners[index].tell(x, y)
def tell_pending(self, x):
@@ -103,13 +105,19 @@ class BalancingLearner(BaseLearner):
self._loss.pop(index, None)
self.learners[index].tell_pending(x)
def loss(self, real=True):
def losses(self, real=True):
losses = []
loss_dict = self._loss if real else self._pending_loss
for index, learner in enumerate(self.learners):
if index not in self._loss:
self._loss[index] = learner.loss(real)
loss = self._loss[index]
losses.append(loss)
if index not in loss_dict:
loss_dict[index] = learner.loss(real)
losses.append(loss_dict[index])
return losses
def loss(self, real=True):
losses = self.losses(real)
return max(losses)
def plot(self, cdims=None, plotter=None):
Loading