diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py
index 39f994ee5c12aa259aa9e00ac817b14fcccf77c1..a52712e33979a23e4a0123568edf2e654f61918d 100644
--- a/kwant/tests/test_wraparound.py
+++ b/kwant/tests/test_wraparound.py
@@ -6,14 +6,21 @@
 # the file AUTHORS.rst at the top-level directory of this distribution and at
 # http://kwant-project.org/authors.
 
+import tempfile
 import itertools
 import numpy as np
 import tinyarray as ta
+import pytest
 
 import kwant
-from kwant.wraparound import wraparound
+from kwant import plotter
+from kwant.wraparound import wraparound, plot_2d_bands
 from kwant._common import get_parameters
 
+if plotter.mpl_enabled:
+    from mpl_toolkits import mplot3d  # pragma: no flakes
+    from matplotlib import pyplot  # pragma: no flakes
+
 
 def _simple_syst(lat, E=0, t=1+1j, sym=None):
     """Create a builder for a simple infinite system."""
@@ -192,3 +199,54 @@ def test_symmetry():
                 assert np.all(orig(None) == new(None, None, None))
             else:
                 assert np.all(orig == new)
+
+
+@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+def test_plot_2d_bands():
+    chain = kwant.lattice.chain()
+    square = kwant.lattice.square()
+    cube = kwant.lattice.general([(1, 0, 0), (0, 1, 0), (0, 0, 1)])
+    hc = kwant.lattice.honeycomb()
+
+    syst_1d = kwant.Builder(kwant.TranslationalSymmetry(*chain._prim_vecs))
+    syst_1d[chain(0)] = 2
+    syst_1d[chain.neighbors()] = -1
+
+    syst_2d = _simple_syst(square, t=-1)
+    syst_graphene = _simple_syst(hc, t=-1)
+
+    syst_3d = kwant.Builder(kwant.TranslationalSymmetry(*cube._prim_vecs))
+    syst_3d[cube(0, 0, 0)] = 6
+    syst_3d[cube.neighbors()] = -1
+
+
+    with tempfile.TemporaryFile('w+b') as out:
+        # test 2D
+        plot_2d_bands(wraparound(syst_2d).finalized(), k_x=11, k_y=11, file=out)
+        plot_2d_bands(wraparound(syst_graphene).finalized(), k_x=11, k_y=11,
+                   file=out)
+
+    # test non-wrapped around system
+    with pytest.raises(TypeError):
+        plot_2d_bands(syst_1d.finalized())
+    # test incompletely wrapped around system
+    with pytest.raises(TypeError):
+        plot_2d_bands(wraparound(syst_2d, keep=0).finalized())
+    # test incorrect lattice dimention (1, 3)
+    with pytest.raises(ValueError):
+        plot_2d_bands(wraparound(syst_1d).finalized())
+    with pytest.raises(ValueError):
+        plot_2d_bands(wraparound(syst_3d).finalized())
+
+    # test k_x and k_y differ
+    with tempfile.TemporaryFile('w+b') as out:
+        syst = wraparound(syst_2d).finalized()
+        plot_2d_bands(syst, k_x=11, k_y=15, file=out)
+        plot_2d_bands(syst, k_x=np.linspace(-np.pi, np.pi, 11), file=out)
+        plot_2d_bands(syst, k_y=np.linspace(-np.pi, np.pi, 11), file=out)
+
+        syst = wraparound(syst_graphene).finalized()
+        # test extend_bbox2d
+        plot_2d_bands(syst, extend_bbox=1.2, k_x=11, k_y=11, file=out)
+        # test mask Brillouin zone
+        plot_2d_bands(syst, mask_brillouin_zone=True, k_x=11, k_y=11, file=out)
diff --git a/kwant/wraparound.py b/kwant/wraparound.py
index 2146430d80de9b01b63d8617c7980cab84e62924..1936486bd85e96f42d0e96cb8e39e335c7eb08ca 100644
--- a/kwant/wraparound.py
+++ b/kwant/wraparound.py
@@ -9,15 +9,20 @@
 import collections
 import inspect
 import cmath
+
 import tinyarray as ta
+import numpy as np
+import scipy.linalg
+import scipy.spatial
 
-from . import builder
+from . import builder, system, plotter
+from .linalg import lll
 from .builder import herm_conj, HermConjOfFunc
 from .lattice import TranslationalSymmetry
 from ._common import get_parameters
 
 
-__all__ = ['wraparound']
+__all__ = ['wraparound', 'plot_2d_bands']
 
 
 def _hashable(obj):
@@ -306,3 +311,155 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')):
         ret[hop] = vals[0] if len(vals) == 1 else bind_sum(2, *vals)
 
     return ret
+
+
+def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
+                  mask_brillouin_zone=False, extend_bbox=0, file=None,
+                  show=True, dpi=None, fig_size=None, ax=None):
+    """Plot 2D band structure of a wrapped around system.
+
+    This function is primarily useful for systems that have translational
+    symmetry vectors that are non-orthogonal (e.g. graphene). This function
+    will properly plot the band structure in an orthonormal basis in k-space,
+    as opposed to in the basis of reciprocal lattice vectors (which would
+    produce a "skewed" Brillouin zone).
+
+    If your system has orthogonal lattice vectors, you are probably better
+    off using `kwant.plotter.spectrum`.
+
+    Parameters
+    ----------
+    syst : `kwant.system.FiniteSystem`
+        A 2D system that was finalized from a Builder produced by
+        `kwant.wraparound.wraparound`.  Note that this *must* be a finite
+        system; so `kwant.wraparound.wraparound` should have been called with
+        ``keep=None``.
+    k_x, k_y : int or sequence of float, default: 31
+        Either a number of sampling points, or a sequence of points at which
+        the band structure is to be evaluated, in units of inverse length.
+    params : dict, optional
+        Dictionary of parameter names and their values, not including the
+        momentum parameters.
+    mask_brillouin_zone : bool, default: False
+        If True, then the band structure will only be plotted over the first
+        Brillouin zone. By default the band structure is plotted over a
+        rectangular bounding box that contains the Brillouin zone.
+    extend_bbox : float, default: 0
+        Amount by which to extend the region over which the band structure is
+        plotted, expressed as a proportion of the Brillouin zone bounding box
+        length. i.e. ``extend_bbox=0.1`` will extend the region by 10% (in all
+        directions).
+    file : string or file object, optional
+        The output file.  If None, output will be shown instead.
+    show : bool, default: False
+        Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
+        to be shown immediately.  Defaults to `True`.
+    dpi : float, optional
+        Number of pixels per inch.  If not set the ``matplotlib`` default is
+        used.
+    fig_size : tuple, optional
+        Figure size `(width, height)` in inches.  If not set, the default
+        ``matplotlib`` value is used.
+    ax : ``matplotlib.axes.Axes`` instance, optional
+        If `ax` is not `None`, no new figure is created, but the plot is done
+        within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
+        and `fig_size` are ignored.
+
+    Returns
+    -------
+    fig : matplotlib figure
+        A figure with the output if `ax` is not set, else None.
+
+    Notes
+    -----
+    This function produces plots where the units of momentum are inverse
+    length. This is contrary to `kwant.plotter.bands`, where the units
+    of momentum are inverse lattice constant.
+
+    If the lattice vectors for the symmetry of ``syst`` are not orthogonal,
+    then part of the plotted band structure will be outside the first Brillouin
+    zone (inside the bounding box of the brillouin zone). Setting
+    ``mask_brillouin_zone=True`` will cause the plot to be truncated outside of
+    the first Brillouin zone.
+
+    See Also
+    --------
+    `kwant.plotter.spectrum`
+    """
+    if not hasattr(syst, '_wrapped_symmetry'):
+        raise TypeError("Expecting a system that was produced by "
+                        "'kwant.wraparound.wraparound'.")
+    if not isinstance(syst, system.FiniteSystem):
+        msg = ("All symmetry directions must be wrapped around: specify "
+               "'keep=None' when calling 'kwant.wraparound.wraparound'.")
+        raise TypeError(msg)
+
+    params = params or {}
+    lat_ndim, space_ndim = syst._wrapped_symmetry.periods.shape
+
+    if lat_ndim != 2:
+        raise ValueError("Expected a system with a 2D translational symmetry.")
+    if space_ndim != lat_ndim:
+        raise ValueError("Lattice dimension must equal realspace dimension.")
+
+    # columns of B are lattice vectors
+    B = np.array(syst._wrapped_symmetry.periods).T
+    # columns of A are reciprocal lattice vectors
+    A = B.dot(np.linalg.inv(B.T.dot(B)))
+
+    ## calculate the bounding box for the 1st Brillouin zone
+
+    # Get lattice points that neighbor the origin, in basis of lattice vectors
+    reduced_vecs, transf = lll.lll(A.T)
+    neighbors = ta.dot(lll.voronoi(reduced_vecs), transf)
+    # Add the origin to these points.
+    klat_points = np.concatenate(([[0] * lat_ndim], neighbors))
+    # Transform to cartesian coordinates and rescale.
+    # Will be used in 'outside_bz' function, later on.
+    klat_points = 2 * np.pi * np.dot(klat_points, A.T)
+    # Calculate the Voronoi cell vertices
+    vor = scipy.spatial.Voronoi(klat_points)
+    around_origin = vor.point_region[0]
+    bz_vertices = vor.vertices[vor.regions[around_origin]]
+    # extract bounding box
+    k_max = np.max(np.abs(bz_vertices), axis=0)
+
+    ## build grid along each axis, if needed
+    ks = []
+    for k, km in zip((k_x, k_y), k_max):
+        k = np.array(k)
+        if not k.shape:
+            if extend_bbox:
+                km += km * extend_bbox
+            k = np.linspace(-km, km, k)
+        ks.append(k)
+
+    # TODO: It is very inefficient to call 'momentum_to_lattice' once for
+    #       each point (for trivial Hamiltonians 60% of the time is spent
+    #       doing this). We should instead transform the whole grid in one call.
+
+    def momentum_to_lattice(k):
+        k, residuals = scipy.linalg.lstsq(A, k)[:2]
+        if np.any(abs(residuals) > 1e-7):
+            raise RuntimeError("Requested momentum doesn't correspond"
+                               " to any lattice momentum.")
+        return k
+
+    def ham(k_x, k_y=None, **params):
+        # transform into the basis of reciprocal lattice vectors
+        k = momentum_to_lattice([k_x] if k_y is None else [k_x, k_y])
+        p = dict(zip(syst._momentum_names, k), **params)
+        return syst.hamiltonian_submatrix(params=p, sparse=False)
+
+    def outside_bz(k_x, k_y, **_):
+        dm = scipy.spatial.distance_matrix(klat_points, [[k_x, k_y]])
+        return np.argmin(dm) != 0  # is origin no closest 'klat_point' to 'k'?
+
+    fig = plotter.spectrum(ham,
+                           x=('k_x', ks[0]),
+                           y=('k_y', ks[1]) if lat_ndim == 2 else None,
+                           params=params,
+                           mask=(outside_bz if mask_brillouin_zone else None),
+                           file=file, show=show, dpi=dpi,
+                           fig_size=fig_size, ax=ax)
+    return fig