# ------------------------------------------------------------------------------ # Filename: learner1D.py # Description: Contains 'Learner1D' object, a learner for 1D data. # ------------------------------------------------------------------------------ import heapq from math import sqrt import itertools import multiprocessing import numpy as np import tornado def add_arg(func): """Make func return (arg, func(arg)).""" def wrapper(*args): return (args[0], func(*args)) return wrapper class Learner1D(object): """ Learns and predicts a 1D function. Description ----------- Answers questions like: * "How much data do you need to get 2% accuracy?" * "What is the current status?" * "If I give you n data points, which ones would you like?" (initialise/request/promise/put/describe current state) """ def __init__(self, xdata=None, ydata=None): """Initialize the learner. Parameters ---------- data : Possibly empty list of float-like tuples, describing the initial data. """ # Set internal variables # A dict storing the loss function for each interval x_n. self.losses = {} # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local # properties. self.neighbors = {} # A dict {x_n: y_n} for quick checking of local # properties. self.data = {} # Bounding box [[minx, maxx], [miny, maxy]]. self._bbox = [[np.inf, -np.inf], [np.inf, -np.inf]] # Data scale (maxx - minx), (maxy - miny) self._scale = [0, 0] self._oldscale = [0, 0] # Add initial data if provided if xdata is not None: self.add_data(xdata, ydata) def interval_loss(self, x_left, x_right): """Calculate loss in the interval x_left, x_right. Currently returns the rescaled length of the interval. If one of the y-values is missing, returns 0 (so the intervals with missing data are never touched. This behavior should be improved later. """ y_right, y_left = self.interp_data[x_right], self.interp_data[x_left] return sqrt(((x_right - x_left) / self._scale[0])**2 + ((y_right - y_left) / self._scale[1])**2) def loss(self): if len(self.losses) == 0: return float('inf') else: return max(self.losses.values()) def add_data(self, xvalues, yvalues): """Add data to the intervals. Parameters ---------- xvalues : iterable of numbers Values of the x coordinate. yvalues : iterable of numbers and None Values of the y coordinate. `None` means that the value will be provided later. """ try: for x, y in zip(xvalues, yvalues): self.add_point(x, y) except TypeError: self.add_point(xvalues, yvalues) def add_point(self, x, y): """Update the data.""" self.data[x] = y # Update the scale. self._bbox[0][0] = min(self._bbox[0][0], x) self._bbox[0][1] = max(self._bbox[0][1], x) if y is not None: self._bbox[1][0] = min(self._bbox[1][0], y) self._bbox[1][1] = max(self._bbox[1][1], y) self._scale = [self._bbox[0][1] - self._bbox[0][0], self._bbox[1][1] - self._bbox[1][0]] def choose_points(self, n=10): """Return n points that are expected to maximally reduce the loss.""" # Find out how to divide the n points over the intervals # by finding positive integer n_i that minimize max(L_i / n_i) subject # to a constraint that sum(n_i) = n + N, with N the total number of # intervals. # Return equally spaced points within each interval to which points # will be added. self.interpolate() # Apply new interpolation step if new results def points(x, n): return list(np.linspace(x[0], x[1], n, endpoint=False)[1:]) # Calculate how many points belong to each interval. quals = [(-loss, x_range, 1) for (x_range, loss) in self.losses.items()] heapq.heapify(quals) for point_number in range(n): quality, x, n = quals[0] heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1)) xs = sum((points(x, n) for quality, x, n in quals), []) # Add `None`s to data because then the same point will not be returned # upon a next request. This can be used for parallelization. self.add_data(xs, itertools.repeat(None)) return xs def get_largest_interval(self): xs = sorted(x for x, y in self.data.items() if y is not None) if len(xs) < 2: return np.inf else: return np.diff(xs).max() def interpolate(self): xdata = [] ydata = [] xdata_unfinished = [] self.interp_data = {} for x in sorted(self.data): y = self.data[x] if y is None: xdata_unfinished.append(x) else: xdata.append(x) ydata.append(y) self.interp_data[x] = y if len(ydata) == 0: ydata_unfinished = (0, ) * len(xdata_unfinished) else: ydata_unfinished = np.interp(xdata_unfinished, xdata, ydata) for x, y in zip(xdata_unfinished, ydata_unfinished): self.interp_data[x] = y self.neighbors = {} xdata_sorted = sorted(self.interp_data) for i, x in enumerate(xdata_sorted): if i == 0: self.neighbors[x] = [None, xdata_sorted[1]] elif i == len(xdata_sorted) - 1: self.neighbors[x] = [xdata_sorted[i-1], None] else: self.neighbors[x] = [xdata_sorted[i-1], xdata_sorted[i+1]] self.losses = {} for x, (x_left, x_right) in self.neighbors.items(): if x_left is not None: self.losses[(x_left, x)] = self.interval_loss(x_left, x) if x_right is not None: self.losses[x, x_right] = self.interval_loss(x, x_right) try: del self.losses[x_left, x_right] except KeyError: pass # We can't use API that is specific to any particular asynchronous # framework, so we have to roll our own utility functions. async def any_complete(futures): total = tornado.concurrent.Future() for f in futures: f.add_done_callback(lambda f: total.set_result(None) if not total.done() else None) await total return [f for f in futures if f.done()] async def run(f, executor, learner, goal, ncores=multiprocessing.cpu_count()): xs = dict() done = [None] * ncores while not goal(learner): # Launch tasks to replace the ones that completed # on the last iteration. for x in learner.choose_points(len(done)): xs[executor.submit(f, x)] = x # Collect and results and add them to the learner futures = list(xs.keys()) await any_complete(futures) done = [fut for fut in futures if fut.done()] for fut in done: x = xs.pop(fut) # Need to explicitly await the future (even though we know the # result is there) to be compatible with Dask, who's futures' # 'result' method return a future themselves. y = await fut learner.add_point(x, y) # cancel any outstanding tasks for fut in xs.keys(): fut.cancel() # XXX: we should introduce an API for removing data points, and remove all # the data points with a 'None' value from the learner, or add a # method to simply remove all "unfinished" points from the learner.