Skip to content
Snippets Groups Projects
Commit f68bd816 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

introduce 'uses_nth_neighbors' decorator and rename to 'nth_neighbors'

This abstracts the attribute 'nn_neighbors' away and makes it easier
for the user, because one can now just set the 'loss_per_interval'
and the 'nn_neighbors' will be set be default.
parent ae6af683
No related branches found
No related tags found
1 merge request!131Resolve "(Learner1D) add possibility to use the direct neighbors in the loss"
......@@ -15,6 +15,61 @@ from ..notebook_integration import ensure_holoviews
from ..utils import cache_latest
def uses_nth_neighbors(n):
"""Decorator to specify how many neighboring intervals the loss function uses.
Wraps loss functions to indicate that they expect intervals together
with ``n`` nearest neighbors
The loss function is then guaranteed to receive the data of at least the
N nearest neighbors (``nth_neighbors``) in a dict that tells you what the
neighboring points of these are. And the `~adaptive.Learner1D` will
then make sure that the loss is updated whenever one of the
``nth_neighbors`` changes.
Examples
--------
The next function is a part of the `get_curvature_loss` function.
>>> @uses_nth_neighbors(1)
... def triangle_loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
... # at the boundary, neighbors[<left boundary x>] is (None, <some other x>)
... xs = [x for x in xs if x is not None]
... if len(xs) <= 2:
... return (x_right - x_left) / scale[0]
...
... y_scale = scale[1] or 1
... ys_scaled = [data[x] / y_scale for x in xs]
... xs_scaled = [x / scale[0] for x in xs]
... N = len(xs) - 2
... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
... return sum(volume(pts[i:i+3]) for i in range(N)) / N
Or you may define a loss that favours the (local) minima of a function.
>>> @uses_nth_neighbors(1)
... def local_minima_resolving_loss(interval, scale, data, neighbors):
... x_left, x_right = interval
... n_left = neighbors[x_left][0]
... n_right = neighbors[x_right][1]
... loss = (x_right - x_left) / scale[0]
...
... if not ((n_left is not None and data[x_left] > data[n_left])
... or (n_right is not None and data[x_right] > data[n_right])):
... return loss * 100
...
... return loss
"""
def _wrapped(loss_per_interval):
loss_per_interval.nth_neighbors = n
return loss_per_interval
return _wrapped
@uses_nth_neighbors(0)
def uniform_loss(interval, scale, data, neighbors):
"""Loss function that samples the domain uniformly.
......@@ -36,6 +91,7 @@ def uniform_loss(interval, scale, data, neighbors):
return dx
@uses_nth_neighbors(0)
def default_loss(interval, scale, data, neighbors):
"""Calculate loss on a single interval.
......@@ -70,6 +126,7 @@ def _loss_of_multi_interval(xs, ys):
return sum(vol(pts[i:i+3]) for i in range(N)) / N
@uses_nth_neighbors(1)
def triangle_loss(interval, scale, data, neighbors):
x_left, x_right = interval
xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
......@@ -85,6 +142,7 @@ def triangle_loss(interval, scale, data, neighbors):
def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
@uses_nth_neighbors(1)
def curvature_loss(interval, scale, data, neighbors):
triangle_loss_ = triangle_loss(interval, scale, data, neighbors)
default_loss_ = default_loss(interval, scale, data, neighbors)
......@@ -118,8 +176,8 @@ def _get_neighbors_from_list(xs):
return sortedcontainers.SortedDict(neighbors)
def _get_intervals(x, neighbors, nn_neighbors):
nn = nn_neighbors
def _get_intervals(x, neighbors, nth_neighbors):
nn = nth_neighbors
i = neighbors.index(x)
start = max(0, i - nn - 1)
end = min(len(neighbors), i + nn + 2)
......@@ -141,10 +199,6 @@ class Learner1D(BaseLearner):
A function that returns the loss for a single interval of the domain.
If not provided, then a default is used, which uses the scaled distance
in the x-y plane as the loss. See the notes for more details.
nn_neighbors : int, default: 0
The number of neighboring intervals that the loss function
takes into account. If ``loss_per_interval`` doesn't use the neighbors
at all, then it should be 0.
Attributes
----------
......@@ -170,16 +224,25 @@ class Learner1D(BaseLearner):
A map containing points as keys to its neighbors as a tuple.
At the left ``x_left`` and right ``x_left`` most boundary it has
``x_left: (None, float)`` and ``x_right: (float, None)``.
The `loss_per_interval` function should also have
an attribute `nth_neighbors` that indicates how many of the neighboring
intervals to `interval` are used. If `loss_per_interval` doesn't
have such an attribute, it's assumed that is uses **no** neighboring
intervals. Also see the `uses_nth_neighbors` decorator.
**WARNING**: When modifying the `data` and `neighbors` datastructures
the learner will behave in an undefined way.
"""
def __init__(self, function, bounds, loss_per_interval=None, nn_neighbors=0):
def __init__(self, function, bounds, loss_per_interval=None):
self.function = function
self.nn_neighbors = nn_neighbors
if nn_neighbors == 0:
self.loss_per_interval = loss_per_interval or default_loss
if hasattr(loss_per_interval, 'nth_neighbors'):
self.nth_neighbors = loss_per_interval.nth_neighbors
else:
self.loss_per_interval = loss_per_interval or get_curvature_loss()
self.nth_neighbors = 0
self.loss_per_interval = loss_per_interval or default_loss
# A dict storing the loss function for each interval x_n.
self.losses = {}
......@@ -278,10 +341,10 @@ class Learner1D(BaseLearner):
if real:
# We need to update all interpolated losses in the interval
# (x_left, x), (x, x_right) and the nn_neighbors nearest
# (x_left, x), (x, x_right) and the nth_neighbors nearest
# neighboring intervals. Since the addition of the
# point 'x' could change their loss.
for ival in _get_intervals(x, self.neighbors, self.nn_neighbors):
for ival in _get_intervals(x, self.neighbors, self.nth_neighbors):
self._update_interpolated_loss_in_interval(*ival)
# Since 'x' is in between (x_left, x_right),
......
......@@ -347,19 +347,19 @@ def test_curvature_loss():
def f(x):
return np.tanh(20*x)
for n in [0, 1]:
learner = Learner1D(f, (-1, 1),
loss_per_interval=get_curvature_loss(), nn_neighbors=n)
simple(learner, goal=lambda l: l.npoints > 100)
assert learner.npoints > 100
loss = get_curvature_loss()
assert loss.nth_neighbors == 1
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
simple(learner, goal=lambda l: l.npoints > 100)
assert learner.npoints > 100
def test_curvature_loss_vectors():
def f(x):
return np.tanh(20*x), np.tanh(20*(x-0.4))
for n in [0, 1]:
learner = Learner1D(f, (-1, 1),
loss_per_interval=get_curvature_loss(), nn_neighbors=n)
simple(learner, goal=lambda l: l.npoints > 100)
assert learner.npoints > 100
loss = get_curvature_loss()
assert loss.nth_neighbors == 1
learner = Learner1D(f, (-1, 1), loss_per_interval=loss)
simple(learner, goal=lambda l: l.npoints > 100)
assert learner.npoints > 100
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment