diff --git a/kwant/plotter.py b/kwant/plotter.py index abd9fe2cc2ba10a973f208d05e24ee440d759c64..4eef52710f3ed0af045de21649aa8faa152be5aa 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -37,7 +37,7 @@ try: from matplotlib import collections from matplotlib.backends.backend_agg import FigureCanvasAgg from . import _colormaps - mpl_enabled = True + mpl_available = True try: from mpl_toolkits import mplot3d has3d = True @@ -47,7 +47,7 @@ try: except ImportError: warnings.warn("matplotlib is not available, only iterator-providing " "functions will work.", RuntimeWarning) - mpl_enabled = False + mpl_available = False from . import system, builder, _common @@ -72,7 +72,7 @@ def matplotlib_chores(): pre_1_4_matplotlib = [int(x) for x in ver.split('.')[:2]] < [1, 4] -if mpl_enabled: +if mpl_available: matplotlib_chores() @@ -95,7 +95,7 @@ def _sample_array(array, n_samples, rng=None): return array[rng.choice(range(la), min(n_samples, la))] -if mpl_enabled: +if mpl_available: class LineCollection(collections.LineCollection): def __init__(self, segments, reflen=None, **kwargs): super().__init__(segments, **kwargs) @@ -637,10 +637,9 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None, The output mode to be used. Can be one of the following: 'pyplot' : attach the figure to pyplot, with the same behavior as if pyplot.plot was called to create this figure. - 'ipython' : attach a `FigureCanvasAgg` to the figure and return it. - 'return' : return the figure. - 'file' : same as 'ipython', but also save the figure into a file. - 'auto' : if fname is given, save to a file, else if pyplot + 'return' : attach a `FigureCanvasAgg` to the figure and return it. + 'file' : same as 'return', but also save the figure into a file. + 'auto' : if fname is given, save to a file, otherwise like pyplot is imported, attach to pyplot, otherwise just return. See also the notes below. file : string or a file object @@ -658,17 +657,10 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None, matplotlib in that the `dpi` attribute of the figure is used by defaul instead of the matplotlib config setting. """ - if not mpl_enabled: + if not mpl_available: raise RuntimeError('matplotlib is not installed.') if output_mode == 'auto': - if file is not None: - output_mode = 'file' - else: - try: - matplotlib.pyplot.get_backend() - output_mode = 'pyplot' - except AttributeError: - output_mode = 'pyplot' + output_mode = 'pyplot' if file is None else 'file' if output_mode == 'pyplot': try: fake_fig = matplotlib.pyplot.figure() @@ -685,21 +677,18 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None, pass if show: matplotlib.pyplot.show() - return fig - elif output_mode == 'return': - canvas = FigureCanvasAgg(fig) - fig.canvas = canvas - return fig - elif output_mode == 'file': - canvas = FigureCanvasAgg(fig) - if savefile_opts is None: - savefile_opts = ([], {}) - if 'dpi' not in savefile_opts[1]: - savefile_opts[1]['dpi'] = fig.dpi - canvas.print_figure(file, *savefile_opts[0], **savefile_opts[1]) - return fig + elif output_mode in ['return', 'file']: + fig.canvas = FigureCanvasAgg(fig) + if output_mode == 'file': + fig.canvas = canvas = FigureCanvasAgg(fig) + if savefile_opts is None: + savefile_opts = ([], {}) + if 'dpi' not in savefile_opts[1]: + savefile_opts[1]['dpi'] = fig.dpi + canvas.print_figure(file, *savefile_opts[0], **savefile_opts[1]) else: - assert False, 'Unknown output_mode' + raise ValueError('Unknown output_mode') + return fig # Extracting necessary data from the system. @@ -1139,7 +1128,7 @@ def plot(sys, num_lead_cells=2, unit='nn', its aspect ratio. """ - if not mpl_enabled: + if not mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for plot()") @@ -1555,7 +1544,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, correspond to exactly one pixel. """ - if not mpl_enabled: + if not mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for map()") @@ -1652,7 +1641,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, See `~kwant.physics.Bands` for the calculation of dispersion without plotting. """ - if not mpl_enabled: + if not mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for bands()") @@ -1727,7 +1716,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, A figure with the output if `ax` is not set, else None. """ - if not mpl_enabled: + if not mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for plot_spectrum()") if y is not None and not has3d: @@ -2067,7 +2056,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', fig : matplotlib figure A figure with the output if `ax` is not set, else None. """ - if not mpl_enabled: + if not mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for current()") diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index 11266b4e8e167a46acdb40cb9fcd286c9f74de8f..12ea35ab54900e5fa29aa09e17f06de174dcc1c6 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -105,7 +105,7 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0): return syst -@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.") +@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") def test_plot(): plot = plotter.plot syst2d = syst_2d() @@ -155,7 +155,7 @@ def bad_transform(pos): x, y = pos return x, y, 0 -@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.") +@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") def test_map(): syst = syst_2d() with tempfile.TemporaryFile('w+b') as out: @@ -191,7 +191,7 @@ def test_mask_interpolate(): coords, np.ones(2 * len(coords))) -@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.") +@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") def test_bands(): syst = syst_2d().finalized().leads[0] @@ -206,7 +206,7 @@ def test_bands(): plotter.bands(syst, ax=ax, file=out) -@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.") +@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") def test_spectrum(): def ham_1d(a, b, c): @@ -417,7 +417,7 @@ def test_current_interpolation(): assert scipy.stats.linregress(np.log(data))[2] < -0.8 -@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.") +@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") def test_current(): syst = syst_2d().finalized() J = kwant.operator.Current(syst) diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py index dc8ed3a927ea2248d4991d856b9f00310b812474..eb981d2798fa1b8fd42e36d26a0cd88039274b11 100644 --- a/kwant/tests/test_wraparound.py +++ b/kwant/tests/test_wraparound.py @@ -17,7 +17,7 @@ from kwant import plotter from kwant.wraparound import wraparound, plot_2d_bands from kwant._common import get_parameters -if plotter.mpl_enabled: +if plotter.mpl_available: from mpl_toolkits import mplot3d # pragma: no flakes from matplotlib import pyplot # pragma: no flakes @@ -201,7 +201,7 @@ def test_symmetry(): assert np.all(orig == new) -@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.") +@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") def test_plot_2d_bands(): chain = kwant.lattice.chain() square = kwant.lattice.square()