Skip to content
Snippets Groups Projects
Commit f846e1fb authored by Bas Nijholt's avatar Bas Nijholt
Browse files

Add a test for the 'BalancingLearner' with various learners, closes #102'

parent 896b7b4e
No related branches found
No related tags found
No related merge requests found
Pipeline #12187 failed
......@@ -123,6 +123,8 @@ def ask_randomly(learner, rounds, points):
return xs, ls
# Tests
@run_with(Learner1D)
def test_uniform_sampling1D(learner_type, f, learner_kwargs):
"""Points are sampled uniformly if no data is provided.
......@@ -349,6 +351,37 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner
assert abs(learner.loss() - control.loss()) / learner.loss() < 1e-11
# XXX: the LearnerND currently fails because there is no `add_data=False` argument in ask.
@run_with(Learner1D, Learner2D, xfail(LearnerND), AverageLearner)
def test_balancing_learner(learner_type, f, learner_kwargs):
"""Test if the BalancingLearner works with the different types of learners."""
learners = [learner_type(generate_random_parametrization(f), **learner_kwargs)
for i in range(5)]
learner = BalancingLearner(learners)
# Emulate parallel execution
stash = []
for i in range(200):
xs, _ = learner.ask(10)
# Save 5 random points out of `xs` for later
random.shuffle(xs)
for _ in range(5):
stash.append(xs.pop())
for x in xs:
learner.tell(x, learner.function(x))
# Evaluate and add 5 random points from `stash`
random.shuffle(stash)
for _ in range(5):
learner.tell(stash.pop(), learner.function(x))
assert all(l.npoints > 20 for l in learner.learners)
@pytest.mark.xfail
@run_with(Learner1D, Learner2D, LearnerND)
def test_convergence_for_arbitrary_ordering(learner_type, f, learner_kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment