Skip to content
Snippets Groups Projects
Commit 8602e24f authored by Bas Nijholt's avatar Bas Nijholt
Browse files

make the 'BalancingLearner' return 'loss_improvements'

parent 93562ad3
No related branches found
No related tags found
1 merge request!94add runner.max_retries
......@@ -64,21 +64,24 @@ class BalancingLearner(BaseLearner):
def _ask_and_tell(self, n):
points = []
loss_improvements = []
for _ in range(n):
loss_improvements = []
improvements_per_learner = []
pairs = []
for index, learner in enumerate(self.learners):
if index not in self._points:
self._points[index] = learner.ask(
n=1, add_data=False)
point, loss_improvement = self._points[index]
loss_improvements.append(loss_improvement[0])
improvements_per_learner.append(loss_improvement[0])
pairs.append((index, point[0]))
x, _ = max(zip(pairs, loss_improvements), key=itemgetter(1))
x, l = max(zip(pairs, improvements_per_learner),
key=itemgetter(1))
points.append(x)
loss_improvements.append(l)
self.tell(x, None)
return points, None
return points, loss_improvements
def ask(self, n, add_data=True):
"""Chose points for learners."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment