From ad7a3d2347c0fbb36b152a09d46e6435c1087327 Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Sat, 20 Oct 2018 13:14:53 +0200
Subject: [PATCH] add 'LearnerND.plot_3D' and add an example to the docs

---
 adaptive/learner/learnerND.py | 65 ++++++++++++++++++++++++++++++++++-
 1 file changed, 64 insertions(+), 1 deletion(-)

diff --git a/adaptive/learner/learnerND.py b/adaptive/learner/learnerND.py
index c8081878..aa89803f 100644
--- a/adaptive/learner/learnerND.py
+++ b/adaptive/learner/learnerND.py
@@ -10,7 +10,7 @@ import scipy.spatial
 
 from .base_learner import BaseLearner
 
-from ..notebook_integration import ensure_holoviews
+from ..notebook_integration import ensure_holoviews, ensure_plotly
 from .triangulation import (Triangulation, point_in_simplex,
                             circumsphere, simplex_volume_in_embedding)
 from ..utils import restore, cache_latest
@@ -585,6 +585,69 @@ class LearnerND(BaseLearner):
         else:
             raise ValueError("Only 1 or 2-dimensional plots can be generated.")
 
+    def plot_3D(self, with_triangulation=False):
+        """Plot the learner's data in 3D using plotly.
+
+        Parameters
+        ----------
+        with_triangulation : bool, default: False
+            Add the verticices to the plot.
+
+        Returns
+        -------
+        plot : plotly.offline.iplot object
+            The 3D plot of ``learner.data``.
+        """
+        plotly = ensure_plotly()
+
+        plots = []
+
+        vertices = self.tri.vertices
+        if with_triangulation:
+            Xe, Ye, Ze = [], [], []
+            for simplex in self.tri.simplices:
+                for s in itertools.combinations(simplex, 2):
+                    Xe += [vertices[i][0] for i in s] + [None]
+                    Ye += [vertices[i][1] for i in s] + [None]
+                    Ze += [vertices[i][2] for i in s] + [None]
+
+            plots.append(plotly.graph_objs.Scatter3d(
+                x=Xe, y=Ye, z=Ze, mode='lines',
+                line=dict(color='rgb(125,125,125)', width=1),
+                hoverinfo='none'
+            ))
+
+        Xn, Yn, Zn = zip(*vertices)
+        colors = [self.data[p] for p in self.tri.vertices]
+        marker = dict(symbol='circle', size=3, color=colors,
+            colorscale='Viridis',
+            line=dict(color='rgb(50,50,50)', width=0.5))
+
+        plots.append(plotly.graph_objs.Scatter3d(
+            x=Xn, y=Yn, z=Zn, mode='markers',
+            name='actors', marker=marker,
+            hoverinfo='text'
+        ))
+
+        axis = dict(
+            showbackground=False,
+            showline=False,
+            zeroline=False,
+            showgrid=False,
+            showticklabels=False,
+            title='',
+        )
+
+        layout = plotly.graph_objs.Layout(
+            showlegend=False,
+            scene=dict(xaxis=axis, yaxis=axis, zaxis=axis),
+            margin=dict(t=100),
+            hovermode='closest')
+
+        fig = plotly.graph_objs.Figure(data=plots, layout=layout)
+
+        return plotly.offline.iplot(fig)
+
     def _get_data(self):
         return self.data
 
-- 
GitLab