Skip to content
Snippets Groups Projects

Resolve "(Learner1D) add possibility to use the direct neighbors in the loss"

Merged Jorn Hoofwijk requested to merge 119-add-second-order-loss-to-adaptive into master
1 unresolved thread
1 file
+ 24
7
Compare changes
  • Side-by-side
  • Inline
@@ -15,6 +15,14 @@ from ..notebook_integration import ensure_holoviews
from ..utils import cache_latest
def annotate_with_nn_neighbors(nn_neighbors):
def _wrapped(loss_per_interval):
loss_per_interval.nn_neighbors = nn_neighbors
return loss_per_interval
return _wrapped
@annotate_with_nn_neighbors(0)
def uniform_loss(interval, scale, data, neighbors):
"""Loss function that samples the domain uniformly.
@@ -36,6 +44,7 @@ def uniform_loss(interval, scale, data, neighbors):
return dx
@annotate_with_nn_neighbors(0)
def default_loss(interval, scale, data, neighbors):
"""Calculate loss on a single interval.
@@ -70,6 +79,7 @@ def _loss_of_multi_interval(xs, ys):
return sum(vol(pts[i:i+3]) for i in range(N)) / N
@annotate_with_nn_neighbors(1)
def triangle_loss(interval, scale, data, neighbors):
x_left, x_right = interval
xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
@@ -85,6 +95,7 @@ def triangle_loss(interval, scale, data, neighbors):
def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
@annotate_with_nn_neighbors(1)
def curvature_loss(interval, scale, data, neighbors):
triangle_loss_ = triangle_loss(interval, scale, data, neighbors)
default_loss_ = default_loss(interval, scale, data, neighbors)
@@ -141,10 +152,12 @@ class Learner1D(BaseLearner):
A function that returns the loss for a single interval of the domain.
If not provided, then a default is used, which uses the scaled distance
in the x-y plane as the loss. See the notes for more details.
nn_neighbors : int, default: 0
nn_neighbors : int, optional, default: None
The number of neighboring intervals that the loss function
takes into account. If ``loss_per_interval`` doesn't use the neighbors
at all, then it should be 0.
at all, then it should be 0. By default we try to access the
``loss_per_interval.nn_neighbors`` attribute which is set for all
implemented loss functions.
Attributes
----------
@@ -170,14 +183,18 @@ class Learner1D(BaseLearner):
A map containing points as keys to its neighbors as a tuple.
"""
def __init__(self, function, bounds, loss_per_interval=None, nn_neighbors=0):
def __init__(self, function, bounds, loss_per_interval=None,
*, nn_neighbors=None):
self.function = function
self.nn_neighbors = nn_neighbors
if nn_neighbors == 0:
self.loss_per_interval = loss_per_interval or default_loss
if nn_neighbors is not None:
self.nn_neighbors = nn_neighbors
elif hasattr(loss_per_interval, 'nn_neighbors'):
self.nn_neighbors = loss_per_interval.nn_neighbors
else:
self.loss_per_interval = loss_per_interval or get_curvature_loss()
self.nn_neighbors = 0
self.loss_per_interval = loss_per_interval or default_loss
# A dict storing the loss function for each interval x_n.
self.losses = {}
Loading