Skip to content
Snippets Groups Projects

Resolve "(Learner1D) add possibility to use the direct neighbors in the loss"

Merged Jorn Hoofwijk requested to merge 119-add-second-order-loss-to-adaptive into master
1 unresolved thread
Compare and Show latest version
4 files
+ 108
56
Compare changes
  • Side-by-side
  • Inline
Files
4
@@ -15,14 +15,61 @@ from ..notebook_integration import ensure_holoviews
from ..utils import cache_latest
def annotate_with_nn_neighbors(nn_neighbors):
def uses_nth_neighbors(n):
"""Decorator to specify how many neighboring intervals the loss function uses.
Wraps loss functions to indicate that they expect intervals together
with ``n`` nearest neighbors
The loss function is then guaranteed to receive the data of at least the
N nearest neighbors (``nth_neighbors``) in a dict that tells you what the
neighboring points of these are. And the `~adaptive.Learner1D` will
then make sure that the loss is updated whenever one of the
``nth_neighbors`` changes.
Examples
--------
The next function is a part of the `get_curvature_loss` function.
>>> @uses_nth_neighbors(1)
... def triangle_loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
... # at the boundary, neighbors[<left boundary x>] is (None, <some other x>)
... xs = [x for x in xs if x is not None]
... if len(xs) <= 2:
... return (x_right - x_left) / scale[0]
...
... y_scale = scale[1] or 1
... ys_scaled = [data[x] / y_scale for x in xs]
... xs_scaled = [x / scale[0] for x in xs]
... N = len(xs) - 2
... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
... return sum(volume(pts[i:i+3]) for i in range(N)) / N
Or you may define a loss that favours the (local) minima of a function.
>>> @uses_nth_neighbors(1)
... def local_minima_resolving_loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... n_left = neighbors[x_left][0]
... n_right = neighbors[x_right][1]
... loss = (x_right - x_left) / scale[0]
...
... if not ((n_left is not None and data[x_left] > data[n_left])
... or (n_right is not None and data[x_right] > data[n_right])):
... return loss * 100
...
... return loss
"""
def _wrapped(loss_per_interval):
loss_per_interval.nn_neighbors = nn_neighbors
loss_per_interval.nth_neighbors = n
return loss_per_interval
return _wrapped
@annotate_with_nn_neighbors(0)
@uses_nth_neighbors(0)
def uniform_loss(interval, scale, data, neighbors):
"""Loss function that samples the domain uniformly.
@@ -44,7 +91,7 @@ def uniform_loss(interval, scale, data, neighbors):
return dx
@annotate_with_nn_neighbors(0)
@uses_nth_neighbors(0)
def default_loss(interval, scale, data, neighbors):
"""Calculate loss on a single interval.
@@ -79,7 +126,7 @@ def _loss_of_multi_interval(xs, ys):
return sum(vol(pts[i:i+3]) for i in range(N)) / N
@annotate_with_nn_neighbors(1)
@uses_nth_neighbors(1)
def triangle_loss(interval, scale, data, neighbors):
x_left, x_right = interval
xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
@@ -95,7 +142,7 @@ def triangle_loss(interval, scale, data, neighbors):
def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
@annotate_with_nn_neighbors(1)
@uses_nth_neighbors(1)
def curvature_loss(interval, scale, data, neighbors):
triangle_loss_ = triangle_loss(interval, scale, data, neighbors)
default_loss_ = default_loss(interval, scale, data, neighbors)
@@ -129,8 +176,8 @@ def _get_neighbors_from_list(xs):
return sortedcontainers.SortedDict(neighbors)
def _get_intervals(x, neighbors, nn_neighbors):
nn = nn_neighbors
def _get_intervals(x, neighbors, nth_neighbors):
nn = nth_neighbors
i = neighbors.index(x)
start = max(0, i - nn - 1)
end = min(len(neighbors), i + nn + 2)
@@ -152,12 +199,6 @@ class Learner1D(BaseLearner):
A function that returns the loss for a single interval of the domain.
If not provided, then a default is used, which uses the scaled distance
in the x-y plane as the loss. See the notes for more details.
nn_neighbors : int, optional, default: None
The number of neighboring intervals that the loss function
takes into account. If ``loss_per_interval`` doesn't use the neighbors
at all, then it should be 0. By default we try to access the
``loss_per_interval.nn_neighbors`` attribute which is set for all
implemented loss functions.
Attributes
----------
@@ -181,18 +222,25 @@ class Learner1D(BaseLearner):
to have values for both of the points in 'interval'.
neighbors : dict(float → (float, float))
A map containing points as keys to its neighbors as a tuple.
At the left ``x_left`` and right ``x_left`` most boundary it has
``x_left: (None, float)`` and ``x_right: (float, None)``.
The `loss_per_interval` function should also have
an attribute `nth_neighbors` that indicates how many of the neighboring
intervals to `interval` are used. If `loss_per_interval` doesn't
have such an attribute, it's assumed that is uses **no** neighboring
intervals. Also see the `uses_nth_neighbors` decorator.
**WARNING**: When modifying the `data` and `neighbors` datastructures
the learner will behave in an undefined way.
"""
def __init__(self, function, bounds, loss_per_interval=None,
*, nn_neighbors=None):
def __init__(self, function, bounds, loss_per_interval=None):
self.function = function
if nn_neighbors is not None:
self.nn_neighbors = nn_neighbors
elif hasattr(loss_per_interval, 'nn_neighbors'):
self.nn_neighbors = loss_per_interval.nn_neighbors
if hasattr(loss_per_interval, 'nth_neighbors'):
self.nth_neighbors = loss_per_interval.nth_neighbors
else:
self.nn_neighbors = 0
self.nth_neighbors = 0
self.loss_per_interval = loss_per_interval or default_loss
@@ -293,10 +341,10 @@ class Learner1D(BaseLearner):
if real:
# We need to update all interpolated losses in the interval
# (x_left, x), (x, x_right) and the nn_neighbors nearest
# (x_left, x), (x, x_right) and the nth_neighbors nearest
# neighboring intervals. Since the addition of the
# point 'x' could change their loss.
for ival in _get_intervals(x, self.neighbors, self.nn_neighbors):
for ival in _get_intervals(x, self.neighbors, self.nth_neighbors):
self._update_interpolated_loss_in_interval(*ival)
# Since 'x' is in between (x_left, x_right),
Loading