Skip to content
Snippets Groups Projects
Commit e7df6eb0 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

make learner work with concurrent.futures

parent 28c1cc74
No related branches found
No related tags found
No related merge requests found
......@@ -4,10 +4,10 @@
# TODO:
#-------------------------------------------------------------------------------
import numpy as np
from math import sqrt
import heapq
from math import sqrt
import itertools
import numpy as np
class Learner1D(object):
""" Learns and predicts a 1D function.
......@@ -50,22 +50,24 @@ class Learner1D(object):
self._scale = [0, 0]
self._oldscale = [0, 0]
self.unfinished = {}
# Add initial data if provided
if xdata is not None:
self.add_data(xdata, ydata)
def loss(self, x_i, x_f):
"""Calculate loss in the interval x_i, x_f.
def loss(self, x_left, x_right):
"""Calculate loss in the interval x_left, x_right.
Currently returns the rescaled length of the interval. If one of the
y-values is missing, returns 0 (so the intervals with missing data are
never touched. This behavior should be improved later.
"""
assert x_i < x_f and self._neighbors[x_i][1] == x_f
assert x_left < x_right and self._neighbors[x_left][1] == x_right
try:
return sqrt(((x_f - x_i) / self._scale[0])**2 +
((self._ydata[x_f] - self._ydata[x_i])
/ self._scale[1])**2)
y_right, y_left = self._ydata[x_right], self._ydata[x_left]
return sqrt(((x_right - x_left) / self._scale[0])**2 +
((y_right - y_left) / self._scale[1])**2)
except TypeError: # One of y-values is None.
return 0
......@@ -80,22 +82,25 @@ class Learner1D(object):
Values of the y coordinate. `None` means that the value will be
provided later.
"""
for x, y in zip(xvalues, yvalues):
self.add_point(x, y)
try:
for x, y in zip(xvalues, yvalues):
self.add_point(x, y)
except TypeError:
self.add_point(xvalues, yvalues)
def add_point(self, x, y):
# Update the data
"""Update the data."""
self._ydata[x] = y
# Update the neighbors.
if x not in self._neighbors: # The point is new
xvals = np.sort(list(self._neighbors.keys()))
pos = np.searchsorted(xvals, x)
self._neighbors[None] = [None, None] # To reduce the number of
# condititons.
xvals = sorted(self._neighbors)
pos = np.searchsorted(xvals, x) # This could be done for multiple vals at once
self._neighbors[None] = [None, None] # To reduce the number of condititons.
x_lower = xvals[pos-1] if pos != 0 else None
x_upper = xvals[pos] if pos != len(xvals) else None
# print x_lower, x_upper, x
self._neighbors[x] = [x_lower, x_upper]
self._neighbors[x_lower][1] = x
self._neighbors[x_upper][0] = x
......@@ -126,7 +131,7 @@ class Learner1D(object):
self._losses = {key: self.loss(*key) for key in self._losses}
self._oldscale = self._scale
def choose_points(self, n=10):
def choose_points(self, n=10, add_to_data=False):
"""Return n points that are expected to maximally reduce the loss."""
# Find out how to divide the n points over the intervals
# by finding positive integer n_i that minimize max(L_i / n_i) subject
......@@ -135,22 +140,44 @@ class Learner1D(object):
# Return equally spaced points within each interval to which points
# will be added.
points = lambda x, n: list(np.linspace(x[0], x[1], n,
endpoint=False)[1:])
def points(x, n):
return list(np.linspace(x[0], x[1], n, endpoint=False)[1:])
# Calculate how many points belong to each interval.
quals = [(-loss, x_i, 1) for (x_i, loss) in
quals = [(-loss, x_range, 1) for (x_range, loss) in
self._losses.items()]
heapq.heapify(quals)
for point_number in range(n):
quality, x, n = quals[0]
heapq.heapreplace(quals, (quality * n / (n+1), x, n + 1))
return sum((points(x, n) for quality, x, n in quals), [])
xs = sum((points(x, n) for quality, x, n in quals), [])
# Add `None`s to data because then the same point will not be returned
# upon a next request. This can be used for parallelization.
if add_to_data:
self.add_data(xs, itertools.repeat(None))
return xs
def get_status(self):
""" Report current status.
"""Report current status.
So far just returns some internal variables [losses, intervals and
data]
"""
return self._losses, self._neighbors, self._ydata
def get_results(self):
"""Work with distributed.client.Future objects."""
for x, y in self.unfinished.items():
if y.done():
y = self.unfinished.pop(x).result()
self.add_point(x, y)
def add_futures(self, xs, ys):
try:
for x, y in zip(xs, ys):
self.unfinished[x] = y
except TypeError:
self.unfinished[xs] = ys
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment