Skip to content
Snippets Groups Projects
Commit f07282c4 authored by Jorn Hoofwijk's avatar Jorn Hoofwijk Committed by Bas Nijholt
Browse files

add iso_surface_plot to adaptive learnerND

parent 3a69cd99
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !124. Comments created here will be created in the context of that merge request.
...@@ -10,7 +10,7 @@ import scipy.spatial ...@@ -10,7 +10,7 @@ import scipy.spatial
from .base_learner import BaseLearner from .base_learner import BaseLearner
from ..notebook_integration import ensure_holoviews from ..notebook_integration import ensure_holoviews, ensure_plotly
from .triangulation import (Triangulation, point_in_simplex, from .triangulation import (Triangulation, point_in_simplex,
circumsphere, simplex_volume_in_embedding) circumsphere, simplex_volume_in_embedding)
from ..utils import restore, cache_latest from ..utils import restore, cache_latest
...@@ -576,3 +576,129 @@ class LearnerND(BaseLearner): ...@@ -576,3 +576,129 @@ class LearnerND(BaseLearner):
return im.opts(style=dict(cmap='viridis')) return im.opts(style=dict(cmap='viridis'))
else: else:
raise ValueError("Only 1 or 2-dimensional plots can be generated.") raise ValueError("Only 1 or 2-dimensional plots can be generated.")
def _get_isosurface(self, level=0.0):
if self.ndim != 3 or self.vdim != 1:
raise Exception('Isosurface plotting is only supported'
' for a 3D input and 1D output')
vertices = [] # index -> (x,y,z)
faces = [] # tuple of indices of the corner points
from_line_to_vertex = {} # the interpolated vertex (index) between two known points
def _get_vertex_index(a, b):
if (a, b) in from_line_to_vertex:
return from_line_to_vertex[(a, b)]
# Otherwise compute it and cache the result.
vertex_a = self.tri.vertices[a]
vertex_b = self.tri.vertices[b]
value_a = self.data[vertex_a]
value_b = self.data[vertex_b]
da = abs(value_a - level)
db = abs(value_b - level)
dab = da + db
new_pt = (db / dab * np.array(vertex_a)
+ da / dab * np.array(vertex_b))
new_index = len(vertices)
vertices.append(new_pt)
from_line_to_vertex[(a, b)] = new_index
return new_index
for simplex in self.tri.simplices:
plane = []
for a, b in itertools.combinations(simplex, 2):
va = self.data[self.tri.vertices[a]]
vb = self.data[self.tri.vertices[b]]
if min(va, vb) < level <= max(va, vb):
vi = _get_vertex_index(a, b)
should_add = True
for pi in plane:
if np.allclose(vertices[vi], vertices[pi]):
should_add = False
if should_add:
plane.append(vi)
if len(plane) == 3:
faces.append(plane)
elif len(plane) == 4:
faces.append(plane[:3])
faces.append(plane[1:])
if len(faces) == 0:
r_min = min(self.data[v] for v in self.tri.vertices)
r_max = max(self.data[v] for v in self.tri.vertices)
raise ValueError(
f"Could not draw isosurface for level={level}, as"
" this value is not inside the function range. Please choose"
f" a level strictly inside interval ({r_min}, {r_max})"
)
return vertices, faces
def plot_isosurface(self, level=0.0, hull_opacity=0.2):
"""Plots the linearly interpolated isosurface of the function,
based on the currently evaluated points. This is the 3D analog
of an isoline.
Parameters
----------
level : float, default 0.0
the function value which you are interested in.
hull_opacity : float, default 0.0
the opacity of the hull of the domain.
Returns
-------
plot : plotly.offline.iplot object
The plot object of the isosurface.
"""
plotly = ensure_plotly()
vertices, faces = self._get_isosurface(level)
x, y, z = zip(*vertices)
fig = plotly.figure_factory.create_trisurf(
x=x, y=y, z=z, plot_edges=False,
simplices=faces, title="Isosurface")
if hull_opacity < 1e-3:
# Do not compute the hull_mesh.
return plotly.offline.iplot(fig)
hull_mesh = self._get_hull_mesh(opacity=hull_opacity)
return plotly.offline.iplot([fig.data[0], hull_mesh])
def _get_hull_mesh(self, opacity=0.2):
plotly = ensure_plotly()
hull = scipy.spatial.ConvexHull(self._bounds_points)
# Find the colors of each plane, giving triangles which are coplanar
# the same color, such that a square face has the same color.
color_dict = {}
def _get_plane_color(simplex):
simplex = tuple(simplex)
# If the volume of the two triangles combined is zero then they
# belong to the same plane.
for simplex_key, color in color_dict.items():
points = [hull.points[i] for i in set(simplex_key + simplex)]
points = np.array(points)
if np.linalg.matrix_rank(points[1:] - points[0]) < 3:
return color
if scipy.spatial.ConvexHull(points).volume < 1e-5:
return color
color_dict[simplex] = tuple(random.randint(0, 255)
for _ in range(3))
return color_dict[simplex]
colors = [_get_plane_color(simplex) for simplex in hull.simplices]
x, y, z = zip(*self._bounds_points)
i, j, k = hull.simplices.T
return plotly.graph_objs.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k,
facecolor=colors, opacity=opacity)
...@@ -8,7 +8,7 @@ import warnings ...@@ -8,7 +8,7 @@ import warnings
_async_enabled = False _async_enabled = False
_plotting_enabled = False _plotting_enabled = False
_plotly_notebook_imported = False
def notebook_extension(): def notebook_extension():
if not in_ipynb(): if not in_ipynb():
...@@ -38,6 +38,27 @@ def ensure_holoviews(): ...@@ -38,6 +38,27 @@ def ensure_holoviews():
raise RuntimeError('holoviews is not installed; plotting is disabled.') raise RuntimeError('holoviews is not installed; plotting is disabled.')
def ensure_plotly():
if not in_ipynb():
raise RuntimeError(
"plotting functions using 'plotly' may only be run "
"from a Jupyter notebook.")
global _plotly_notebook_imported
try:
import plotly.graph_objs
import plotly.figure_factory
import plotly.offline
plotly = importlib.import_module('plotly')
if not _plotly_notebook_imported:
plotly.offline.init_notebook_mode()
_plotly_notebook_imported = True
return plotly
except ModuleNotFoundError:
raise RuntimeError('plotly is not installed; plotting is disabled.')
def in_ipynb(): def in_ipynb():
try: try:
# If we are running in IPython, then `get_ipython()` is always a global # If we are running in IPython, then `get_ipython()` is always a global
......
...@@ -38,6 +38,7 @@ extras_require = { ...@@ -38,6 +38,7 @@ extras_require = {
'ipywidgets', 'ipywidgets',
'bokeh', 'bokeh',
'matplotlib', 'matplotlib',
'plotly'
], ],
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment