From 881628700a67d55be7bf1a40a4d00dbe52872f21 Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Sat, 27 Oct 2018 03:43:07 +0200
Subject: [PATCH] remove 'loss_depends_on_neighbours' and replace by
 'nn_neighbors'

This makes the code work with any number of neighbors. Now the
new triangle loss even works with nn_neigbors=0. I also added
that we pass the neighbors in all loss_per_interval functions.
---
 adaptive/learner/learner1D.py    | 81 +++++++++++++++-----------------
 adaptive/tests/test_learner1d.py | 16 ++++---
 2 files changed, 49 insertions(+), 48 deletions(-)

diff --git a/adaptive/learner/learner1D.py b/adaptive/learner/learner1D.py
index d802ae6d..c0fa73b6 100644
--- a/adaptive/learner/learner1D.py
+++ b/adaptive/learner/learner1D.py
@@ -15,7 +15,7 @@ from ..notebook_integration import ensure_holoviews
 from ..utils import cache_latest
 
 
-def uniform_loss(interval, scale, function_values):
+def uniform_loss(interval, scale, function_values, neighbors):
     """Loss function that samples the domain uniformly.
 
     Works with `~adaptive.Learner1D` only.
@@ -36,7 +36,7 @@ def uniform_loss(interval, scale, function_values):
     return dx
 
 
-def default_loss(interval, scale, function_values):
+def default_loss(interval, scale, function_values, neighbors):
     """Calculate loss on a single interval.
 
     Currently returns the rescaled length of the interval. If one of the
@@ -70,12 +70,9 @@ def _loss_of_multi_interval(xs, ys):
     return sum(vol(pts[i:i+3]) for i in range(N)) / N
 
 
-def triangle_loss(interval, neighbours, scale, function_values):
+def triangle_loss(interval, scale, function_values, neighbors):
     x_left, x_right = interval
-    neighbour_left, neighbour_right = neighbours
-    xs = [neighbour_left, x_left, x_right, neighbour_right]
-    # The neighbours could be None if we are at the boundary, in that case we
-    # have to filter this out
+    xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
     xs = [x for x in xs if x is not None]
 
     if len(xs) <= 2:
@@ -88,9 +85,9 @@ def triangle_loss(interval, neighbours, scale, function_values):
 
 
 def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02):
