Resolve "(Learner1D) add possibility to use the direct neighbors in the loss"
Closes #119 (closed)
Currently works for one
Still have to make it work for
Also performance is actually quite good. As in: the learner slows down about 1.5 times. Going (on my laptop) from 3 seconds per 1000 points to 4.5 seconds per 1000 points.
Which, I believe, will be more than compensated by the fact that the chosen points are generally better
Merge request reports
Activity
added 2 commits
- Resolved by Bas Nijholt
added 1 commit
- b393efc2 - use _get_loss_in_interval in tell_many method
Perfect! I am using it ATM and it seems to work really well.
The only thing I don't like it that it adds an extra complication for the user.
There might be an automatic way of determining that
loss_depends_on_neighbors = True
such that the user wouldn't have to specify it. We could of course also change the API such that neighbors are always passed, which is a very cheap operation to, the updating of the neigboring losses is a bit expensive to always do though. Probably @anton-akhmerov and @jbweston have an opinion.Well, the only way would be to either make it a seperate learner (like CurvatureLearner1D) or make it true by default. As I cannot think of an easy way to automatically determine it.
Unless one would use a decorator around a loss or something that indicates to the learner we need the neighbours:
e.g.
@uses_neighbours def loss(interval, neighbours, scale, all_points): return "something"
Edited by Jorn Hoofwijkadded 1 commit
- be4765b7 - remove 'loss_depends_on_neighbours' and replace by 'nn_neighbors'
I removed
loss_depends_on_neighbours
and replaced it bynn_neighbors
(N
nearest neighbors). This makes the code work with any number of neighbors. Now the new triangle loss even works withnn_neigbors=0
. I also added that we pass the neighbors in all loss_per_interval functions.Edited by Bas Nijholtadded 1 commit
- 9d49bc09 - rename 'function_values' to 'data' because it's more obvious what it is
added 2 commits
added 1 commit
- 0e0dd5da - rename 'function_values' to 'data' because its more obvious
@Jorn I am really confused:
loss = loss_per_interval=adaptive.learner.learner1D.get_curvature_loss() learner1 = adaptive.Learner1D(f, bounds=(-1, 1), nn_neighbors=0, loss_per_interval=loss) learner2 = adaptive.Learner1D(f, bounds=(-1, 1), nn_neighbors=1, loss_per_interval=loss) runner = adaptive.BlockingRunner(learner1, goal=lambda l: l.npoints > 400, log=True) adaptive.runner.replay_log(learner2, runner.log) assert learner1.data == learner2.data assert learner1.losses == learner2.losses
This implies that the losses of the neighbors don't really need to be updated, or not?
edit: I understand this now. This happens on my computer with just 2 cores, however on
io
(with 48 cores) there are more interpolated intervals, so there anAssertionError
is raised.Edited by Bas NijholtI would expect that the second assertion should fail for any function that is not a straight line, regardless of the number of cores.
Edited by Jorn Hoofwijk@basnijholt I found why it goes wrong and solved it in: f9032bc4
added 1 commit
- 1d637cc8 - introduce 'fast_det' for 2x2 and 3x3 matrices
mentioned in merge request !132
- Resolved by Bas Nijholt
added 1 commit
- 86363644 - introduce 'annotate_with_nn_neighbors' decorator
- Resolved by Bas Nijholt
- Resolved by Bas Nijholt
added 1 commit
- 51b8ea9c - remove nn_neighbors from Learner1D signature
added 1 commit
- f73878e0 - add two samples of loss functions with nn_neighbors=1
added 10 commits
-
008d47e1 - 1 commit from branch
master
- bcd213c3 - added a curvature_loss function to learner1D
- ff20135d - remove 'loss_depends_on_neighbours' and replace by 'nn_neighbors'
- d1c4e279 - rename 'function_values' to 'data' because its more obvious
- a0e73d8b - simplifications
- 4a41f4f5 - introduce 'fast_det' for 2x2 and 3x3 matrices
- 20eba318 - fix tests
- 590e9738 - introduce 'annotate_with_nn_neighbors' decorator
- a28192b3 - added curvature docs
- 260b0ddb - add two samples of loss functions with nn_neighbors=1
Toggle commit list-
008d47e1 - 1 commit from branch
added 5 commits
Toggle commit list