Skip to content
Snippets Groups Projects
learner1D.py 6.45 KiB
Newer Older
Bas Nijholt's avatar
Bas Nijholt committed
#-------------------------------------------------------------------------------
# Filename:     learner1D.py
# Description:  Contains 'Learner1D' object, a learner for 1D data.
#               TODO:
#-------------------------------------------------------------------------------
Bas Nijholt's avatar
Bas Nijholt committed

Bas Nijholt's avatar
Bas Nijholt committed
import heapq
from math import sqrt
import itertools
import numpy as np
Bas Nijholt's avatar
Bas Nijholt committed

class Learner1D(object):
    """ Learns and predicts a 1D function.

    Description
    -----------
    Answers questions like:
    * "How much data do you need to get 2% accuracy?"
    * "What is the current status?"
    * "If I give you n data points, which ones would you like?"
    (initialise/request/promise/put/describe current state)

    """

    def __init__(self, xdata=None, ydata=None):
        """Initialize the learner.

        Parameters
        ----------
        data :
           Possibly empty list of float-like tuples, describing the initial
           data.
        """

        # Set internal variables

        # A dict storing the loss function for each interval x_n.
        self._losses = {}

        # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
        # properties.
        self._neighbors = {}
        # A dict {x_n: y_n} for quick checking of local
        # properties.
        self._ydata = {}

        # Bounding box [[minx, maxx], [miny, maxy]].
        self._bbox = [[np.inf, -np.inf], [np.inf, -np.inf]]
        # Data scale (maxx - minx), (maxy - miny)
        self._scale = [0, 0]
        self._oldscale = [0, 0]

Bas Nijholt's avatar
Bas Nijholt committed
        # A dict with {x_n: concurrent.futures}
        self.unfinished = {}

Bas Nijholt's avatar
Bas Nijholt committed
        # Add initial data if provided
        if xdata is not None:
            self.add_data(xdata, ydata)

    def loss(self, x_left, x_right):
        """Calculate loss in the interval x_left, x_right.
Bas Nijholt's avatar
Bas Nijholt committed

        Currently returns the rescaled length of the interval. If one of the
        y-values is missing, returns 0 (so the intervals with missing data are
        never touched. This behavior should be improved later.
        """
        assert x_left < x_right and self._neighbors[x_left][1] == x_right
Bas Nijholt's avatar
Bas Nijholt committed
        try:
            y_right, y_left = self._ydata[x_right], self._ydata[x_left]
            return sqrt(((x_right - x_left) / self._scale[0])**2 +
                        ((y_right - y_left) / self._scale[1])**2)
Bas Nijholt's avatar
Bas Nijholt committed
        except TypeError:  # One of y-values is None.
            return 0

    def add_data(self, xvalues, yvalues):
        """Add data to the intervals.

        Parameters
        ----------
        xvalues : iterable of numbers
            Values of the x coordinate.
        yvalues : iterable of numbers and None
            Values of the y coordinate. `None` means that the value will be
            provided later.
        """
        try:
            for x, y in zip(xvalues, yvalues):
                self.add_point(x, y)
        except TypeError:
            self.add_point(xvalues, yvalues)
Bas Nijholt's avatar
Bas Nijholt committed

    def add_point(self, x, y):
        """Update the data."""
Bas Nijholt's avatar
Bas Nijholt committed
        self._ydata[x] = y

        # Update the neighbors.
        if x not in self._neighbors:  # The point is new
            xvals = sorted(self._neighbors)
            pos = np.searchsorted(xvals, x)  # This could be done for multiple vals at once
            self._neighbors[None] = [None, None]  # To reduce the number of condititons.

Bas Nijholt's avatar
Bas Nijholt committed
            x_lower = xvals[pos-1] if pos != 0 else None
            x_upper = xvals[pos] if pos != len(xvals) else None
Bas Nijholt's avatar
Bas Nijholt committed
            self._neighbors[x] = [x_lower, x_upper]
            self._neighbors[x_lower][1] = x
            self._neighbors[x_upper][0] = x
            del self._neighbors[None]

        # Update the scale.
        self._bbox[0][0] = min(self._bbox[0][0], x)
        self._bbox[0][1] = max(self._bbox[0][1], x)
        if y is not None:
            self._bbox[1][0] = min(self._bbox[1][0], y)
            self._bbox[1][1] = max(self._bbox[1][1], y)
        self._scale = [self._bbox[0][1] - self._bbox[0][0],
                       self._bbox[1][1] - self._bbox[1][0]]

        # Update the losses.
        x_lower, x_upper = self._neighbors[x]
        if x_lower is not None:
            self._losses[x_lower, x] = self.loss(x_lower, x)
        if x_upper is not None:
            self._losses[x, x_upper] = self.loss(x, x_upper)
        try:
            del self._losses[x_lower, x_upper]
        except KeyError:
            pass

        # If the scale has doubled, recompute all losses.
        if self._scale > self._oldscale * 2:
            self._losses = {key: self.loss(*key) for key in self._losses}
            self._oldscale = self._scale

    def choose_points(self, n=10, add_to_data=False):
Bas Nijholt's avatar
Bas Nijholt committed
        """Return n points that are expected to maximally reduce the loss."""
        # Find out how to divide the n points over the intervals
        # by finding  positive integer n_i that minimize max(L_i / n_i) subject
        # to a constraint that sum(n_i) = n + N, with N the total number of
        # intervals.

        # Return equally spaced points within each interval to which points
        # will be added.
        def points(x, n):
            return list(np.linspace(x[0], x[1], n, endpoint=False)[1:])
Bas Nijholt's avatar
Bas Nijholt committed

        # Calculate how many points belong to each interval.
        quals = [(-loss, x_range, 1) for (x_range, loss) in
Bas Nijholt's avatar
Bas Nijholt committed
                 self._losses.items()]
Bas Nijholt's avatar
Bas Nijholt committed
        heapq.heapify(quals)
Bas Nijholt's avatar
Bas Nijholt committed
        for point_number in range(n):
Bas Nijholt's avatar
Bas Nijholt committed
            quality, x, n = quals[0]
            heapq.heapreplace(quals, (quality * n / (n+1), x, n + 1))

        xs = sum((points(x, n) for quality, x, n in quals), [])

        # Add `None`s to data because then the same point will not be returned
        # upon a next request. This can be used for parallelization.
        if add_to_data:
            self.add_data(xs, itertools.repeat(None))

        return xs
Bas Nijholt's avatar
Bas Nijholt committed

    def get_status(self):
        """Report current status.

Bas Nijholt's avatar
Bas Nijholt committed
        So far just returns some internal variables [losses, intervals and
        data]
        """
        return self._losses, self._neighbors, self._ydata

    def get_results(self):
        """Work with distributed.client.Future objects."""
Bas Nijholt's avatar
Bas Nijholt committed
        done = [(x, y.result()) for x, y in self.unfinished.items() if y.done()]
        for x, y in done:
            self.unfinished.pop(x)
        if done:
            self.add_data(*np.array(done).T)

    def add_futures(self, xs, ys):
Bas Nijholt's avatar
Bas Nijholt committed
        """Add concurrent.futures to the self.unfinished dict."""
        try:
            for x, y in zip(xs, ys):
                self.unfinished[x] = y
        except TypeError:
            self.unfinished[xs] = ys