Skip to content
Snippets Groups Projects
Commit faf901b5 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

BalancingLearner: create 'tell_pending' which deprecates 'tell(x, None)'

parent ccbdd095
Branches
No related tags found
No related merge requests found
This commit is part of merge request !107. Comments created here will be created in the context of that merge request.
...@@ -71,7 +71,7 @@ class BalancingLearner(BaseLearner): ...@@ -71,7 +71,7 @@ class BalancingLearner(BaseLearner):
for index, learner in enumerate(self.learners): for index, learner in enumerate(self.learners):
if index not in self._points: if index not in self._points:
self._points[index] = learner.ask( self._points[index] = learner.ask(
n=1, add_data=False) n=1, tell_pending=False)
point, loss_improvement = self._points[index] point, loss_improvement = self._points[index]
improvements_per_learner.append(loss_improvement[0]) improvements_per_learner.append(loss_improvement[0])
pairs.append((index, point[0])) pairs.append((index, point[0]))
...@@ -79,13 +79,13 @@ class BalancingLearner(BaseLearner): ...@@ -79,13 +79,13 @@ class BalancingLearner(BaseLearner):
key=itemgetter(1)) key=itemgetter(1))
points.append(x) points.append(x)
loss_improvements.append(l) loss_improvements.append(l)
self.tell(x, None) self.tell_pending(x)
return points, loss_improvements return points, loss_improvements
def ask(self, n, add_data=True): def ask(self, n, tell_pending=True):
"""Chose points for learners.""" """Chose points for learners."""
if not add_data: if not tell_pending:
with restore(*self.learners): with restore(*self.learners):
return self._ask_and_tell(n) return self._ask_and_tell(n)
else: else:
...@@ -97,6 +97,12 @@ class BalancingLearner(BaseLearner): ...@@ -97,6 +97,12 @@ class BalancingLearner(BaseLearner):
self._loss.pop(index, None) self._loss.pop(index, None)
self.learners[index].tell(x, y) self.learners[index].tell(x, y)
def tell_pending(self, x):
index, x = x
self._points.pop(index, None)
self._loss.pop(index, None)
self.learners[index].tell_pending(x)
def loss(self, real=True): def loss(self, real=True):
losses = [] losses = []
for index, learner in enumerate(self.learners): for index, learner in enumerate(self.learners):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment