Skip to content
Snippets Groups Projects
Commit c731c877 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

adds interpolation for unknown points

parent 7f1a2620
No related branches found
No related tags found
No related merge requests found
......@@ -57,16 +57,24 @@ class Learner1D(object):
if xdata is not None:
self.add_data(xdata, ydata)
def loss(self, x_left, x_right):
def loss(self, x_left, x_right, interpolate=False):
"""Calculate loss in the interval x_left, x_right.
Currently returns the rescaled length of the interval. If one of the
y-values is missing, returns 0 (so the intervals with missing data are
never touched. This behavior should be improved later.
"""
if interpolate:
ydata = self.interp_ydata
assert ydata.keys() == self._ydata.keys()
else:
ydata = self._ydata
assert x_left < x_right and self._neighbors[x_left][1] == x_right
try:
y_right, y_left = self._ydata[x_right], self._ydata[x_left]
y_right, y_left = ydata[x_right], ydata[x_left]
return sqrt(((x_right - x_left) / self._scale[0])**2 +
((y_right - y_left) / self._scale[1])**2)
except TypeError: # One of y-values is None.
......@@ -99,12 +107,12 @@ class Learner1D(object):
pos = np.searchsorted(xvals, x) # This could be done for multiple vals at once
self._neighbors[None] = [None, None] # To reduce the number of condititons.
x_lower = xvals[pos-1] if pos != 0 else None
x_upper = xvals[pos] if pos != len(xvals) else None
x_left = xvals[pos-1] if pos != 0 else None
x_right = xvals[pos] if pos != len(xvals) else None
self._neighbors[x] = [x_lower, x_upper]
self._neighbors[x_lower][1] = x
self._neighbors[x_upper][0] = x
self._neighbors[x] = [x_left, x_right]
self._neighbors[x_left][1] = x
self._neighbors[x_right][0] = x
del self._neighbors[None]
# Update the scale.
......@@ -117,13 +125,13 @@ class Learner1D(object):
self._bbox[1][1] - self._bbox[1][0]]
# Update the losses.
x_lower, x_upper = self._neighbors[x]
if x_lower is not None:
self._losses[x_lower, x] = self.loss(x_lower, x)
if x_upper is not None:
self._losses[x, x_upper] = self.loss(x, x_upper)
x_left, x_right = self._neighbors[x]
if x_left is not None:
self._losses[x_left, x] = self.loss(x_left, x)
if x_right is not None:
self._losses[x, x_right] = self.loss(x, x_right)
try:
del self._losses[x_lower, x_upper]
del self._losses[x_left, x_right]
except KeyError:
pass
......@@ -132,7 +140,7 @@ class Learner1D(object):
self._losses = {key: self.loss(*key) for key in self._losses}
self._oldscale = self._scale
def choose_points(self, n=10, add_to_data=False):
def choose_points(self, n=10, add_to_data=False, interpolate=False):
"""Return n points that are expected to maximally reduce the loss."""
# Find out how to divide the n points over the intervals
# by finding positive integer n_i that minimize max(L_i / n_i) subject
......@@ -141,12 +149,18 @@ class Learner1D(object):
# Return equally spaced points within each interval to which points
# will be added.
if interpolate:
self.interpolate()
losses = self.interp_losses.items()
else:
losses = self._losses.items()
def points(x, n):
return list(np.linspace(x[0], x[1], n, endpoint=False)[1:])
# Calculate how many points belong to each interval.
quals = [(-loss, x_range, 1) for (x_range, loss) in
self._losses.items()]
losses]
heapq.heapify(quals)
for point_number in range(n):
quality, x, n = quals[0]
......@@ -183,3 +197,54 @@ class Learner1D(object):
self.unfinished[x] = y
except TypeError:
self.unfinished[xs] = ys
def interpolate(self):
"""Estimates the approximate positions of unknown y-values by
interpolating and assuming the unknown point lies on a line between
its nearest known neighbors.
Upon running this function it adds:
self.interp_ydata
self.interp_losses
self.real_neighbors
"""
ydata = sorted([x for x, y in self._ydata.items() if y is not None])
self.real_neighbors = {}
for i, y in enumerate(ydata):
if i == 0:
self.real_neighbors[y] = [None, ydata[1]]
elif i == len(ydata) - 1:
self.real_neighbors[y] = [ydata[i-1], None]
else:
self.real_neighbors[y] = [ydata[i-1], ydata[i+1]]
ydata_unfinished = [x for x, y in self._ydata.items() if y is None]
indices = np.searchsorted(ydata, ydata_unfinished)
for i, y in zip(indices, ydata_unfinished):
x_left, x_right = self.real_neighbors[ydata[i]]
self.real_neighbors[y] = [x_left, ydata[i]]
self.interp_ydata = {}
for x, (x_left, x_right) in self.real_neighbors.items():
y = self._ydata[x]
if y is None:
y_left = self._ydata[x_left]
y_right = self._ydata[x_right]
y = np.interp(x, [x_left, x_right], [y_left, y_right])
self.interp_ydata[x] = y
self.interp_losses = {}
for x, (x_left, x_right) in self.real_neighbors.items():
if x_left is not None:
self.interp_losses[(x_left, x)] = self.loss(
x_left, x, interpolate=True)
if x_right is not None:
self.interp_losses[x, x_right] = self.loss(
x, x_right, interpolate=True)
try:
del self.interp_losses[x_left, x_right]
except KeyError:
pass
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment