Skip to content
Snippets Groups Projects
Commit 0ee4c372 authored by Bas Nijholt's avatar Bas Nijholt
Browse files

use sortedcontainers.SortedDict() for neighbors

parent aa4ca3e6
Branches
Tags
1 merge request!2rename variables and begin implementing loss_improvement(points)
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import abc import abc
from copy import copy
import heapq import heapq
import itertools import itertools
from math import sqrt from math import sqrt
...@@ -25,7 +26,7 @@ class BaseLearner(metaclass=abc.ABCMeta): ...@@ -25,7 +26,7 @@ class BaseLearner(metaclass=abc.ABCMeta):
and returns a holoviews plot. and returns a holoviews plot.
""" """
def __init__(self, function): def __init__(self, function):
self.data = sortedcontainers.SortedDict() self.data = {}
self.function = function self.function = function
def add_data(self, xvalues, yvalues): def add_data(self, xvalues, yvalues):
...@@ -190,8 +191,8 @@ class Learner1D(BaseLearner): ...@@ -190,8 +191,8 @@ class Learner1D(BaseLearner):
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
# properties. # properties.
self.neighbors = {} self.neighbors = sortedcontainers.SortedDict()
self.neighbors_interp = {} self.neighbors_interp = sortedcontainers.SortedDict()
# Bounding box [[minx, maxx], [miny, maxy]]. # Bounding box [[minx, maxx], [miny, maxy]].
self._bbox = [list(bounds), [np.inf, -np.inf]] self._bbox = [list(bounds), [np.inf, -np.inf]]
...@@ -247,10 +248,9 @@ class Learner1D(BaseLearner): ...@@ -247,10 +248,9 @@ class Learner1D(BaseLearner):
pass pass
def find_neighbors(self, x, neighbors): def find_neighbors(self, x, neighbors):
xvals = sorted(neighbors) pos = neighbors.bisect_left(x)
pos = np.searchsorted(xvals, x) x_lower = neighbors.iloc[pos-1] if pos != 0 else None
x_lower = xvals[pos-1] if pos != 0 else None x_upper = neighbors.iloc[pos] if pos != len(neighbors) else None
x_upper = xvals[pos] if pos != len(xvals) else None
return x_lower, x_upper return x_lower, x_upper
def update_neighbors(self, x, real): def update_neighbors(self, x, real):
...@@ -258,10 +258,8 @@ class Learner1D(BaseLearner): ...@@ -258,10 +258,8 @@ class Learner1D(BaseLearner):
if x not in neighbors: # The point is new if x not in neighbors: # The point is new
x_lower, x_upper = self.find_neighbors(x, neighbors) x_lower, x_upper = self.find_neighbors(x, neighbors)
neighbors[x] = [x_lower, x_upper] neighbors[x] = [x_lower, x_upper]
neighbors[None] = [None, None] # To reduce the number of condititons. neighbors.get(x_lower, [None, None])[1] = x
neighbors[x_lower][1] = x neighbors.get(x_upper, [None, None])[0] = x
neighbors[x_upper][0] = x
del neighbors[None]
def update_scale(self, x, y): def update_scale(self, x, y):
self._bbox[0][0] = min(self._bbox[0][0], x) self._bbox[0][0] = min(self._bbox[0][0], x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment