Skip to content
Snippets Groups Projects

make a BalancingLearner strategy that compares the total loss rather than loss improvement

1 file
+ 24
2
Compare changes
  • Side-by-side
  • Inline
@@ -4,6 +4,8 @@ from contextlib import suppress
from functools import partial
from operator import itemgetter
import numpy as np
from .base_learner import BaseLearner
from ..notebook_integration import ensure_holoviews
from ..utils import restore, named_product
@@ -34,6 +36,10 @@ class BalancingLearner(BaseLearner):
>>> cdims = (['A', 'B'], itertools.product([True, False], [0, 1]))
>>> cdims = (['A', 'B'], [(True, 0), (True, 1),
... (False, 0), (False, 1)])
strategy : str, default 'loss_improvements'
The points that the 'BalancingLearner' choses can be either based on
the best 'loss_improvements' or the smallest total 'loss' of the
child learners.
Notes
-----
@@ -46,7 +52,7 @@ class BalancingLearner(BaseLearner):
undefined way.
"""
def __init__(self, learners, *, cdims=None):
def __init__(self, learners, *, cdims=None, strategy='loss_improvements'):
self.learners = learners
# Naively we would make 'function' a method, but this causes problems
@@ -63,7 +69,12 @@ class BalancingLearner(BaseLearner):
raise TypeError('A BalacingLearner can handle only one type'
'of learners.')
def _ask_and_tell(self, n):
if strategy == 'loss_improvements':
self._ask_and_tell = self._ask_and_tell_based_on_loss_improvements
elif strategy == 'loss':
self._ask_and_tell = self._ask_and_tell_based_on_loss
def _ask_and_tell_based_on_loss_improvements(self, n):
points = []
loss_improvements = []
for _ in range(n):
@@ -84,6 +95,17 @@ class BalancingLearner(BaseLearner):
return points, loss_improvements
def _ask_and_tell_based_on_loss(self, n):
points = []
loss_improvements = []
for _ in range(n):
losses = self.losses(real=False)
max_ind = np.argmax(losses)
xs, ls = self.learners[max_ind].ask(1)
points.append((max_ind, xs[0]))
loss_improvements.append(ls[0])
return points, loss_improvements
def ask(self, n, tell_pending=True):
"""Chose points for learners."""
if not tell_pending:
Loading