Skip to content
Snippets Groups Projects
Commit dde9093a authored by Bas Nijholt's avatar Bas Nijholt
Browse files

small style changes and renames

parent 02f1671e
No related branches found
No related tags found
1 merge request!7implement 2D learner
...@@ -13,8 +13,9 @@ import numpy as np ...@@ -13,8 +13,9 @@ import numpy as np
from scipy import interpolate, optimize, special from scipy import interpolate, optimize, special
import sortedcontainers import sortedcontainers
class BaseLearner(metaclass=abc.ABCMeta): class BaseLearner(metaclass=abc.ABCMeta):
"""Base class for algorithms for learning a function 'f: X → Y' """Base class for algorithms for learning a function 'f: X → Y'.
Attributes Attributes
---------- ----------
...@@ -94,6 +95,7 @@ class BaseLearner(metaclass=abc.ABCMeta): ...@@ -94,6 +95,7 @@ class BaseLearner(metaclass=abc.ABCMeta):
def __setstate__(self, state): def __setstate__(self, state):
self.__dict__ = state self.__dict__ = state
class AverageLearner(BaseLearner): class AverageLearner(BaseLearner):
def __init__(self, function, atol=None, rtol=None): def __init__(self, function, atol=None, rtol=None):
"""A naive implementation of adaptive computing of averages. """A naive implementation of adaptive computing of averages.
...@@ -305,12 +307,12 @@ class Learner1D(BaseLearner): ...@@ -305,12 +307,12 @@ class Learner1D(BaseLearner):
if real: if real:
# If the scale has doubled, recompute all losses. # If the scale has doubled, recompute all losses.
if self._scale > self._oldscale * 2: if self._scale > self._oldscale * 2:
self.losses = {key: self.interval_loss(*key, self.data) self.losses = {xs: self.interval_loss(*xs, self.data)
for key in self.losses} for xs in self.losses}
self.losses_combined = {key: self.interval_loss(*key, self.losses_combined = {x: self.interval_loss(*x,
self.data_combined) for key in self.losses_combined} self.data_combined)
self._oldscale = self._scale for x in self.losses_combined}
self._oldscale = self._scale
def choose_points(self, n=10, add_data=True): def choose_points(self, n=10, add_data=True):
"""Return n points that are expected to maximally reduce the loss.""" """Return n points that are expected to maximally reduce the loss."""
...@@ -437,7 +439,8 @@ class BalancingLearner(BaseLearner): ...@@ -437,7 +439,8 @@ class BalancingLearner(BaseLearner):
loss_improvements = [] loss_improvements = []
pairs = [] pairs = []
for index, learner in enumerate(self.learners): for index, learner in enumerate(self.learners):
point, loss_improvement = learner.choose_points(n=1, add_data=False) point, loss_improvement = learner.choose_points(n=1,
add_data=False)
loss_improvements.append(loss_improvement[0]) loss_improvements.append(loss_improvement[0])
pairs.append((index, point[0])) pairs.append((index, point[0]))
x, _ = max(zip(pairs, loss_improvements), key=itemgetter(1)) x, _ = max(zip(pairs, loss_improvements), key=itemgetter(1))
...@@ -483,7 +486,7 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform): ...@@ -483,7 +486,7 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform):
Notes Notes
----- -----
Based on maximizing the disagreement between a linear and a cubic model:: Based on maximizing the disagreement between a linear and a cubic model:
f_1(x) = a + sum_j b_j (x_j - x_0) f_1(x) = a + sum_j b_j (x_j - x_0)
f_2(x) = a + sum_j c_j (x_j - x_0) + sum_ij d_ij (x_i - x_0) (x_j - x_0) f_2(x) = a + sum_j c_j (x_j - x_0) + sum_ij d_ij (x_i - x_0) (x_j - x_0)
...@@ -508,29 +511,26 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform): ...@@ -508,29 +511,26 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform):
# -- Least-squares fit: (ii) cubic model # -- Least-squares fit: (ii) cubic model
# (ii.a) fitting function values # (ii.a) fitting function values
x2 = (x[:-1, :, None] * x[:-1, None, :]).reshape(m-1, ndim*ndim) x2 = (x[:-1, :, None] * x[:-1, None, :]).reshape(m - 1, ndim**2)
x3 = (x[:-1, :, None, None] * x[:-1, None, :, None] * x[:-1, None, None, :] x3 = (x[:-1, :, None, None] * x[:-1, None, :, None] * x[:-1, None, None, :]
).reshape(m-1, ndim*ndim*ndim) ).reshape(m - 1, ndim**3)
lhs1 = np.c_[x[:-1], x2, x3] lhs1 = np.c_[x[:-1], x2, x3]
rhs1 = z[:-1] rhs1 = z[:-1]
# (ii.b) fitting gradients # (ii.b) fitting gradients
d_b = np.tile(np.eye(ndim)[None, :, :], (m, 1, 1)).reshape(m*ndim, ndim) d_b = np.tile(np.eye(ndim)[None, :, :], (m, 1, 1)).reshape(m * ndim, ndim)
o = np.eye(ndim) o = np.eye(ndim)
d_d = (o[None, :, None, :] * x[:, None, :, None] + d_d = (o[None, :, None, :] * x[:, None, :, None] +
x[:, None, None, :] * o[None, :, :, None]).reshape(m * ndim, ndim * ndim) x[:, None, None, :] * o[None, :, :, None]).reshape(m * ndim,
d_e = ( ndim * ndim)
o[:, None, :, None, None] * d_e = (o[:, None, :, None, None] * x[None, :, None, :, None] *
x[None, :, None, :, None] * x[None, :, None, None, :] x[None, :, None, None, :] +
+ x[None, :, :, None, None] * o[:, None, None, :, None] *
x[None, :, :, None, None] * x[None, :, None, None, :] +
o[:, None, None, :, None] * x[None, :, None, None, :] x[None, :, :, None, None] * x[None, :, None, :, None] *
+ o[:, None, None, None, :]).reshape(m * ndim, ndim**3)
x[None, :, :, None, None] *
x[None, :, None, :, None] * o[:, None, None, None, :]
).reshape(m*ndim, ndim*ndim*ndim)
lhs2 = np.c_[d_b, d_d, d_e] lhs2 = np.c_[d_b, d_d, d_e]
rhs2 = grad.ravel() rhs2 = grad.ravel()
...@@ -540,8 +540,8 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform): ...@@ -540,8 +540,8 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform):
rhs = np.r_[rhs1, rhs2] rhs = np.r_[rhs1, rhs2]
cd, _, rank, _ = np.linalg.lstsq(lhs, rhs) cd, _, rank, _ = np.linalg.lstsq(lhs, rhs)
c = cd[:ndim] c = cd[:ndim]
d = cd[ndim:ndim+ndim*ndim].reshape(ndim, ndim) d = cd[ndim:ndim + ndim**2].reshape(ndim, ndim)
e = cd[ndim+ndim*ndim:].reshape(ndim, ndim, ndim) e = cd[ndim + ndim**2:].reshape(ndim, ndim, ndim)
# -- Find point of maximum disagreement, inside the triangle # -- Find point of maximum disagreement, inside the triangle
...@@ -549,26 +549,25 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform): ...@@ -549,26 +549,25 @@ def _max_disagreement_location_in_simplex(points, values, grad, transform):
def func(x): def func(x):
x = itr.dot(x) x = itr.dot(x)
v = -abs(((c - b)*x).sum() v = (((c - b) * x).sum() +
+ (d*x[:, None]*x[None, :]).sum() (d * x[:, None] * x[None, :]).sum() +
+ (e*x[:, None, None]*x[None, :, None]*x[None, None, :]).sum() (e * x[:, None, None] * x[None, :, None] * x[None, None, :]).sum())
)**2 v = -abs(v)**2
return np.array(v) return np.array(v)
cons = [lambda x: np.array([1 - x.sum()])] cons = [lambda x: np.array([1 - x.sum()])]
for j in range(ndim): for j in range(ndim):
cons.append(lambda x: np.array([x[j]])) cons.append(lambda x: np.array([x[j]]))
ps = [1.0/(ndim+1)] * ndim ps = [1.0 / (ndim + 1)] * ndim
p = optimize.fmin_slsqp(func, ps, ieqcons=cons, disp=False, p = optimize.fmin_slsqp(func, ps, ieqcons=cons, disp=False,
bounds=[(0, 1)]*ndim) bounds=[(0, 1)] * ndim)
p = itr.dot(p) + points[-1] p = itr.dot(p) + points[-1]
return p return p
class Learner2D(BaseLearner): class Learner2D(BaseLearner):
"""Sample a 2-D function adaptively. """Sample a 2-D function adaptively.
Parameters Parameters
...@@ -687,18 +686,19 @@ class Learner2D(BaseLearner): ...@@ -687,18 +686,19 @@ class Learner2D(BaseLearner):
p = self.points p = self.points
v = self.values v = self.values
if v.shape[0] < self.ndim+1: if v.shape[0] < self.ndim + 1:
raise ValueError("too few points...") raise ValueError("too few points...")
# Interpolate the unfinished points # Interpolate the unfinished points
if self._interp: if self._interp:
if self.n_real >= 4: if self.n_real >= 4:
ip = interpolate.LinearNDInterpolator(self.points_real, ip_real = interpolate.LinearNDInterpolator(self.points_real,
self.values_real) self.values_real)
else: else:
ip = lambda x: np.empty(len(x)) # Important not to return exact zeros # It is important not to return exact zeros
ip_real = lambda x: np.empty(len(x))
n_interp = list(self._interp.values()) n_interp = list(self._interp.values())
values = ip(p[n_interp]) values = ip_real(p[n_interp])
for n, value in zip(n_interp, values): for n, value in zip(n_interp, values):
v[n] = value v[n] = value
...@@ -716,13 +716,12 @@ class Learner2D(BaseLearner): ...@@ -716,13 +716,12 @@ class Learner2D(BaseLearner):
dev = 0 dev = 0
for j in range(self.ndim): for j in range(self.ndim):
vest = v[:, j, None] + ((p[:, :, :] - vest = v[:, j, None] + ((p[:, :, :] - p[:, j, None, :]) *
p[:, j, None, :]) * g[:, j, None, :]).sum(axis=-1)
g[:, j, None, :]).sum(axis=-1)
dev += abs(vest - v).max(axis=1) dev += abs(vest - v).max(axis=1)
q = p[:, :-1, :] - p[:, -1, None, :] q = p[:, :-1, :] - p[:, -1, None, :]
vol = abs(q[:, 0, 0]*q[:, 1, 1] - q[:, 0, 1]*q[:, 1, 0]) vol = abs(q[:, 0, 0] * q[:, 1, 1] - q[:, 0, 1] * q[:, 1, 0])
vol /= special.gamma(1 + self.ndim) vol /= special.gamma(1 + self.ndim)
dev *= vol dev *= vol
...@@ -730,7 +729,7 @@ class Learner2D(BaseLearner): ...@@ -730,7 +729,7 @@ class Learner2D(BaseLearner):
if stack_till is None: if stack_till is None:
# Take new points # Take new points
try: try:
cp = 0.9*dev.max() cp = 0.9 * dev.max()
nstack = min(self.nstack, (dev > cp).sum()) nstack = min(self.nstack, (dev > cp).sum())
if nstack <= 0: if nstack <= 0:
raise ValueError() raise ValueError()
...@@ -748,9 +747,9 @@ class Learner2D(BaseLearner): ...@@ -748,9 +747,9 @@ class Learner2D(BaseLearner):
return True return True
return False return False
for j in range(len(dev)): for j, _ in enumerate(dev):
jsimplex = np.argmax(dev) jsimplex = np.argmax(dev)
# -- Estimate point of maximum curvature inside the simplex # Estimate point of maximum curvature inside the simplex
p = tri.points[tri.vertices[jsimplex]] p = tri.points[tri.vertices[jsimplex]]
v = ip.values[tri.vertices[jsimplex]] v = ip.values[tri.vertices[jsimplex]]
g = grad[tri.vertices[jsimplex]] g = grad[tri.vertices[jsimplex]]
...@@ -759,11 +758,11 @@ class Learner2D(BaseLearner): ...@@ -759,11 +758,11 @@ class Learner2D(BaseLearner):
point_new = _max_disagreement_location_in_simplex( point_new = _max_disagreement_location_in_simplex(
p, v, g, transform) p, v, g, transform)
# -- Reduce to bounds # Reduce to bounds
for j, (a, b) in enumerate(self.bounds): for j, (a, b) in enumerate(self.bounds):
point_new[j] = max(a, min(b, point_new[j])) point_new[j] = max(a, min(b, point_new[j]))
# -- Check if it is really new (also, revert to mean point optionally) # Check if it is really new (also, revert to mean point optionally)
if point_exists(point_new): if point_exists(point_new):
dev[jsimplex] = 0 dev[jsimplex] = 0
continue continue
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment