Skip to content
Snippets Groups Projects

Resolve "(BalancingLearner) loss is cached incorrectly"

Merged Jorn Hoofwijk requested to merge 108-balancinglearner-loss-is-cached-incorrectly into master
All threads resolved!
1 file
+ 13
5
Compare changes
  • Side-by-side
  • Inline
@@ -56,6 +56,7 @@ class BalancingLearner(BaseLearner):
@@ -56,6 +56,7 @@ class BalancingLearner(BaseLearner):
self._points = {}
self._points = {}
self._loss = {}
self._loss = {}
 
self._pending_loss = {}
self._cdims_default = cdims
self._cdims_default = cdims
if len(set(learner.__class__ for learner in self.learners)) > 1:
if len(set(learner.__class__ for learner in self.learners)) > 1:
@@ -95,6 +96,7 @@ class BalancingLearner(BaseLearner):
@@ -95,6 +96,7 @@ class BalancingLearner(BaseLearner):
index, x = x
index, x = x
self._points.pop(index, None)
self._points.pop(index, None)
self._loss.pop(index, None)
self._loss.pop(index, None)
 
self._pending_loss.pop(index, None)
self.learners[index].tell(x, y)
self.learners[index].tell(x, y)
def tell_pending(self, x):
def tell_pending(self, x):
@@ -103,13 +105,19 @@ class BalancingLearner(BaseLearner):
@@ -103,13 +105,19 @@ class BalancingLearner(BaseLearner):
self._loss.pop(index, None)
self._loss.pop(index, None)
self.learners[index].tell_pending(x)
self.learners[index].tell_pending(x)
def loss(self, real=True):
def losses(self, real=True):
losses = []
losses = []
 
loss_dict = self._loss if real else self._pending_loss
 
for index, learner in enumerate(self.learners):
for index, learner in enumerate(self.learners):
if index not in self._loss:
if index not in loss_dict:
self._loss[index] = learner.loss(real)
loss_dict[index] = learner.loss(real)
loss = self._loss[index]
losses.append(loss_dict[index])
losses.append(loss)
 
return losses
 
 
def loss(self, real=True):
 
losses = self.losses(real)
return max(losses)
return max(losses)
def plot(self, cdims=None, plotter=None):
def plot(self, cdims=None, plotter=None):
Loading