Skip to content
Snippets Groups Projects

Resolve "(Learner1D) add possibility to use the direct neighbors in the loss"

Merged Jorn Hoofwijk requested to merge 119-add-second-order-loss-to-adaptive into master
Compare and Show latest version
2 files
+ 55
49
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -18,13 +18,14 @@ from ..utils import cache_latest
def use_nn_neighbors(n):
"""Decorator to specify how many neighboring intervals the loss function uses.
This decorator can wrap around a loss function to let `~adaptive.Learner1D`
know that you would like to look at the N-nearest neighboring intervals.
Wraps loss functions to indicate that they expect intervals together
with ``n`` nearest neighbors
The loss function is then guaranteed to receive the data of at least the
nn-neighbors and a dict that tells you what the neighboring points of these
are. And the Learner1D will then make sure that the loss is updated whenever
on of the nn-neighbours changes.
N nearest neighbors (nn_neighbors) in a dict that tells you what the
neighboring points of these are. And the `~adaptive.Learner1D` will
then make sure that the loss is updated whenever one of the
nn_neighbors changes.
Examples
--------
@@ -32,40 +33,40 @@ def use_nn_neighbors(n):
This is a part of the curvature loss function
>>> @use_nn_neighbors(1)
...def triangle_loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
... # at the boundary, neighbours[<left boundary x>] is (None, <some other x>)
... xs = [x for x in xs if x is not None]
... if len(xs) <= 2:
... return (x_right - x_left) / scale[0]
...
... y_scale = scale[1] or 1
... ys_scaled = [data[x] / y_scale for x in xs]
... xs_scaled = [x / scale[0] for x in xs]
... N = len(xs) - 2
... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
... return sum(volume(pts[i:i+3]) for i in range(N)) / N
... def triangle_loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
... # at the boundary, neighbors[<left boundary x>] is (None, <some other x>)
... xs = [x for x in xs if x is not None]
... if len(xs) <= 2:
... return (x_right - x_left) / scale[0]
...
... y_scale = scale[1] or 1
... ys_scaled = [data[x] / y_scale for x in xs]
... xs_scaled = [x / scale[0] for x in xs]
... N = len(xs) - 2
... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
... return sum(volume(pts[i:i+3]) for i in range(N)) / N
Or you may define a loss that favours the (local) minima of a function.
>>>@use_nn_neighbors(1)
...def loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... n_left = neighbors[x_left][0]
... n_right = neighbors[x_right][1]
... is_min = True
...
... if n_left is not None and data[x_left] > data[n_left]:
... is_min = False
... if n_right is not None and data[x_right] > data[n_right]:
... is_min = False
...
... loss = (x_right - x_left) / scale[0]
...
... if is_min:
... return loss * 100
... return loss
>>> @use_nn_neighbors(1)
... def loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... n_left = neighbors[x_left][0]
... n_right = neighbors[x_right][1]
... is_min = True
...
... if n_left is not None and data[x_left] > data[n_left]:
... is_min = False
... if n_right is not None and data[x_right] > data[n_right]:
... is_min = False
...
... loss = (x_right - x_left) / scale[0]
...
... if is_min:
... return loss * 100
... return loss
"""
def _wrapped(loss_per_interval):
loss_per_interval.nn_neighbors = n
@@ -221,7 +222,11 @@ class Learner1D(BaseLearner):
-----
`loss_per_interval` takes 4 parameters: ``interval``, ``scale``,
``data``, and ``neighbors``, and returns a scalar; the loss over
the interval.
the interval. The `loss_per_interval` function should also have
an attribute `nn_neighbors` that indicates how many of the neighboring
intervals to `interval` are used. If `loss_per_interval` doesn't
have such an attribute, it's assumed that is uses **no** neighboring
intervals. Also see the `use_nn_neighbors` decorator.
interval : (float, float)
The bounds of the interval.
scale : (float, float)
Loading