From 8ca980f705dc76954c5ddbfec7dbe78acdcf87bc Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Wed, 13 Dec 2017 16:05:17 +0100
Subject: [PATCH] select matplotlib backend only when needed

Previously, in order to not to fix the matplotlib backend, we required
users to import matplotlib.pyplot before calling any Kwant plotting
functions.  This did not have the desired effect, since we did import
`matplotlib.backends` and that also fixes the backend.

Now, both backends and pyplot are imported at the last possible moment
and a warning is emitted if this fixes the backend.
---
 doc/source/tutorial/first_steps.rst | 12 +++++------
 kwant/plotter.py                    | 33 ++++++++++++++++++-----------
 kwant/tests/test_plotter.py         | 10 +++++++--
 3 files changed, 35 insertions(+), 20 deletions(-)

diff --git a/doc/source/tutorial/first_steps.rst b/doc/source/tutorial/first_steps.rst
index 9a74435f..1ddce048 100644
--- a/doc/source/tutorial/first_steps.rst
+++ b/doc/source/tutorial/first_steps.rst
@@ -315,12 +315,12 @@ subbands that increases with energy.
 
    - Instead of plotting to the screen (which is standard)
      `~kwant.plotter.plot` can also write to a file specified by the argument
-     ``file``.  For the plotting to the screen to work the module
-     ``matplotlib.pyplot`` has to be imported.  (An informative error message
-     will remind you if you forget.)  The reason for this is pretty technical:
-     matplotlib's "backend" can only be chosen before ``matplotlib.pyplot`` has
-     been imported.  Would Kwant import that module by itself, it would deprive
-     you of the possibility to choose a non-default backend later.
+     ``file``.
+
+   - Due to matplotlib's limitations, Kwant's plotting routines have the
+     side effect of fixing matplotlib's "backend".  If you would like to choose
+     a different backend than the standard one, you must do so before asking
+     Kwant to plot anything.
 
 
 .. rubric:: Footnotes
diff --git a/kwant/plotter.py b/kwant/plotter.py
index 4eef5271..b591d585 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -16,6 +16,7 @@ system in two or three dimensions.
 """
 
 from collections import defaultdict
+import sys
 import itertools
 import functools
 import warnings
@@ -35,7 +36,6 @@ try:
     import matplotlib.cm
     from matplotlib.figure import Figure
     from matplotlib import collections
-    from matplotlib.backends.backend_agg import FigureCanvasAgg
     from . import _colormaps
     mpl_available = True
     try:
@@ -659,15 +659,23 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None,
     """
     if not mpl_available:
         raise RuntimeError('matplotlib is not installed.')
+
+    # We import backends and pyplot only at the last possible moment (=now)
+    # because this has the side effect of selecting the matplotlib backend for
+    # good.  Warn if backend has not been set yet.  This check is the same as
+    # the one performed inside matplotlib.use.
+    if 'matplotlib.backends' not in sys.modules:
+        warnings.warn("Kwant's plotting functions have\nthe side effect of "
+                      "selecting the matplotlib backend. To avoid this "
+                      "warning,\nimport matplotlib.pyplot, "
+                      "matplotlib.backends or call matplotlib.use().",
+                      RuntimeWarning, stacklevel=3)
+
     if output_mode == 'auto':
         output_mode = 'pyplot' if file is None else 'file'
     if output_mode == 'pyplot':
-        try:
-            fake_fig = matplotlib.pyplot.figure()
-        except AttributeError:
-            msg = ('matplotlib.pyplot is unavailable.  Execute `import '
-                   'matplotlib.pyplot` or use a different output mode.')
-            raise RuntimeError(msg)
+        from matplotlib import pyplot
+        fake_fig = pyplot.figure()
         fake_fig.canvas.figure = fig
         fig.canvas = fake_fig.canvas
         for ax in fig.axes:
@@ -676,16 +684,16 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None,
             except AttributeError:
                 pass
         if show:
-            matplotlib.pyplot.show()
+            pyplot.show()
     elif output_mode in ['return', 'file']:
+        from matplotlib.backends.backend_agg import FigureCanvasAgg
         fig.canvas = FigureCanvasAgg(fig)
         if output_mode == 'file':
-            fig.canvas = canvas = FigureCanvasAgg(fig)
             if savefile_opts is None:
                 savefile_opts = ([], {})
             if 'dpi' not in savefile_opts[1]:
                 savefile_opts[1]['dpi'] = fig.dpi
-            canvas.print_figure(file, *savefile_opts[0], **savefile_opts[1])
+            fig.canvas.print_figure(file, *savefile_opts[0], **savefile_opts[1])
     else:
         raise ValueError('Unknown output_mode')
     return fig
@@ -2158,8 +2166,9 @@ def current(syst, current, relwidth=0.05, **kwargs):
         A figure with the output if `ax` is not set, else None.
 
     """
-    return streamplot(*interpolate_current(syst, current, relwidth),
-                      **kwargs)
+    with _common.reraise_warnings(4):
+        return streamplot(*interpolate_current(syst, current, relwidth),
+                          **kwargs)
 
 
 # TODO (Anton): Fix plotting of parts of the system using color = np.nan.
diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py
index 12ea35ab..23b0d01a 100644
--- a/kwant/tests/test_plotter.py
+++ b/kwant/tests/test_plotter.py
@@ -24,10 +24,11 @@ try:
     from mpl_toolkits import mplot3d
     import matplotlib
 
+    # This check is the same as the one performed inside matplotlib.use.
+    matplotlib_backend_chosen = 'matplotlib.backends' in sys.modules
     # If the user did not already choose a backend, then choose
     # the one with the least dependencies.
-    # This check is the same as the one performed inside matplotlib.use.
-    if 'matplotlib.backends' not in sys.modules:
+    if not matplotlib_backend_chosen:
         matplotlib.use('Agg')
 
     from matplotlib import pyplot  # pragma: no flakes
@@ -37,6 +38,11 @@ except ImportError:
 from kwant import plotter
 
 
+def test_matplotlib_backend_unset():
+    """Simply importing Kwant should not set the matplotlib backend."""
+    assert matplotlib_backend_chosen is False
+
+
 def test_importable_without_matplotlib():
     prefix, sep, suffix = plotter.__file__.rpartition('.')
     if suffix in ['pyc', 'pyo']:
-- 
GitLab