diff --git a/adaptive/tests/test_learner.py b/adaptive/tests/test_learner.py index d3352e123e7dc334b2370cc0f602f1895ae90910..8ac6d2f2516d62e19e60e7a01774668b18f7a377 100644 --- a/adaptive/tests/test_learner.py +++ b/adaptive/tests/test_learner.py @@ -163,6 +163,44 @@ def test_uniform_sampling1D(learner_type, f, learner_kwargs): assert max(ivals) / min(ivals) < 2 + 1e-8 +@pytest.mark.focus +def test_learnerND_as_described_in_issue_99(): + # https://gitlab.kwant-project.org/qt/adaptive/issues/99 + l = Learner1D(lambda x: x, (0, 4)) + + l.tell(0, 0) + l.tell(1, 0) + l.tell(2, 0) + + assert l.ask(1) == ([4], [np.inf]) + assert l.losses == {(0, 1): 0.25, (1, 2): 0.25} + assert l.losses_combined == {(0, 1): 0.25, (1, 2): 0.25, (2, 4.0): np.inf} + + assert l.ask(1) == ([3], [np.inf]) + # l.ask(1) + assert l.losses == {(0, 1): 0.25, (1, 2): 0.25} + assert l.losses_combined == {(0, 1): 0.25, (1, 2): 0.25, (2, 3.0): np.inf, (3.0, 4.0): np.inf} + + l.tell(4, 0) + + assert l.losses_combined == {(0, 1): 0.25, (1, 2): 0.25, (2, 3): 0.25, (3, 4): 0.25} + + +@pytest.mark.focus +def test_learnerND_as_described_in_issue_99_comment(): + # https://gitlab.kwant-project.org/qt/adaptive/issues/99 + l = Learner1D(lambda x: x, (0, 4)) + + l.tell(0, 0) + l.tell(1, 0) + l.tell(2, 0) + assert set(l.losses_combined.keys()) == {(0, 1), (1, 2)} + l.ask(1) + assert set(l.losses_combined.keys()) == {(0, 1), (1, 2), (2, 4)} + l.tell(3.5, 0) + assert set(l.losses_combined.keys()) == {(0, 1), (1, 2), (2, 3.5), (3.5, 4.0)} + + @pytest.mark.xfail @run_with(Learner2D, LearnerND) def test_uniform_sampling2D(learner_type, f, learner_kwargs):