diff --git a/doc/source/code/figure/discretize.py.diff b/doc/source/code/figure/discretize.py.diff
index 9ca0df555e3b8553f7772d6f2253e485e6d4df73..aebaf3e03c10d291d2b25789c1145b90d6210a15 100644
--- a/doc/source/code/figure/discretize.py.diff
+++ b/doc/source/code/figure/discretize.py.diff
@@ -1,4 +1,4 @@
-@@ -1,222 +1,236 @@
+@@ -1,225 +1,239 @@
  # Tutorial 2.9. Processing continuum Hamiltonians with discretize
  # ===============================================================
  #
@@ -13,6 +13,9 @@
 +import _defs
  
  import kwant
+ #HIDDEN_BEGIN_import
+ import kwant.continuum
+ #HIDDEN_END_import
  import scipy.sparse.linalg
  import scipy.linalg
  import numpy as np
diff --git a/doc/source/code/figure/plot_qahe.py.diff b/doc/source/code/figure/plot_qahe.py.diff
index d726f98e908feac22a767b4387c8e6b561523bb4..1788f677df857d1b75b8706aa0907fad4b9a467d 100644
--- a/doc/source/code/figure/plot_qahe.py.diff
+++ b/doc/source/code/figure/plot_qahe.py.diff
@@ -1,4 +1,4 @@
-@@ -1,71 +1,75 @@
+@@ -1,72 +1,76 @@
  # Comprehensive example: quantum anomalous Hall effect
  # ====================================================
  #
@@ -16,6 +16,7 @@
  import math
  import matplotlib.pyplot
  import kwant
+ import kwant.continuum
  
  
  # 2 band model exhibiting quantum anomalous Hall effect
diff --git a/doc/source/conf.py b/doc/source/conf.py
index f0272556a1cee688b73c60fdea6574d8792ba5ee..9790aaa58434a0eb868d66bf2b2243677611ee4c 100644
--- a/doc/source/conf.py
+++ b/doc/source/conf.py
@@ -16,7 +16,9 @@ import sys, os
 from distutils.util import get_platform
 sys.path.insert(0, "../../build/lib.{0}-{1}.{2}".format(
         get_platform(), *sys.version_info[:2]))
+
 import kwant
+import kwant.continuum  # sphinx gets confused with lazy loading
 
 # -- General configuration -----------------------------------------------------
 
diff --git a/doc/source/tutorial/discretize.rst b/doc/source/tutorial/discretize.rst
index d3c58219ea2baf593bc96fc6e9ded1073aabc634..1ed5015cd2203e100d1d4e4cf9f21ae283716f04 100644
--- a/doc/source/tutorial/discretize.rst
+++ b/doc/source/tutorial/discretize.rst
@@ -53,6 +53,11 @@ with :math:`A(x) = \frac{\hbar^2}{2 m(x)}`.
 
 Using `~kwant.continuum.discretize` to obtain a template
 ........................................................
+First we must explicitly import the `kwant.continuum` package:
+
+.. literalinclude:: /code/include/discretize.py
+    :start-after: #HIDDEN_BEGIN_import
+    :end-before: #HIDDEN_END_import
 
 The function `kwant.continuum.discretize` takes a symbolic Hamiltonian and
 turns it into a `~kwant.builder.Builder` instance with appropriate spatial
diff --git a/kwant/__init__.py b/kwant/__init__.py
index 1fddab25426261753a81e358a30ddab009ac3171..054657c308607608b1c9c8a893680231b21c59f6 100644
--- a/kwant/__init__.py
+++ b/kwant/__init__.py
@@ -33,7 +33,7 @@ from ._common import KwantDeprecationWarning, UserCodeError
 __all__.extend(['KwantDeprecationWarning', 'UserCodeError'])
 
 for module in ['system', 'builder', 'lattice', 'solvers', 'digest', 'rmt',
-               'operator', 'kpm', 'wraparound', 'continuum']:
+               'operator', 'kpm', 'wraparound']:
     exec('from . import {0}'.format(module))
     __all__.append(module)
 
@@ -56,6 +56,13 @@ except:
 else:
     __all__.extend(['plotter', 'plot'])
 
+# Lazy import continuum package for backwards compatibility
+from ._common import lazy_import
+
+continuum = lazy_import('continuum', deprecation_warning=True)
+__all__.append('continuum')
+del lazy_import
+
 
 def test(verbose=True):
     from pytest import main
diff --git a/kwant/_common.py b/kwant/_common.py
index 76e952cf6c0d2ffc437d42ca5c093a5825aac0e9..4cd0de1a3e6d4b084f3434bbfffef08871d51ded 100644
--- a/kwant/_common.py
+++ b/kwant/_common.py
@@ -6,10 +6,12 @@
 # the file AUTHORS.rst at the top-level directory of this distribution and at
 # http://kwant-project.org/authors.
 
+import sys
 import numpy as np
 import numbers
 import inspect
 import warnings
+import importlib
 from contextlib import contextmanager
 
 __all__ = ['KwantDeprecationWarning', 'UserCodeError']
@@ -39,30 +41,6 @@ class UserCodeError(Exception):
     pass
 
 
-class ExtensionUnavailable:
-    """Class that replaces unavailable extension modules in the Kwant namespace.
-
-    Some extensions for Kwant (e.g. 'kwant.continuum') require additional
-    dependencies that are not required for core functionality. When the
-    additional dependencies are not installed an instance of this class will
-    be inserted into Kwant's root namespace to simulate the presence of the
-    extension and inform users that they need to install additional
-    dependencies.
-
-    See https://mail.python.org/pipermail/python-ideas/2012-May/014969.html
-    for more details.
-    """
-
-    def __init__(self, name, dependencies):
-        self.name = name
-        self.dependencies = ', '.join(dependencies)
-
-    def __getattr__(self, _):
-        msg = ("'{}' is not available because one or more of the following "
-               "dependencies are not installed: {}")
-        raise RuntimeError(msg.format(self.name, self.dependencies))
-
-
 def ensure_isinstance(obj, typ, msg=None):
     if isinstance(obj, typ):
         return
@@ -123,3 +101,25 @@ def get_parameters(func):
     takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD
                        for i in pars.values())
     return required_params, default_params, takes_kwargs
+
+
+class lazy_import:
+    def __init__(self, module, package='kwant', deprecation_warning=False):
+        if module.startswith('.') and not package:
+            raise ValueError('Cannot import a relative module without a package.')
+        self.__module = module
+        self.__package = package
+        self.__deprecation_warning = deprecation_warning
+
+    def __getattr__(self, name):
+        if self.__deprecation_warning:
+            msg = ("Accessing {0} without an explicit import is deprecated. "
+                   "Instead, explicitly 'import {0}'."
+                  ).format('.'.join((self.__package, self.__module)))
+            warnings.warn(msg, KwantDeprecationWarning, stacklevel=2)
+        relative_module = '.' + self.__module
+        mod = importlib.import_module(relative_module, self.__package)
+        # Replace this _LazyModuleProxy with an actual module
+        package = sys.modules[self.__package]
+        setattr(package, self.__module, mod)
+        return getattr(mod, name)
diff --git a/kwant/_plotter.py b/kwant/_plotter.py
new file mode 100644
index 0000000000000000000000000000000000000000..771e9ba281999efc07f2aece0a7fc0169d19be63
--- /dev/null
+++ b/kwant/_plotter.py
@@ -0,0 +1,320 @@
+# -*- coding: utf-8 -*-
+# Copyright 2011-2018 Kwant authors.
+#
+# This file is part of Kwant.  It is subject to the license terms in the file
+# LICENSE.rst found in the top-level directory of this distribution and at
+# http://kwant-project.org/license.  A list of Kwant authors can be found in
+# the file AUTHORS.rst at the top-level directory of this distribution and at
+# http://kwant-project.org/authors.
+
+# This module is imported by plotter.py. It contains all the expensive imports
+# that we want to remove from plotter.py
+
+# All matplotlib imports must be isolated in a try, because even without
+# matplotlib iterators remain useful.  Further, mpl_toolkits used for 3D
+# plotting are also imported separately, to ensure that 2D plotting works even
+# if 3D does not.
+
+import warnings
+from math import sqrt, pi
+import numpy as np
+
+try:
+    import matplotlib
+    import matplotlib.colors
+    import matplotlib.cm
+    from matplotlib.figure import Figure
+    from matplotlib import collections
+    from . import _colormaps
+    mpl_available = True
+    try:
+        from mpl_toolkits import mplot3d
+        has3d = True
+    except ImportError:
+        warnings.warn("3D plotting not available.", RuntimeWarning)
+        has3d = False
+except ImportError:
+    warnings.warn("matplotlib is not available, only iterator-providing "
+                  "functions will work.", RuntimeWarning)
+    mpl_available = False
+
+
+# Collections that allow for symbols and linewiths to be given in data space
+# (not for general use, only implement what's needed for plotter)
+def isarray(var):
+    if hasattr(var, '__getitem__') and not isinstance(var, str):
+        return True
+    else:
+        return False
+
+
+def nparray_if_array(var):
+    return np.asarray(var) if isarray(var) else var
+
+
+if mpl_available:
+    class LineCollection(collections.LineCollection):
+        def __init__(self, segments, reflen=None, **kwargs):
+            super().__init__(segments, **kwargs)
+            self.reflen = reflen
+
+        def set_linewidths(self, linewidths):
+            self.linewidths_orig = nparray_if_array(linewidths)
+
+        def draw(self, renderer):
+            if self.reflen is not None:
+                # Note: only works for aspect ratio 1!
+                #       72.0 - there is 72 points in an inch
+                factor = (self.axes.transData.frozen().to_values()[0] * 72.0 *
+                          self.reflen / self.figure.dpi)
+            else:
+                factor = 1
+
+            super().set_linewidths(self.linewidths_orig *
+                                                       factor)
+            return super().draw(renderer)
+
+
+    class PathCollection(collections.PathCollection):
+        def __init__(self, paths, sizes=None, reflen=None, **kwargs):
+            super().__init__(paths, sizes=sizes, **kwargs)
+
+            self.reflen = reflen
+            self.linewidths_orig = nparray_if_array(self.get_linewidths())
+
+            self.transforms = np.array(
+                [matplotlib.transforms.Affine2D().scale(x).get_matrix()
+                 for x in sizes])
+
+        def get_transforms(self):
+            return self.transforms
+
+        def get_transform(self):
+            Affine2D = matplotlib.transforms.Affine2D
+            if self.reflen is not None:
+                # For the paths, use the data transformation but strip the
+                # offset (will be added later with offsets)
+                args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
+                return Affine2D().from_values(*args).scale(self.reflen)
+            else:
+                return Affine2D().scale(self.figure.dpi / 72.0)
+
+        def draw(self, renderer):
+            if self.reflen:
+                # Note: only works for aspect ratio 1!
+                factor = (self.axes.transData.frozen().to_values()[0] /
+                          self.figure.dpi * 72.0 * self.reflen)
+                self.set_linewidths(self.linewidths_orig * factor)
+
+            return collections.Collection.draw(self, renderer)
+
+
+    if has3d:
+        # Sorting is optional.
+        sort3d = True
+
+        # Compute the projection of a 3D length into 2D data coordinates
+        # for this we use 2 3D half-circles that are projected into 2D.
+        # (This gives the same length as projecting the full unit sphere.)
+
+        phi = np.linspace(0, pi, 21)
+        xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21)
+        unit_sphere = np.bmat([[xyz[0], xyz[2]], [xyz[1], xyz[0]],
+                                [xyz[2], xyz[1]]])
+        unit_sphere = np.asarray(unit_sphere)
+
+        def projected_length(ax, length):
+            rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()])
+            rc = np.apply_along_axis(np.sum, 1, rc) / 2.
+
+            rs = unit_sphere * length + rc.reshape(-1, 1)
+
+            transform = mplot3d.proj3d.proj_transform
+            rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2])
+            rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2]
+
+            coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1)
+            return sqrt(np.sum(coords**2, axis=0).max())
+
+
+        # Auxiliary array for calculating corners of a cube.
+        corners = np.zeros((3, 8, 6), np.float_)
+        corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \
+        corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \
+        corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0
+
+
+        class Line3DCollection(mplot3d.art3d.Line3DCollection):
+            def __init__(self, segments, reflen=None, zorder=0, **kwargs):
+                super().__init__(segments, **kwargs)
+                self.reflen = reflen
+                self.zorder3d = zorder
+
+            def set_linewidths(self, linewidths):
+                self.linewidths_orig = nparray_if_array(linewidths)
+
+            def do_3d_projection(self, renderer):
+                super().do_3d_projection(renderer)
+                # The whole 3D ordering is flawed in mplot3d when several
+                # collections are added. We just use normal zorder. Note the
+                # "-" due to the different logic in the 3d plotting, we still
+                # want larger zorder values to be plotted on top of smaller
+                # ones.
+                return -self.zorder3d
+
+            def draw(self, renderer):
+                if self.reflen:
+                    proj_len = projected_length(self.axes, self.reflen)
+                    args = self.axes.transData.frozen().to_values()
+                    # Note: unlike in the 2D case, where we can enforce equal
+                    #       aspect ratio, this (currently) does not work with
+                    #       3D plots in matplotlib. As an approximation, we
+                    #       thus scale with the average of the x- and y-axis
+                    #       transformation.
+                    factor = proj_len * (args[0] +
+                                         args[3]) * 0.5 * 72.0 / self.figure.dpi
+                else:
+                    factor = 1
+
+                super().set_linewidths(
+                                                self.linewidths_orig * factor)
+                super().draw(renderer)
+
+
+        class Path3DCollection(mplot3d.art3d.Patch3DCollection):
+            def __init__(self, paths, sizes, reflen=None, zorder=0,
+                         offsets=None, **kwargs):
+                paths = [matplotlib.patches.PathPatch(path) for path in paths]
+
+                if offsets is not None:
+                    kwargs['offsets'] = offsets[:, :2]
+
+                super().__init__(paths, **kwargs)
+
+                if offsets is not None:
+                    self.set_3d_properties(zs=offsets[:, 2], zdir="z")
+
+                self.reflen = reflen
+                self.zorder3d = zorder
+
+                self.paths_orig = np.array(paths, dtype='object')
+                self.linewidths_orig = nparray_if_array(self.get_linewidths())
+                self.linewidths_orig2 = self.linewidths_orig
+                self.array_orig = nparray_if_array(self.get_array())
+                self.facecolors_orig = nparray_if_array(self.get_facecolors())
+                self.edgecolors_orig = nparray_if_array(self.get_edgecolors())
+
+                Affine2D = matplotlib.transforms.Affine2D
+                self.orig_transforms = np.array(
+                    [Affine2D().scale(x).get_matrix() for x in sizes])
+                self.transforms = self.orig_transforms
+
+            def set_array(self, array):
+                self.array_orig = nparray_if_array(array)
+                super().set_array(array)
+
+            def set_color(self, colors):
+                self.facecolors_orig = nparray_if_array(colors)
+                self.edgecolors_orig = self.facecolors_orig
+                super().set_color(colors)
+
+            def set_edgecolors(self, colors):
+                colors = matplotlib.colors.colorConverter.to_rgba_array(colors)
+                self.edgecolors_orig = nparray_if_array(colors)
+                super().set_edgecolors(colors)
+
+            def get_transforms(self):
+                # this is exact only for an isometric projection, for the
+                # perspective projection used in mplot3d it's an approximation
+                return self.transforms
+
+            def get_transform(self):
+                Affine2D = matplotlib.transforms.Affine2D
+                if self.reflen:
+                    proj_len = projected_length(self.axes, self.reflen)
+
+                    # For the paths, use the data transformation but strip the
+                    # offset (will be added later with the offsets).
+                    args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
+                    return Affine2D().from_values(*args).scale(proj_len)
+                else:
+                    return Affine2D().scale(self.figure.dpi / 72.0)
+
+            def do_3d_projection(self, renderer):
+                xs, ys, zs = self._offsets3d
+
+                # numpy complains about zero-length index arrays
+                if len(xs) == 0:
+                    return -self.zorder3d
+
+                proj = mplot3d.proj3d.proj_transform_clip
+                vs = np.array(proj(xs, ys, zs, renderer.M)[:3])
+
+                if sort3d:
+                    indx = vs[2].argsort()[::-1]
+
+                    self.set_offsets(vs[:2, indx].T)
+
+                    if len(self.paths_orig) > 1:
+                        paths = np.resize(self.paths_orig, (vs.shape[1],))
+                        self.set_paths(paths[indx])
+
+                    if len(self.orig_transforms) > 1:
+                        self.transforms = self.transforms[indx]
+
+                    lw_orig = self.linewidths_orig
+                    if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1):
+                        self.linewidths_orig2 = np.resize(lw_orig,
+                                                           (vs.shape[1],))[indx]
+
+                    # Note: here array, facecolors and edgecolors are
+                    #       guaranteed to be 2d numpy arrays or None.  (And
+                    #       array is the same length as the coordinates)
+
+                    if self.array_orig is not None:
+                        super(Path3DCollection,
+                              self).set_array(self.array_orig[indx])
+
+                    if (self.facecolors_orig is not None and
+                        self.facecolors_orig.shape[0] > 1):
+                        shape = list(self.facecolors_orig.shape)
+                        shape[0] = vs.shape[1]
+                        super().set_facecolors(
+                            np.resize(self.facecolors_orig, shape)[indx])
+
+                    if (self.edgecolors_orig is not None and
+                        self.edgecolors_orig.shape[0] > 1):
+                        shape = list(self.edgecolors_orig.shape)
+                        shape[0] = vs.shape[1]
+                        super().set_edgecolors(
+                                                np.resize(self.edgecolors_orig,
+                                                          shape)[indx])
+                else:
+                    self.set_offsets(vs[:2].T)
+
+                # the whole 3D ordering is flawed in mplot3d when several
+                # collections are added. We just use normal zorder, but correct
+                # by the projected z-coord of the "center of gravity",
+                # normalized by the projected z-coord of the world coordinates.
+                # In doing so, several Path3DCollections are plotted probably
+                # in the right order (it's not exact) if they have the same
+                # zorder. Still, smaller and larger integer zorders are plotted
+                # below or on top.
+
+                bbox = np.asarray(self.axes.get_w_lims())
+
+                proj = mplot3d.proj3d.proj_transform_clip
+                cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2]
+
+                return -self.zorder3d + vs[2].mean() / cz.ptp()
+
+            def draw(self, renderer):
+                if self.reflen:
+                    proj_len = projected_length(self.axes, self.reflen)
+                    args = self.axes.transData.frozen().to_values()
+                    factor = proj_len * (args[0] +
+                                         args[3]) * 0.5 * 72.0 / self.figure.dpi
+
+                    self.set_linewidths(self.linewidths_orig2 * factor)
+
+                super().draw(renderer)
diff --git a/kwant/continuum/__init__.py b/kwant/continuum/__init__.py
index 718f50151795830b73df1f9327aa591997d0cbdc..bc2420e9c7998af5cb7f3b33fe72555e778156d4 100644
--- a/kwant/continuum/__init__.py
+++ b/kwant/continuum/__init__.py
@@ -6,18 +6,15 @@
 # the file AUTHORS.rst at the top-level directory of this distribution and at
 # http://kwant-project.org/authors.
 
-import sys
-
-from .._common import ExtensionUnavailable
-
 try:
     from .discretizer import discretize, discretize_symbolic, build_discretized
     from ._common import sympify, lambdify
     from ._common import momentum_operators, position_operators
-except ImportError:
-    sys.modules[__name__] = ExtensionUnavailable(__name__, ('sympy',))
+except ImportError as error:
+    msg = ("'kwant.continuum' is not available because one or more of its "
+           "dependencies is not installed.")
+    raise ImportError(msg) from error
 
-del sys, ExtensionUnavailable
 
 __all__ = ['discretize', 'discretize_symbolic', 'build_discretized',
            'sympify', 'lambdify']
diff --git a/kwant/plotter.py b/kwant/plotter.py
index 2112e5b98334cd15e22ef1dcebc09aaa59ed61c9..2fc6b6acf2e21e7d65f8b0b3336394c836389858 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -26,29 +26,6 @@ import tinyarray as ta
 from scipy import spatial, interpolate
 from math import cos, sin, pi, sqrt
 
-# All matplotlib imports must be isolated in a try, because even without
-# matplotlib iterators remain useful.  Further, mpl_toolkits used for 3D
-# plotting are also imported separately, to ensure that 2D plotting works even
-# if 3D does not.
-try:
-    import matplotlib
-    import matplotlib.colors
-    import matplotlib.cm
-    from matplotlib.figure import Figure
-    from matplotlib import collections
-    from . import _colormaps
-    mpl_available = True
-    try:
-        from mpl_toolkits import mplot3d
-        has3d = True
-    except ImportError:
-        warnings.warn("3D plotting not available.", RuntimeWarning)
-        has3d = False
-except ImportError:
-    warnings.warn("matplotlib is not available, only iterator-providing "
-                  "functions will work.", RuntimeWarning)
-    mpl_available = False
-
 from . import system, builder, _common
 
 
@@ -57,18 +34,9 @@ __all__ = ['plot', 'map', 'bands', 'spectrum', 'current',
            'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos',
            'sys_leads_hopping_pos', 'mask_interpolate']
 
-
-# Collections that allow for symbols and linewiths to be given in data space
-# (not for general use, only implement what's needed for plotter)
-def isarray(var):
-    if hasattr(var, '__getitem__') and not isinstance(var, str):
-        return True
-    else:
-        return False
-
-
-def nparray_if_array(var):
-    return np.asarray(var) if isarray(var) else var
+# All the expensive imports are done in _plotter.py. We lazy load the module
+# to avoid slowing down the initial import of Kwant.
+_p = _common.lazy_import('_plotter')
 
 
 def _sample_array(array, n_samples, rng=None):
@@ -77,278 +45,10 @@ def _sample_array(array, n_samples, rng=None):
     return array[rng.choice(range(la), min(n_samples, la))]
 
 
-if mpl_available:
-    class LineCollection(collections.LineCollection):
-        def __init__(self, segments, reflen=None, **kwargs):
-            super().__init__(segments, **kwargs)
-            self.reflen = reflen
-
-        def set_linewidths(self, linewidths):
-            self.linewidths_orig = nparray_if_array(linewidths)
-
-        def draw(self, renderer):
-            if self.reflen is not None:
-                # Note: only works for aspect ratio 1!
-                #       72.0 - there is 72 points in an inch
-                factor = (self.axes.transData.frozen().to_values()[0] * 72.0 *
-                          self.reflen / self.figure.dpi)
-            else:
-                factor = 1
-
-            super().set_linewidths(self.linewidths_orig *
-                                                       factor)
-            return super().draw(renderer)
-
-
-    class PathCollection(collections.PathCollection):
-        def __init__(self, paths, sizes=None, reflen=None, **kwargs):
-            super().__init__(paths, sizes=sizes, **kwargs)
-
-            self.reflen = reflen
-            self.linewidths_orig = nparray_if_array(self.get_linewidths())
-
-            self.transforms = np.array(
-                [matplotlib.transforms.Affine2D().scale(x).get_matrix()
-                 for x in sizes])
-
-        def get_transforms(self):
-            return self.transforms
-
-        def get_transform(self):
-            Affine2D = matplotlib.transforms.Affine2D
-            if self.reflen is not None:
-                # For the paths, use the data transformation but strip the
-                # offset (will be added later with offsets)
-                args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
-                return Affine2D().from_values(*args).scale(self.reflen)
-            else:
-                return Affine2D().scale(self.figure.dpi / 72.0)
-
-        def draw(self, renderer):
-            if self.reflen:
-                # Note: only works for aspect ratio 1!
-                factor = (self.axes.transData.frozen().to_values()[0] /
-                          self.figure.dpi * 72.0 * self.reflen)
-                self.set_linewidths(self.linewidths_orig * factor)
-
-            return collections.Collection.draw(self, renderer)
-
-
-    if has3d:
-        # Sorting is optional.
-        sort3d = True
-
-        # Compute the projection of a 3D length into 2D data coordinates
-        # for this we use 2 3D half-circles that are projected into 2D.
-        # (This gives the same length as projecting the full unit sphere.)
-
-        phi = np.linspace(0, pi, 21)
-        xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21)
-        unit_sphere = np.bmat([[xyz[0], xyz[2]], [xyz[1], xyz[0]],
-                                [xyz[2], xyz[1]]])
-        unit_sphere = np.asarray(unit_sphere)
-
-        def projected_length(ax, length):
-            rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()])
-            rc = np.apply_along_axis(np.sum, 1, rc) / 2.
-
-            rs = unit_sphere * length + rc.reshape(-1, 1)
-
-            transform = mplot3d.proj3d.proj_transform
-            rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2])
-            rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2]
-
-            coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1)
-            return sqrt(np.sum(coords**2, axis=0).max())
-
-
-        # Auxiliary array for calculating corners of a cube.
-        corners = np.zeros((3, 8, 6), np.float_)
-        corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \
-        corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \
-        corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0
-
-
-        class Line3DCollection(mplot3d.art3d.Line3DCollection):
-            def __init__(self, segments, reflen=None, zorder=0, **kwargs):
-                super().__init__(segments, **kwargs)
-                self.reflen = reflen
-                self.zorder3d = zorder
-
-            def set_linewidths(self, linewidths):
-                self.linewidths_orig = nparray_if_array(linewidths)
-
-            def do_3d_projection(self, renderer):
-                super().do_3d_projection(renderer)
-                # The whole 3D ordering is flawed in mplot3d when several
-                # collections are added. We just use normal zorder. Note the
-                # "-" due to the different logic in the 3d plotting, we still
-                # want larger zorder values to be plotted on top of smaller
-                # ones.
-                return -self.zorder3d
-
-            def draw(self, renderer):
-                if self.reflen:
-                    proj_len = projected_length(self.axes, self.reflen)
-                    args = self.axes.transData.frozen().to_values()
-                    # Note: unlike in the 2D case, where we can enforce equal
-                    #       aspect ratio, this (currently) does not work with
-                    #       3D plots in matplotlib. As an approximation, we
-                    #       thus scale with the average of the x- and y-axis
-                    #       transformation.
-                    factor = proj_len * (args[0] +
-                                         args[3]) * 0.5 * 72.0 / self.figure.dpi
-                else:
-                    factor = 1
-
-                super().set_linewidths(
-                                                self.linewidths_orig * factor)
-                super().draw(renderer)
-
-
-        class Path3DCollection(mplot3d.art3d.Patch3DCollection):
-            def __init__(self, paths, sizes, reflen=None, zorder=0,
-                         offsets=None, **kwargs):
-                paths = [matplotlib.patches.PathPatch(path) for path in paths]
-
-                if offsets is not None:
-                    kwargs['offsets'] = offsets[:, :2]
-
-                super().__init__(paths, **kwargs)
-
-                if offsets is not None:
-                    self.set_3d_properties(zs=offsets[:, 2], zdir="z")
-
-                self.reflen = reflen
-                self.zorder3d = zorder
-
-                self.paths_orig = np.array(paths, dtype='object')
-                self.linewidths_orig = nparray_if_array(self.get_linewidths())
-                self.linewidths_orig2 = self.linewidths_orig
-                self.array_orig = nparray_if_array(self.get_array())
-                self.facecolors_orig = nparray_if_array(self.get_facecolors())
-                self.edgecolors_orig = nparray_if_array(self.get_edgecolors())
-
-                Affine2D = matplotlib.transforms.Affine2D
-                self.orig_transforms = np.array(
-                    [Affine2D().scale(x).get_matrix() for x in sizes])
-                self.transforms = self.orig_transforms
-
-            def set_array(self, array):
-                self.array_orig = nparray_if_array(array)
-                super().set_array(array)
-
-            def set_color(self, colors):
-                self.facecolors_orig = nparray_if_array(colors)
-                self.edgecolors_orig = self.facecolors_orig
-                super().set_color(colors)
-
-            def set_edgecolors(self, colors):
-                colors = matplotlib.colors.colorConverter.to_rgba_array(colors)
-                self.edgecolors_orig = nparray_if_array(colors)
-                super().set_edgecolors(colors)
-
-            def get_transforms(self):
-                # this is exact only for an isometric projection, for the
-                # perspective projection used in mplot3d it's an approximation
-                return self.transforms
-
-            def get_transform(self):
-                Affine2D = matplotlib.transforms.Affine2D
-                if self.reflen:
-                    proj_len = projected_length(self.axes, self.reflen)
-
-                    # For the paths, use the data transformation but strip the
-                    # offset (will be added later with the offsets).
-                    args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
-                    return Affine2D().from_values(*args).scale(proj_len)
-                else:
-                    return Affine2D().scale(self.figure.dpi / 72.0)
-
-            def do_3d_projection(self, renderer):
-                xs, ys, zs = self._offsets3d
-
-                # numpy complains about zero-length index arrays
-                if len(xs) == 0:
-                    return -self.zorder3d
-
-                proj = mplot3d.proj3d.proj_transform_clip
-                vs = np.array(proj(xs, ys, zs, renderer.M)[:3])
-
-                if sort3d:
-                    indx = vs[2].argsort()[::-1]
-
-                    self.set_offsets(vs[:2, indx].T)
-
-                    if len(self.paths_orig) > 1:
-                        paths = np.resize(self.paths_orig, (vs.shape[1],))
-                        self.set_paths(paths[indx])
-
-                    if len(self.orig_transforms) > 1:
-                        self.transforms = self.transforms[indx]
-
-                    lw_orig = self.linewidths_orig
-                    if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1):
-                        self.linewidths_orig2 = np.resize(lw_orig,
-                                                           (vs.shape[1],))[indx]
-
-                    # Note: here array, facecolors and edgecolors are
-                    #       guaranteed to be 2d numpy arrays or None.  (And
-                    #       array is the same length as the coordinates)
-
-                    if self.array_orig is not None:
-                        super(Path3DCollection,
-                              self).set_array(self.array_orig[indx])
-
-                    if (self.facecolors_orig is not None and
-                        self.facecolors_orig.shape[0] > 1):
-                        shape = list(self.facecolors_orig.shape)
-                        shape[0] = vs.shape[1]
-                        super().set_facecolors(
-                            np.resize(self.facecolors_orig, shape)[indx])
-
-                    if (self.edgecolors_orig is not None and
-                        self.edgecolors_orig.shape[0] > 1):
-                        shape = list(self.edgecolors_orig.shape)
-                        shape[0] = vs.shape[1]
-                        super().set_edgecolors(
-                                                np.resize(self.edgecolors_orig,
-                                                          shape)[indx])
-                else:
-                    self.set_offsets(vs[:2].T)
-
-                # the whole 3D ordering is flawed in mplot3d when several
-                # collections are added. We just use normal zorder, but correct
-                # by the projected z-coord of the "center of gravity",
-                # normalized by the projected z-coord of the world coordinates.
-                # In doing so, several Path3DCollections are plotted probably
-                # in the right order (it's not exact) if they have the same
-                # zorder. Still, smaller and larger integer zorders are plotted
-                # below or on top.
-
-                bbox = np.asarray(self.axes.get_w_lims())
-
-                proj = mplot3d.proj3d.proj_transform_clip
-                cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2]
-
-                return -self.zorder3d + vs[2].mean() / cz.ptp()
-
-            def draw(self, renderer):
-                if self.reflen:
-                    proj_len = projected_length(self.axes, self.reflen)
-                    args = self.axes.transData.frozen().to_values()
-                    factor = proj_len * (args[0] +
-                                         args[3]) * 0.5 * 72.0 / self.figure.dpi
-
-                    self.set_linewidths(self.linewidths_orig2 * factor)
-
-                super().draw(renderer)
-
-
 # matplotlib helper functions.
 
 def _make_figure(dpi, fig_size):
-    fig = Figure()
+    fig = _p.Figure()
     if dpi is not None:
         fig.set_dpi(dpi)
     if fig_size is not None:
@@ -377,10 +77,10 @@ def set_colors(color, collection, cmap, norm=None):
     if (isinstance(color, np.ndarray) and color.dtype == np.dtype('object')):
         color = tuple(color)
 
-    if isinstance(collection, mplot3d.art3d.Line3DCollection):
+    if isinstance(collection, _p.mplot3d.art3d.Line3DCollection):
         length = len(collection._segments3d)  # Once again, matplotlib fault!
 
-    if isarray(color) and len(color) == length:
+    if _p.isarray(color) and len(color) == length:
         try:
             # check if it is an array of floats for color mapping
             color = np.asarray(color, dtype=float)
@@ -393,7 +93,7 @@ def set_colors(color, collection, cmap, norm=None):
         except (TypeError, ValueError):
             pass
 
-    colors = matplotlib.colors.colorConverter.to_rgba_array(color)
+    colors = _p.matplotlib.colors.colorConverter.to_rgba_array(color)
     collection.set_color(colors)
 
 
@@ -414,7 +114,7 @@ def get_symbol(symbols):
 
     paths = []
     for symbol in symbols:
-        if isinstance(symbol, matplotlib.path.Path):
+        if isinstance(symbol, _p.matplotlib.path.Path):
             return symbol
         elif hasattr(symbol, '__getitem__') and len(symbol) == 3:
             kind, n, angle = symbol
@@ -428,13 +128,13 @@ def get_symbol(symbols):
                     radius = sqrt(2 * pi / (n * sin(2 * pi / n)))
 
                 angle = pi * angle / 180
-                patch = matplotlib.patches.RegularPolygon((0, 0), n,
-                                                          radius=radius,
-                                                          orientation=angle)
+                patch = _p.matplotlib.patches.RegularPolygon((0, 0), n,
+                                                             radius=radius,
+                                                             orientation=angle)
             else:
                 raise ValueError("Unknown symbol definition " + str(symbol))
         elif symbol == 'o':
-            patch = matplotlib.patches.Circle((0, 0), 1)
+            patch = _p.matplotlib.patches.Circle((0, 0), 1)
 
         paths.append(patch.get_path().transformed(patch.get_transform()))
 
@@ -495,9 +195,9 @@ def symbols(axes, pos, symbol='o', size=1, reflen=None, facecolor='k',
         size = (size, )
 
     if dim == 2:
-        Collection = PathCollection
+        Collection = _p.PathCollection
     else:
-        Collection = Path3DCollection
+        Collection = _p.Path3DCollection
 
     if len(pos) == 0 or np.all(symbol == 'no symbol') or np.all(size == 0):
         paths = []
@@ -569,9 +269,9 @@ def lines(axes, pos0, pos1, reflen=None, colors='k', linestyles='solid',
     dim = pos0.shape[1]
     assert dim == 2 or dim == 3
     if dim == 2:
-        Collection = LineCollection
+        Collection = _p.LineCollection
     else:
-        Collection = Line3DCollection
+        Collection = _p.Line3DCollection
 
     if (len(pos0) == 0 or
         ('linewidths' in kwargs and kwargs['linewidths'] == 0)):
@@ -631,7 +331,7 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None,
     matplotlib in that the `dpi` attribute of the figure is used by defaul
     instead of the matplotlib config setting.
     """
-    if not mpl_available:
+    if not _p.mpl_available:
         raise RuntimeError('matplotlib is not installed.')
 
     # We import backends and pyplot only at the last possible moment (=now)
@@ -1111,7 +811,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
       its aspect ratio.
 
     """
-    if not mpl_available:
+    if not _p.mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for plot()")
 
@@ -1161,7 +861,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
         start_pos = np.apply_along_axis(pos_transform, 1, start_pos)
 
     dim = 3 if (sites_pos.shape[1] == 3) else 2
-    if dim == 3 and not has3d:
+    if dim == 3 and not _p.has3d:
         raise RuntimeError("Installed matplotlib does not support 3d plotting")
     sites_pos = resize_to_dim(sites_pos)
     end_pos = resize_to_dim(end_pos)
@@ -1202,7 +902,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
     def make_proper_site_spec(spec, fancy_indexing=False):
         if callable(spec):
             spec = [spec(i[0]) for i in sites if i[1] is None]
-        if (fancy_indexing and isarray(spec)
+        if (fancy_indexing and _p.isarray(spec)
             and not isinstance(spec, np.ndarray)):
             try:
                 spec = np.asarray(spec)
@@ -1213,7 +913,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
     def make_proper_hop_spec(spec, fancy_indexing=False):
         if callable(spec):
             spec = [spec(*i[0]) for i in hops if i[1] is None]
-        if (fancy_indexing and isarray(spec)
+        if (fancy_indexing and _p.isarray(spec)
             and not isinstance(spec, np.ndarray)):
             try:
                 spec = np.asarray(spec)
@@ -1225,7 +925,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
     if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
     # separate different symbols (not done in 3D, the separation
     # would mess up sorting)
-    if (isarray(site_symbol) and dim != 3 and
+    if (_p.isarray(site_symbol) and dim != 3 and
         (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))):
         symbol_dict = defaultdict(list)
         for i, symbol in enumerate(site_symbol):
@@ -1239,7 +939,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
         fancy_indexing = False
 
     if site_color is None:
-        cycle = (x['color'] for x in matplotlib.rcParams['axes.prop_cycle'])
+        cycle = (x['color'] for x in _p.matplotlib.rcParams['axes.prop_cycle'])
         if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)):
             # Skipping the leads for brevity.
             families = sorted({site.family for site in syst.sites})
@@ -1279,30 +979,30 @@ def plot(sys, num_lead_cells=2, unit='nn',
         try:
             if site_color.ndim == 1 and len(site_color) == n_syst_sites:
                 site_color = np.asarray(site_color, dtype=float)
-                norm = matplotlib.colors.Normalize(site_color.min(),
-                                                   site_color.max())
+                norm = _p.matplotlib.colors.Normalize(site_color.min(),
+                                                      site_color.max())
         except:
             pass
 
     # take spec also for lead, if it's not a list/array, default, otherwise
     if lead_site_symbol is None:
-        lead_site_symbol = (site_symbol if not isarray(site_symbol)
+        lead_site_symbol = (site_symbol if not _p.isarray(site_symbol)
                             else defaults['site_symbol'][dim])
     if lead_site_size is None:
-        lead_site_size = (site_size if not isarray(site_size)
+        lead_site_size = (site_size if not _p.isarray(site_size)
                           else defaults['site_size'][dim])
     if lead_color is None:
         lead_color = defaults['lead_color'][dim]
-    lead_color = matplotlib.colors.colorConverter.to_rgba(lead_color)
+    lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color)
 
     if lead_site_edgecolor is None:
-        lead_site_edgecolor = (site_edgecolor if not isarray(site_edgecolor)
+        lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor)
                                else defaults['site_edgecolor'][dim])
     if lead_site_lw is None:
-        lead_site_lw = (site_lw if not isarray(site_lw)
+        lead_site_lw = (site_lw if not _p.isarray(site_lw)
                         else defaults['site_lw'][dim])
     if lead_hop_lw is None:
-        lead_hop_lw = (hop_lw if not isarray(hop_lw)
+        lead_hop_lw = (hop_lw if not _p.isarray(hop_lw)
                        else defaults['hop_lw'][dim])
 
     hop_cmap = None
@@ -1328,11 +1028,11 @@ def plot(sys, num_lead_cells=2, unit='nn',
 
     # plot system sites and hoppings
     for symbol, slc in symbol_slcs:
-        size = site_size[slc] if isarray(site_size) else site_size
-        col = site_color[slc] if isarray(site_color) else site_color
-        edgecol = (site_edgecolor[slc] if isarray(site_edgecolor) else
+        size = site_size[slc] if _p.isarray(site_size) else site_size
+        col = site_color[slc] if _p.isarray(site_color) else site_color
+        edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else
                    site_edgecolor)
-        lw = site_lw[slc] if isarray(site_lw) else site_lw
+        lw = site_lw[slc] if _p.isarray(site_lw) else site_lw
 
         symbol_coll = symbols(ax, sites_pos[slc], size=size,
                               reflen=reflen, symbol=symbol,
@@ -1344,8 +1044,8 @@ def plot(sys, num_lead_cells=2, unit='nn',
                       zorder=1, cmap=hop_cmap)
 
     # plot lead sites and hoppings
-    norm = matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5)
-    cmap_from_list = matplotlib.colors.LinearSegmentedColormap.from_list
+    norm = _p.matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5)
+    cmap_from_list = _p.matplotlib.colors.LinearSegmentedColormap.from_list
     lead_cmap = cmap_from_list(None, [lead_color, (1, 1, 1, lead_color[3])])
 
     for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs):
@@ -1543,7 +1243,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
       correspond to exactly one pixel.
     """
 
-    if not mpl_available:
+    if not _p.mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for map()")
 
@@ -1576,7 +1276,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
         fig = None
 
     if cmap is None:
-        cmap = _colormaps.kwant_red
+        cmap = _p._colormaps.kwant_red
 
     # Note that we tell imshow to show the array created by mask_interpolate
     # faithfully and not to interpolate by itself another time.
@@ -1640,7 +1340,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
     See `~kwant.physics.Bands` for the calculation of dispersion without plotting.
     """
 
-    if not mpl_available:
+    if not _p.mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for bands()")
 
@@ -1715,10 +1415,10 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
         A figure with the output if `ax` is not set, else None.
     """
 
-    if not mpl_available:
+    if not _p.mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for plot_spectrum()")
-    if y is not None and not has3d:
+    if y is not None and not _p.has3d:
         raise RuntimeError("Installed matplotlib does not support 3d plotting")
 
     if isinstance(syst, system.FiniteSystem):
@@ -1754,7 +1454,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
 
     # set up axes
     if ax is None:
-        fig = Figure()
+        fig = _p.Figure()
         if dpi is not None:
             fig.set_dpi(dpi)
         if fig_size is not None:
@@ -1996,8 +1696,8 @@ _gamma_expand = np.vectorize(_gamma_expand, otypes=[float])
 
 def _linear_cmap(a, b):
     """Make a colormap that linearly interpolates between the colors a and b."""
-    a = matplotlib.colors.colorConverter.to_rgb(a)
-    b = matplotlib.colors.colorConverter.to_rgb(b)
+    a = _p.matplotlib.colors.colorConverter.to_rgb(a)
+    b = _p.matplotlib.colors.colorConverter.to_rgb(b)
     a_linear = _gamma_expand(a)
     b_linear = _gamma_expand(b)
     color_diff = a_linear - b_linear
@@ -2005,7 +1705,7 @@ def _linear_cmap(a, b):
                * color_diff.reshape((1, -1)))
     palette += b_linear
     palette = _gamma_compress(palette)
-    return matplotlib.colors.ListedColormap(palette)
+    return _p.matplotlib.colors.ListedColormap(palette)
 
 
 def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
@@ -2074,7 +1774,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
     fig : matplotlib figure
         A figure with the output if `ax` is not set, else None.
     """
-    if not mpl_available:
+    if not _p.mpl_available:
         raise RuntimeError("matplotlib was not found, but is required "
                            "for current()")
 
@@ -2090,8 +1790,8 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
 
     if bgcolor is None:
         if cmap is None:
-            cmap = _colormaps.kwant_red
-        cmap = matplotlib.cm.get_cmap(cmap)
+            cmap = _p._colormaps.kwant_red
+        cmap = _p.matplotlib.cm.get_cmap(cmap)
         bgcolor = cmap(0)[:3]
     elif cmap is not None:
         raise ValueError("The parameters 'cmap' and 'bgcolor' are "
@@ -2129,7 +1829,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
     ax.streamplot(X, Y, field[:,:,0], field[:,:,1],
                   density=density, linewidth=linewidth,
                   color=color, cmap=line_cmap, arrowstyle='->',
-                  norm=matplotlib.colors.Normalize(0, 1))
+                  norm=_p.matplotlib.colors.Normalize(0, 1))
 
     ax.set_xlim(*box[0])
     ax.set_ylim(*box[1])
diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py
index ff2e55caa18b693b23eca9f219d4d22afcff54e8..a5d06a36bf13e1a8a549620fda9d98ff762d699d 100644
--- a/kwant/tests/test_plotter.py
+++ b/kwant/tests/test_plotter.py
@@ -36,6 +36,7 @@ except ImportError:
     matplotlib_backend_chosen = False
 
 from kwant import plotter
+from kwant import _plotter  # for mpl_available
 
 
 def test_matplotlib_backend_unset():
@@ -44,7 +45,7 @@ def test_matplotlib_backend_unset():
 
 
 def test_importable_without_matplotlib():
-    prefix, sep, suffix = plotter.__file__.rpartition('.')
+    prefix, sep, suffix = _plotter.__file__.rpartition('.')
     if suffix in ['pyc', 'pyo']:
         suffix = 'py'
     assert suffix == 'py'
@@ -111,7 +112,7 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
     return syst
 
 
-@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
+@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_plot():
     plot = plotter.plot
     syst2d = syst_2d()
@@ -161,7 +162,7 @@ def bad_transform(pos):
     x, y = pos
     return x, y, 0
 
-@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
+@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_map():
     syst = syst_2d()
     with tempfile.TemporaryFile('w+b') as out:
@@ -197,7 +198,7 @@ def test_mask_interpolate():
                       coords, np.ones(2 * len(coords)))
 
 
-@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
+@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_bands():
 
     syst = syst_2d().finalized().leads[0]
@@ -212,7 +213,7 @@ def test_bands():
         plotter.bands(syst, ax=ax, file=out)
 
 
-@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
+@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_spectrum():
 
     def ham_1d(a, b, c):
@@ -423,7 +424,7 @@ def test_current_interpolation():
     assert scipy.stats.linregress(np.log(data))[2] < -0.8
 
 
-@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
+@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_current():
     syst = syst_2d().finalized()
     J = kwant.operator.Current(syst)
diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py
index eb981d2798fa1b8fd42e36d26a0cd88039274b11..f59ec8cb15953446b7ce44b745adf6b0c3b98ed7 100644
--- a/kwant/tests/test_wraparound.py
+++ b/kwant/tests/test_wraparound.py
@@ -13,11 +13,11 @@ import tinyarray as ta
 import pytest
 
 import kwant
-from kwant import plotter
+from kwant import _plotter
 from kwant.wraparound import wraparound, plot_2d_bands
 from kwant._common import get_parameters
 
-if plotter.mpl_available:
+if _plotter.mpl_available:
     from mpl_toolkits import mplot3d  # pragma: no flakes
     from matplotlib import pyplot  # pragma: no flakes
 
@@ -201,7 +201,7 @@ def test_symmetry():
                 assert np.all(orig == new)
 
 
-@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
+@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
 def test_plot_2d_bands():
     chain = kwant.lattice.chain()
     square = kwant.lattice.square()
diff --git a/pytest.ini b/pytest.ini
index 6c8258daf2ed83d5fb6da872d379c2ad36d42c25..c7f56a4768a2f6f410c50545cc5740b85baef3d7 100644
--- a/pytest.ini
+++ b/pytest.ini
@@ -2,5 +2,6 @@
 testpaths = kwant
 flakes-ignore =
     __init__.py UnusedImport
+    kwant/_plotter.py UnusedImport
     graph/tests/test_scotch.py UndefinedName
     graph/tests/test_dissection.py UndefinedName