diff --git a/adaptive/learner/learner1D.py b/adaptive/learner/learner1D.py index d802ae6d71c623c5bbbef02017bab5332dc7cbc0..c0fa73b606f873e42b4cb842be174ad5484b6b25 100644 --- a/adaptive/learner/learner1D.py +++ b/adaptive/learner/learner1D.py @@ -15,7 +15,7 @@ from ..notebook_integration import ensure_holoviews from ..utils import cache_latest -def uniform_loss(interval, scale, function_values): +def uniform_loss(interval, scale, function_values, neighbors): """Loss function that samples the domain uniformly. Works with `~adaptive.Learner1D` only. @@ -36,7 +36,7 @@ def uniform_loss(interval, scale, function_values): return dx -def default_loss(interval, scale, function_values): +def default_loss(interval, scale, function_values, neighbors): """Calculate loss on a single interval. Currently returns the rescaled length of the interval. If one of the @@ -70,12 +70,9 @@ def _loss_of_multi_interval(xs, ys): return sum(vol(pts[i:i+3]) for i in range(N)) / N -def triangle_loss(interval, neighbours, scale, function_values): +def triangle_loss(interval, scale, function_values, neighbors): x_left, x_right = interval - neighbour_left, neighbour_right = neighbours - xs = [neighbour_left, x_left, x_right, neighbour_right] - # The neighbours could be None if we are at the boundary, in that case we - # have to filter this out + xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]] xs = [x for x in xs if x is not None] if len(xs) <= 2: @@ -88,9 +85,9 @@ def triangle_loss(interval, neighbours, scale, function_values): def get_curvature_loss(area_factor=1, euclid_factor=0.02, horizontal_factor=0.02): - def curvature_loss(interval, neighbours, scale, function_values): - triangle_loss_ = triangle_loss(interval, neighbours, scale, function_values) - default_loss_ = default_loss(interval, scale, function_values) + def curvature_loss(interval, scale, function_values, neighbors): + triangle_loss_ = triangle_loss(interval, scale, function_values, neighbors) + default_loss_ = default_loss(interval, scale, function_values, neighbors) dx = (interval[1] - interval[0]) / scale[0] return (area_factor * (triangle_loss_**0.5) + euclid_factor * default_loss_ @@ -121,6 +118,15 @@ def _get_neighbors_from_list(xs): return sortedcontainers.SortedDict(neighbors) +def _get_intervals(x, neighbors, nn_neighbors): + nn = nn_neighbors + i = neighbors.index(x) + start = max(0, i - nn - 1) + end = min(len(neighbors), i + nn + 2) + points = neighbors.keys()[start:end] + return list(zip(points, points[1:])) + + class Learner1D(BaseLearner): """Learns and predicts a function 'f:℠→ â„^N'. @@ -135,6 +141,10 @@ class Learner1D(BaseLearner): A function that returns the loss for a single interval of the domain. If not provided, then a default is used, which uses the scaled distance in the x-y plane as the loss. See the notes for more details. + nn_neighbors : int, default: 0 + The number of neighboring intervals that the loss function + takes into account. If ``loss_per_interval`` doesn't use the neighbors + at all, then it should be 0. Attributes ---------- @@ -145,9 +155,9 @@ class Learner1D(BaseLearner): Notes ----- - `loss_per_interval` takes 3 parameters: ``interval``, ``scale``, and - ``function_values``, and returns a scalar; the loss over the interval. - + `loss_per_interval` takes 4 parameters: ``interval``, ``scale``, + ``data``, and ``neighbors``, and returns a scalar; the loss over + the interval. interval : (float, float) The bounds of the interval. scale : (float, float) @@ -156,16 +166,18 @@ class Learner1D(BaseLearner): function_values : dict(float → float) A map containing evaluated function values. It is guaranteed to have values for both of the points in 'interval'. + neighbors : dict(float → (float, float)) + A map containing points as keys to its neighbors as a tuple. """ - def __init__(self, function, bounds, loss_per_interval=None, loss_depends_on_neighbours=False): + def __init__(self, function, bounds, loss_per_interval=None, nn_neighbors=0): self.function = function - self._loss_depends_on_neighbours = loss_depends_on_neighbours + self.nn_neighbors = nn_neighbors - if loss_depends_on_neighbours: - self.loss_per_interval = loss_per_interval or get_curvature_loss() - else: + if nn_neighbors == 0: self.loss_per_interval = loss_per_interval or default_loss + else: + self.loss_per_interval = loss_per_interval or get_curvature_loss() # A dict storing the loss function for each interval x_n. self.losses = {} @@ -230,15 +242,8 @@ class Learner1D(BaseLearner): return 0 # we need to compute the loss for this interval - interval = (x_left, x_right) - if self._loss_depends_on_neighbours: - neighbour_left = self.neighbors.get(x_left, (None, None))[0] - neighbour_right = self.neighbors.get(x_right, (None, None))[1] - neighbours = neighbour_left, neighbour_right - return self.loss_per_interval(interval, neighbours, - self._scale, self.data) - else: - return self.loss_per_interval(interval, self._scale, self.data) + return self.loss_per_interval( + (x_left, x_right), self._scale, self.data, self.neighbors) def _update_interpolated_loss_in_interval(self, x_left, x_right): @@ -271,17 +276,11 @@ class Learner1D(BaseLearner): if real: # We need to update all interpolated losses in the interval - # (x_left, x) and (x, x_right). Since the addition of the point - # 'x' could change their loss. - self._update_interpolated_loss_in_interval(x_left, x) - self._update_interpolated_loss_in_interval(x, x_right) - - # if the loss depends on the neighbors we should also update those losses - if self._loss_depends_on_neighbours: - neighbour_left = self.neighbors.get(x_left, (None, None))[0] - neighbour_right = self.neighbors.get(x_right, (None, None))[1] - self._update_interpolated_loss_in_interval(neighbour_left, x_left) - self._update_interpolated_loss_in_interval(x_right, neighbour_right) + # (x_left, x), (x, x_right) and the nn_neighbors nearest + # neighboring intervals. Since the addition of the + # point 'x' could change their loss. + for ival in _get_intervals(x, self.neighbors, self.nn_neighbors): + self._update_interpolated_loss_in_interval(*ival) # Since 'x' is in between (x_left, x_right), # we get rid of the interval. @@ -427,10 +426,8 @@ class Learner1D(BaseLearner): # The the losses for the "real" intervals. self.losses = {} - for x_left, x_right in intervals: - self.losses[x_left, x_right] = ( - self._get_loss_in_interval(x_left, x_right) - if x_right - x_left >= self._dx_eps else 0) + for ival in intervals: + self.losses[ival] = self._get_loss_in_interval(*ival) # List with "real" intervals that have interpolated intervals inside to_interpolate = [] diff --git a/adaptive/tests/test_learner1d.py b/adaptive/tests/test_learner1d.py index d2f5b26ad0e729f96a1ede02006ac84fc57b86e4..d5d325f26d3a688670c7b39625500d51ccc6078d 100644 --- a/adaptive/tests/test_learner1d.py +++ b/adaptive/tests/test_learner1d.py @@ -347,15 +347,19 @@ def test_curvature_loss(): def f(x): return np.tanh(20*x) - learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True) - simple(learner, goal=lambda l: l.npoints > 100) - # assert this is reached without error + for n in [0, 1]: + learner = Learner1D(f, (-1, 1), + loss_per_interval=get_curvature_loss(), nn_neighbors=n) + simple(learner, goal=lambda l: l.npoints > 100) + assert learner.npoints > 100 def test_curvature_loss_vectors(): def f(x): return np.tanh(20*x), np.tanh(20*(x-0.4)) - learner = Learner1D(f, (-1, 1), loss_per_interval=get_curvature_loss(), loss_depends_on_neighbours=True) - simple(learner, goal=lambda l: l.npoints > 100) - assert learner.npoints > 100 + for n in [0, 1]: + learner = Learner1D(f, (-1, 1), + loss_per_interval=get_curvature_loss(), nn_neighbors=n) + simple(learner, goal=lambda l: l.npoints > 100) + assert learner.npoints > 100