Skip to content
Snippets Groups Projects
Commit f53c378b authored by Jorn Hoofwijk's avatar Jorn Hoofwijk Committed by Bas Nijholt
Browse files

add support for neighbours in loss computation in LearnerND

parent 2b8d7492
No related tags found
1 merge request!132WIP: add support for neighbours in loss computation in LearnerND
Pipeline #13210 passed
......@@ -17,6 +17,12 @@ from .triangulation import (Triangulation, point_in_simplex,
from ..utils import restore, cache_latest
def to_list(inp):
if isinstance(inp, Iterable):
return list(inp)
return [inp]
def volume(simplex, ys=None):
# Notice the parameter ys is there so you can use this volume method as
# as loss function
......@@ -57,6 +63,66 @@ def default_loss(simplex, ys):
return simplex_volume_in_embedding(pts)
def triangle_loss(simplex, neighbors):
"""
Simplex and the neighbors are a list of coordinates,
both input and output coordinates.
Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex, it should contain both
input and output coordinates in the same tuple.
i.e. (x1, x2, x3, y1, y2) in the case of `f: R^3 -> R^2`.
neighbors : list of tuples
The missing vertex of the simplex that shares exactly one
face with `simplex`. May contain None if the simplex is
at the boundary.
Returns
-------
loss : float
"""
neighbors = [n for n in neighbors if n is not None]
if len(neighbors) == 0:
return 0
return sum(simplex_volume_in_embedding([*simplex, neighbour])
for neighbour in neighbors) / len(neighbors)
def get_curvature_loss(curvature_factor=1, volume_factor=0, input_volume_factor=0.05):
# XXX: add doc-string!
def curvature_loss(simplex, neighbors):
"""Simplex and the neighbors are a list of coordinates,
both input and output coordinates.
Parameters
----------
simplex : list of tuples
Each entry is one point of the simplex, it should contain both
input and output coordinates in the same tuple.
i.e. (x1, x2, x3, y1, y2) in the case of `f: R^3 -> R^2`.
neighbors : list of tuples
The missing vertex of the simplex that shares exactly one
face with `simplex`. May contain None if the simplex is
at the boundary.
Returns
-------
loss : float
"""
dim = len(simplex) - 1
loss_volume = simplex_volume_in_embedding(simplex)
xs = [pt[:dim] for pt in simplex]
loss_input_volume = volume(xs)
loss_curvature = triangle_loss(simplex, neighbors)
return (curvature_factor * loss_curvature ** (dim / (dim+1))
+ volume_factor * loss_volume
+ input_volume_factor * loss_input_volume)
return curvature_loss
def choose_point_in_simplex(simplex, transform=None):
"""Choose a new point in inside a simplex.
......@@ -67,9 +133,10 @@ def choose_point_in_simplex(simplex, transform=None):
Parameters
----------
simplex : numpy array
The coordinates of a triangle with shape (N+1, N)
The coordinates of a triangle with shape (N+1, N).
transform : N*N matrix
The multiplication to apply to the simplex before choosing the new point
The multiplication to apply to the simplex before choosing
the new point.
Returns
-------
......@@ -149,10 +216,15 @@ class LearnerND(BaseLearner):
children based on volume.
"""
def __init__(self, func, bounds, loss_per_simplex=None):
def __init__(self, func, bounds, loss_per_simplex=None, loss_depends_on_neighbors=False):
self.ndim = len(bounds)
self._vdim = None
self.loss_per_simplex = loss_per_simplex or default_loss
self._loss_depends_on_neighbors = loss_depends_on_neighbors
if loss_depends_on_neighbors:
self.loss_per_simplex = loss_per_simplex or get_curvature_loss()
else:
self.loss_per_simplex = loss_per_simplex or default_loss
self.bounds = tuple(tuple(map(float, b)) for b in bounds)
self.data = OrderedDict()
self.pending_points = set()
......@@ -295,10 +367,10 @@ class LearnerND(BaseLearner):
simplex = tuple(simplex)
simplices = [self.tri.vertex_to_simplices[i] for i in simplex]
neighbours = set.union(*simplices)
neighbors = set.union(*simplices)
# Neighbours also includes the simplex itself
for simpl in neighbours:
for simpl in neighbors:
_, to_add = self._try_adding_pending_point_to_simplex(point, simpl)
if to_add is None:
continue
......@@ -362,6 +434,7 @@ class LearnerND(BaseLearner):
# find the simplex with the highest loss, we do need to check that the
# simplex hasn't been deleted yet
while len(self._simplex_queue):
# XXX: Need to add check that the loss is the most recent computed loss
loss, simplex, subsimplex = heapq.heappop(self._simplex_queue)
if (subsimplex is None
and simplex in self.tri.simplices
......@@ -417,6 +490,22 @@ class LearnerND(BaseLearner):
return self._ask_best_point() # O(log N)
def _compute_loss(self, simplex):
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
if not self._loss_depends_on_neighbors:
return float(self.loss_per_simplex(vertices, values))
neighbors = self.tri.get_simplices_attached_to_points(simplex)
neighbour_indices = [next(iter(set(simpl) - set(simplex)))
for simpl in neighbors]
neighbour_points = self.tri.get_vertices(neighbour_indices)
simpl = [(*x, *to_list(y)) for x, y in zip(vertices, values)]
neigh = [(*x, *to_list(self.data[tuple(x)])) for x in neighbour_points]
return float(self.loss_per_simplex(simpl, neigh))
def update_losses(self, to_delete: set, to_add: set):
# XXX: add the points outside the triangulation to this as well
pending_points_unbound = set()
......@@ -429,12 +518,9 @@ class LearnerND(BaseLearner):
pending_points_unbound = set(p for p in pending_points_unbound
if p not in self.data)
for simplex in to_add:
vertices = self.tri.get_vertices(simplex)
values = [self.data[tuple(v)] for v in vertices]
loss = float(self.loss_per_simplex(vertices, values))
self._losses[simplex] = float(loss)
loss = self._compute_loss(simplex)
self._losses[simplex] = loss
for p in pending_points_unbound:
self._try_adding_pending_point_to_simplex(p, simplex)
......@@ -446,6 +532,21 @@ class LearnerND(BaseLearner):
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
points_of_added_simplices = set.union(*[set(s) for s in to_add])
neighbors = self.tri.get_simplices_attached_to_points(
points_of_added_simplices) - to_add
for simplex in neighbors:
loss = self._compute_loss(simplex)
self._losses[simplex] = loss
if simplex not in self._subtriangulations:
heapq.heappush(self._simplex_queue, (-loss, simplex, None))
continue
self._update_subsimplex_losses(
simplex, self._subtriangulations[simplex].simplices)
def losses(self):
"""Get the losses of each simplex in the current triangulation, as dict
......@@ -763,7 +864,7 @@ class LearnerND(BaseLearner):
plot = plot * self.plot_isoline(level=l, n=-1)
return plot
vertices, lines = self.self._get_iso(level, which='line')
vertices, lines = self._get_iso(level, which='line')
paths = [[vertices[i], vertices[j]] for i, j in lines]
contour = hv.Path(paths)
......
......@@ -204,7 +204,6 @@ def simplex_volume_in_embedding(vertices) -> float:
if the vertices do not form a simplex (for example,
because they are coplanar, colinear or coincident).
"""
# Implements http://mathworld.wolfram.com/Cayley-MengerDeterminant.html
# Modified from https://codereview.stackexchange.com/questions/77593/calculating-the-volume-of-a-tetrahedron
......@@ -230,6 +229,8 @@ def simplex_volume_in_embedding(vertices) -> float:
vol_square = np.linalg.det(sq_dists_mat) / coeff
if vol_square <= 0:
if abs(vol_square) < 1e-15:
return 0
raise ValueError('Provided vertices do not form a simplex')
return np.sqrt(vol_square)
......@@ -493,17 +494,11 @@ class Triangulation:
bad_triangles.add(simplex)
# Get all simplices that share at least a point with the simplex
neighbours = set.union(*[self.vertex_to_simplices[p]
for p in todo_points])
neighbors = self.get_neighbors_from_vertices(todo_points)
# Filter out the already evaluated simplices
neighbours = neighbours - done_simplices
# Keep only the simplices sharing a whole face with the current simplex
neighbours = set(
simpl for simpl in neighbours
if len(set(simpl) & set(simplex)) == self.dim # they share a face
)
queue.update(neighbours)
neighbors = neighbors - done_simplices
neighbors = self.get_face_sharing_neighbors(neighbors, simplex)
queue.update(neighbors)
faces = list(self.faces(simplices=bad_triangles))
......@@ -606,6 +601,20 @@ class Triangulation:
"""Simplices originating from a vertex don't overlap."""
raise NotImplementedError
def get_neighbors_from_vertices(self, simplex):
return set.union(*[self.vertex_to_simplices[p]
for p in simplex])
def get_face_sharing_neighbors(self, neighbors, simplex):
"""Keep only the simplices sharing a whole face with simplex."""
return set(simpl for simpl in neighbors
if len(set(simpl) & set(simplex)) == self.dim) # they share a face
def get_simplices_attached_to_points(self, indices):
# Get all simplices that share at least a point with the simplex
neighbors = self.get_neighbors_from_vertices(indices)
return self.get_face_sharing_neighbors(neighbors, indices)
@property
def hull(self):
"""Compute hull from triangulation.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment