diff --git a/kwant/plotter.py b/kwant/plotter.py
index abd9fe2cc2ba10a973f208d05e24ee440d759c64..4eef52710f3ed0af045de21649aa8faa152be5aa 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -37,7 +37,7 @@ try:
     from matplotlib import collections
     from matplotlib.backends.backend_agg import FigureCanvasAgg
     from . import _colormaps
-    mpl_enabled = True
+    mpl_available = True
     try:
         from mpl_toolkits import mplot3d
         has3d = True
@@ -47,7 +47,7 @@ try:
 except ImportError:
     warnings.warn("matplotlib is not available, only iterator-providing "
                   "functions will work.", RuntimeWarning)
-    mpl_enabled = False
+    mpl_available = False
 
 from . import system, builder, _common
 
@@ -72,7 +72,7 @@ def matplotlib_chores():
     pre_1_4_matplotlib = [int(x) for x in ver.split('.')[:2]] < [1, 4]
 
 
-if mpl_enabled:
+if mpl_available:
     matplotlib_chores()
 
 
@@ -95,7 +95,7 @@ def _sample_array(array, n_samples, rng=None):
     return array[rng.choice(range(la), min(n_samples, la))]
 
 
-if mpl_enabled:
+if mpl_available:
     class LineCollection(collections.LineCollection):
         def __init__(self, segments, reflen=None, **kwargs):
             super().__init__(segments, **kwargs)
@@ -637,10 +637,9 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None,
         The output mode to be used.  Can be one of the following:
         'pyplot' : attach the figure to pyplot, with the same behavior as if
         pyplot.plot was called to create this figure.
-        'ipython' : attach a `FigureCanvasAgg` to the figure and return it.
-        'return' : return the figure.
-        'file' : same as 'ipython', but also save the figure into a file.
-        'auto' : if fname is given, save to a file, else if pyplot
+        'return' : attach a `FigureCanvasAgg` to the figure and return it.
+        'file' : same as 'return', but also save the figure into a file.
+        'auto' : if fname is given, save to a file, otherwise like pyplot
         is imported, attach to pyplot, otherwise just return.  See also the
         notes below.
     file : string or a file object
@@ -658,17 +657,10 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None,
     matplotlib in that the `dpi` attribute of the figure is used by defaul
     instead of the matplotlib config setting.
     """
-    if not mpl_enabled:
+    if not mpl_available:
         raise RuntimeError('matplotlib is not installed.')
     if output_mode == 'auto':
-        if file is not None:
-            output_mode = 'file'
-        else:
-            try:
-                matplotlib.pyplot.get_backend()
-                output_mode = 'pyplot'
-            except AttributeError:
-                output_mode = 'pyplot'
+        output_mode = 'pyplot' if file is None else 'file'
     if output_mode == 'pyplot':
         try:
             fake_fig = matplotlib.pyplot.figure()
@@ -685,21 +677,18 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None,
                 pass
         if show:
             matplotlib.pyplot.show()
-        return fig
-    elif output_mode == 'return':
-        canvas = FigureCanvasAgg(fig)
-        fig.canvas = canvas
-        return fig
-    elif output_mode == 'file':
-        canvas = FigureCanvasAgg(fig)
-        if savefile_opts is None:
-            savefile_opts = ([], {})
-        if 'dpi' not in savefile_opts[1]:
-            savefile_opts[1]['dpi'] = fig.dpi
-        canvas.print_figure(file, *savefile_opts[0], **savefile_opts[1])
-        return fig
+    elif output_mode in ['return', 'file']:
+        fig.canvas = FigureCanvasAgg(fig)
+        if output_mode == 'file':
+            fig.canvas = canvas = FigureCanvasAgg(fig)
+            if savefile_opts is None:
+                savefile_opts = ([], {})
+            if 'dpi' not in savefile_opts[1]:
+                savefile_opts[1]['dpi'] = fig.dpi
+            canvas.print_figure(file, *savefile_opts[0], **savefile_opts[1])
     else:
-        assert False, 'Unknown output_mode'
+        raise ValueError('Unknown output_mode')
+    return fig
 
 
 # Extracting necessary data from the system.
@@ -1139,7 +1128,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
       its aspect ratio.
 
     """
-    if not mpl_enabled:
+    if not mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for plot()")
 
@@ -1555,7 +1544,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
       correspond to exactly one pixel.
     """
 
-    if not mpl_enabled:
+    if not mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for map()")
 
@@ -1652,7 +1641,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
     See `~kwant.physics.Bands` for the calculation of dispersion without plotting.
     """
 
-    if not mpl_enabled:
+    if not mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for bands()")
 
@@ -1727,7 +1716,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
         A figure with the output if `ax` is not set, else None.
     """
 
-    if not mpl_enabled:
+    if not mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for plot_spectrum()")
     if y is not None and not has3d:
@@ -2067,7 +2056,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
     fig : matplotlib figure
         A figure with the output if `ax` is not set, else None.
     """
-    if not mpl_enabled:
+    if not mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for current()")
 
diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py
index 11266b4e8e167a46acdb40cb9fcd286c9f74de8f..12ea35ab54900e5fa29aa09e17f06de174dcc1c6 100644
--- a/kwant/tests/test_plotter.py
+++ b/kwant/tests/test_plotter.py
@@ -105,7 +105,7 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
     return syst
 
 
-@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_plot():
     plot = plotter.plot
     syst2d = syst_2d()
@@ -155,7 +155,7 @@ def bad_transform(pos):
     x, y = pos
     return x, y, 0
 
-@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_map():
     syst = syst_2d()
     with tempfile.TemporaryFile('w+b') as out:
@@ -191,7 +191,7 @@ def test_mask_interpolate():
                       coords, np.ones(2 * len(coords)))
 
 
-@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_bands():
 
     syst = syst_2d().finalized().leads[0]
@@ -206,7 +206,7 @@ def test_bands():
         plotter.bands(syst, ax=ax, file=out)
 
 
-@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_spectrum():
 
     def ham_1d(a, b, c):
@@ -417,7 +417,7 @@ def test_current_interpolation():
     assert scipy.stats.linregress(np.log(data))[2] < -0.8
 
 
-@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_current():
     syst = syst_2d().finalized()
     J = kwant.operator.Current(syst)
diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py
index dc8ed3a927ea2248d4991d856b9f00310b812474..eb981d2798fa1b8fd42e36d26a0cd88039274b11 100644
--- a/kwant/tests/test_wraparound.py
+++ b/kwant/tests/test_wraparound.py
@@ -17,7 +17,7 @@ from kwant import plotter
 from kwant.wraparound import wraparound, plot_2d_bands
 from kwant._common import get_parameters
 
-if plotter.mpl_enabled:
+if plotter.mpl_available:
     from mpl_toolkits import mplot3d  # pragma: no flakes
     from matplotlib import pyplot  # pragma: no flakes
 
@@ -201,7 +201,7 @@ def test_symmetry():
                 assert np.all(orig == new)
 
 
-@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_plot_2d_bands():
     chain = kwant.lattice.chain()
     square = kwant.lattice.square()