-    def curvature_loss(interval, neighbours, scale, function_values):
-        triangle_loss_ = triangle_loss(interval, neighbours, scale, function_values)
-        default_loss_ = default_loss(interval, scale, function_values)
+    def curvature_loss(interval, scale, function_values, neighbors):
+        triangle_loss_ = triangle_loss(interval, scale, function_values, neighbors)
+        default_loss_ = default_loss(interval, scale, function_values, neighbors)
         dx = (interval[1] - interval[0]) / scale[0]
         return (area_factor * (triangle_loss_**0.5)
                 + euclid_factor * default_loss_
@@ -121,6 +118,15 @@ def _get_neighbors_from_list(xs):
     return sortedcontainers.SortedDict(neighbors)
 
 
+def _get_intervals(x, neighbors, nn_neighbors):
+    nn = nn_neighbors
+    i = neighbors.index(x)
+    start = max(0, i - nn - 1)
+    end = min(len(neighbors), i + nn + 2)
+    points = neighbors.keys()[start:end]
+    return list(zip(points, points[1:]))
+
+
 class Learner1D(BaseLearner):
     """Learns and predicts a function 'f:ℝ → ℝ^N'.
 
@@ -135,6 +141,10 @@ class Learner1D(BaseLearner):
         A function that returns the loss for a single interval of the domain.
         If not provided, then a default is used, which uses the scaled distance
         in the x-y plane as the loss. See the notes for more details.
+    nn_neighbors : int, default: 0
+        The number of neighboring intervals that the loss function
+        takes into account. If ``loss_per_interval`` doesn't use the neighbors
+        at all, then it should be 0.
 
     Attributes
     ----------
@@ -145,9 +155,9 @@ class Learner1D(BaseLearner):
 
     Notes
     -----
-    `loss_per_interval` takes 3 parameters: ``interval``,  ``scale``, and
-    ``function_values``, and returns a scalar; the loss over the interval.
-
+    `loss_per_interval` takes 4 parameters: ``interval``,  ``scale``,
+    ``data``, and ``neighbors``, and returns a scalar; the loss over
+    the interval.
     interval : (float, float)
         The bounds of the interval.
     scale : (float, float)
@@ -156,16 +166,18 @@ class Learner1D(BaseLearner):
     function_values : dict(float → float)
         A map containing evaluated function values. It is guaranteed
         to have values for both of the points in 'interval'.
+    neighbors : dict(float → (float, float))
+        A map containing points as keys to its neighbors as a tuple.
     """
 
-    def __init__(self, function, bounds, loss_per_interval=None, loss_depends_on_neighbours=False):
+    def __init__(self, function, bounds, loss_per_interval=None, nn_neighbors=0):
         self.function = function
-        self._loss_depends_on_neighbours = loss_depends_on_neighbours
+        self.nn_neighbors = nn_neighbors
 
-        if loss_depends_on_neighbours:
-            self.loss_per_interval = loss_per_interval or get_curvature_loss()
-        else:
+        if nn_neighbors == 0:
             self.loss_per_interval = loss_per_interval or default_loss
+        else:
+            self.loss_per_interval = loss_per_interval or get_curvature_loss()
 
         # A dict storing the loss function for each interval x_n.
         self.losses = {}
@@ -230,15 +242,8 @@ class Learner1D(BaseLearner):
             return 0
 
         # we need to compute the loss for this interval
-        interval = (x_left, x_right)
-        if self._loss_depends_on_neighbours:
-            neighbour_left = self.neighbors.get(x_left, (None, None))[0]
-            neighbour_right = self.neighbors.get(x_right, (None, None))[1]
-            neighbours = neighbour_left, neighbour_right
-            return self.loss_per_interval(interval, neighbours,
-                                          self._scale, self.data)
-        else:
-            return self.loss_per_interval(interval, self._scale, self.data)
+        return self.loss_per_interval(
+            (x_left, x_right), self._scale, self.data, self.neighbors)
 
 
     def _update_interpolated_loss_in_interval(self, x_left, x_right):
@@ -271,17 +276,11 @@ class Learner1D(BaseLearner):
 
         if real:
             # We need to update all interpolated losses in the interval
-            # (x_left, x) and (x, x_right). Since the addition of the point
-            # 'x' could change their loss.
-            self._update_interpolated_loss_in_interval(x_left, x)
-            self._update_interpolated_loss_in_interval(x, x_right)
-
-            # if the loss depends on the neighbors we should also update those losses
-            if self._loss_depends_on_neighbours:
-                neighbour_left = self.neighbors.get(x_left, (None, None))[0]
-                neighbour_right = self.neighbors.get(x_right, (None, None))[1]
-                self._update_interpolated_loss_in_interval(neighbour_left, x_left)
-                self._update_interpolated_loss_in_interval(x_right, neighbour_right)
+            # (x_left, x), (x, x_right) and the nn_neighbors nearest
+            # neighboring intervals. Since the addition of the
+            # point 'x' could change their loss.
+            for ival in _get_intervals(x, self.neighbors, self.nn_neighbors):
+                self._update_interpolated_loss_in_interval(*ival)
 
             # Since 'x' is in between (x_left, x_right),
             # we get rid of the interval.
@@ -427,10 +426,8 @@ class Learner1D(BaseLearner):
 
         # The the losses for the "real" intervals.
         self.losses = {}
-        for x_left, x_right in intervals:
-            self.losses[x_left, x_right] = (
-                self._get_loss_in_interval(x_left, x_right)
-                if x_right - x_left >= self._dx_eps else 0)
+        for ival in intervals:
+            self.losses[ival] = self._get_loss_in_interval(*ival)
 
         # List with "real" intervals that have interpolated intervals inside
         to_interpolate = []
diff --git a/adaptive/tests/test_learner1d.py b/adaptive/tests/test_learner1d.py
index d2f5b26a..d5d325f2 100644
--- a/adaptive/tests/test_learner1d.py
+++ b/adaptive/tests/test_learner1d.py
@@ -347,15 +347,19 @@ def test_curvature_loss():
     def f(x):
         return np.tanh(20*x)
 
-    learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True)
-    simple(learner, goal=lambda l: l.npoints > 100)
-    # assert this is reached without error
+    for n in [0, 1]:
+        learner = Learner1D(f, (-1, 1),
+            loss_per_interval=get_curvature_loss(), nn_neighbors=n)
+        simple(learner, goal=lambda l: l.npoints > 100)
+        assert learner.npoints > 100
 
 
 def test_curvature_loss_vectors():
     def f(x):
         return np.tanh(20*x), np.tanh(20*(x-0.4))
 
-    learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True)
-    simple(learner, goal=lambda l: l.npoints > 100)
-    assert learner.npoints > 100
+    for n in [0, 1]:
+        learner = Learner1D(f, (-1, 1),
+            loss_per_interval=get_curvature_loss(), nn_neighbors=n)
+        simple(learner, goal=lambda l: l.npoints > 100)
+        assert learner.npoints > 100
-- 
GitLab