Skip to content
Snippets Groups Projects
Commit a48fc888 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

add a test for the 'BalancingLearner' with various learners, closes #102'

parent 896b7b4e
No related branches found
No related tags found
No related merge requests found
Pipeline #12193 passed
...@@ -123,6 +123,8 @@ def ask_randomly(learner, rounds, points): ...@@ -123,6 +123,8 @@ def ask_randomly(learner, rounds, points):
return xs, ls return xs, ls
# Tests
@run_with(Learner1D) @run_with(Learner1D)
def test_uniform_sampling1D(learner_type, f, learner_kwargs): def test_uniform_sampling1D(learner_type, f, learner_kwargs):
"""Points are sampled uniformly if no data is provided. """Points are sampled uniformly if no data is provided.
...@@ -349,6 +351,40 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner ...@@ -349,6 +351,40 @@ def test_learner_performance_is_invariant_under_scaling(learner_type, f, learner
assert abs(learner.loss() - control.loss()) / learner.loss() < 1e-11 assert abs(learner.loss() - control.loss()) / learner.loss() < 1e-11
# XXX: the LearnerND currently fails because there is no `add_data=False` argument in ask.
@run_with(Learner1D, Learner2D, xfail(LearnerND), AverageLearner)
def test_balancing_learner(learner_type, f, learner_kwargs):
"""Test if the BalancingLearner works with the different types of learners."""
learners = [learner_type(generate_random_parametrization(f), **learner_kwargs)
for i in range(4)]
learner = BalancingLearner(learners)
# Emulate parallel execution
stash = []
for i in range(100):
n = random.randint(1, 10)
m = random.randint(0, n)
xs, _ = learner.ask(n, add_data=False)
# Save 'm' random points out of `xs` for later
random.shuffle(xs)
for _ in range(m):
stash.append(xs.pop())
for x in xs:
learner.tell(x, learner.function(x))
# Evaluate and add 'm' random points from `stash`
random.shuffle(stash)
for _ in range(m):
x = stash.pop()
learner.tell(x, learner.function(x))
assert all(l.npoints > 10 for l in learner.learners), [l.npoints for l in learner.learners]
@pytest.mark.xfail @pytest.mark.xfail
@run_with(Learner1D, Learner2D, LearnerND) @run_with(Learner1D, Learner2D, LearnerND)
def test_convergence_for_arbitrary_ordering(learner_type, f, learner_kwargs): def test_convergence_for_arbitrary_ordering(learner_type, f, learner_kwargs):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment