diff --git a/doc/source/code/figure/band_structure.py.diff b/doc/source/code/figure/band_structure.py.diff index 8a2652072aea317f4a6e5090497c37b43afdef15..c886e6c5d9f6dc0fd2f1c6fdcff4a3b14073b748 100644 --- a/doc/source/code/figure/band_structure.py.diff +++ b/doc/source/code/figure/band_structure.py.diff @@ -41,11 +41,11 @@ #HIDDEN_BEGIN_pejz def main(): lead = make_lead().finalized() -- kwant.plotter.bands(lead, show=False) +- kwant.plotter.bands(lead) - pyplot.xlabel("momentum [(lattice constant)^-1]") - pyplot.ylabel("energy [t]") - pyplot.show() -+ fig = kwant.plotter.bands(lead, show=False) ++ fig = kwant.plotter.bands(lead) + pyplot.xlabel("momentum [(lattice constant)^-1]", + fontsize=_defs.mpl_label_size) + pyplot.ylabel("energy [t]", fontsize=_defs.mpl_label_size) diff --git a/doc/source/code/figure/discretize.py.diff b/doc/source/code/figure/discretize.py.diff index aebaf3e03c10d291d2b25789c1145b90d6210a15..1fcb2a6682aab59c89eabcbd18a9f68d3472d8fc 100644 --- a/doc/source/code/figure/discretize.py.diff +++ b/doc/source/code/figure/discretize.py.diff @@ -63,7 +63,7 @@ ham = syst.hamiltonian_submatrix(params=dict(V=potential), sparse=True) evecs = scipy.sparse.linalg.eigsh(ham, k=10, which='SM')[1] - kwant.plotter.map(syst, abs(evecs[:, n])**2, show=False) + kwant.plotter.map(syst, abs(evecs[:, n])**2) #HIDDEN_END_plot_eigenstate - plt.show() + save_figure('discretizer_gs') @@ -111,7 +111,7 @@ #HIDDEN_BEGIN_plot_qsh_band kwant.plotter.bands(syst.leads[0], params=params, - momenta=np.linspace(-0.3, 0.3, 201), show=False) + momenta=np.linspace(-0.3, 0.3, 201)) #HIDDEN_END_plot_qsh_band plt.grid() diff --git a/doc/source/code/figure/faq.py.diff b/doc/source/code/figure/faq.py.diff index 7c8889dceec701572fde707e7c5aed3d5cf9f427..81a096d4ad35df494d136d959e9fe567ac9241d8 100644 --- a/doc/source/code/figure/faq.py.diff +++ b/doc/source/code/figure/faq.py.diff @@ -399,7 +399,7 @@ def plot_and_label_modes(lead, E): # Plot the different modes pmodes, _ = lead.modes(energy=E) - kwant.plotter.bands(lead, show=False) + kwant.plotter.bands(lead) for i, k in enumerate(pmodes.momenta): plt.plot(k, E, 'ko') plt.annotate(str(i), xy=(k, E), xytext=(-5, 8), diff --git a/doc/source/pre/whatsnew/1.4.rst b/doc/source/pre/whatsnew/1.4.rst index 42b6b5c5ad5daa8d135ce54fb9b4dbdb8da091ee..88c039688103b892c2f5da281f0b8f1e9ffa3720 100644 --- a/doc/source/pre/whatsnew/1.4.rst +++ b/doc/source/pre/whatsnew/1.4.rst @@ -109,3 +109,18 @@ data would appear near the bottom of the color scale, and all of the features would be washed out by the presence of the peak. Now `~kwant.plotter.map` employs a heuristic for setting the colorscale when there are outliers, and will emit a warning when this is detected. + +Simplified integration with pyplot +---------------------------------- +Previously Kwant would only produce plots if the user separately imported +``matplotlib.pyplot``. Futher, Kwant would always call ``pyplot.show()`` +unless ``show=False`` argument was supplied. In the new version Kwant +automatically imports pyplot on first use of plotting functionality and does not +show the figures automatically. The latter allows to easier modify the figures +produced by Kwant:: + + kwant.plotter.plot(syst) + pyplot.title("System plot") + +The ``show`` keyword argument was also removed from all plotting functions. +If you need to restore old behavior call ``pyplot.show()`` after the plotting. diff --git a/kwant/plotter.py b/kwant/plotter.py index 545cd29a3246bdd06c0bcce9967bdf3c40f68e90..ab189e1b8277058b4ed57b50b1ee11ec7ef9ebf4 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -76,7 +76,7 @@ def _make_figure(dpi, fig_size, use_pyplot=False): return fig -def _maybe_output_fig(fig, file=None, show=True): +def _maybe_output_fig(fig, file=None): """Output a matplotlib figure using a given output mode. Parameters @@ -86,9 +86,6 @@ def _maybe_output_fig(fig, file=None, show=True): file : string or a file object The name of the target file or the target file itself (opened for writing). - show : bool - Whether to call ``matplotlib.pyplot.show()``. Only has an effect if - not saving to a file. Notes ----- @@ -101,11 +98,6 @@ def _maybe_output_fig(fig, file=None, show=True): if file is not None: fig.canvas.print_figure(file, dpi=fig.dpi) - elif show: - # If there was no file provided, pyplot should already be available and - # we can import it safely without additional warnings. - from matplotlib import pyplot - pyplot.show() def set_colors(color, collection, cmap, norm=None): @@ -676,7 +668,7 @@ def plot(sys, num_lead_cells=2, unit='nn', lead_site_edgecolor=None, lead_site_lw=None, lead_hop_lw=None, pos_transform=None, cmap='gray', colorbar=True, file=None, - show=True, dpi=None, fig_size=None, ax=None): + dpi=None, fig_size=None, ax=None): """Plot a system in 2 or 3 dimensions. An alias exists for this common name: ``kwant.plot``. @@ -774,9 +766,6 @@ def plot(sys, num_lead_cells=2, unit='nn', provided. file : string or file object or `None` The output file. If `None`, output will be shown instead. - show : bool - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. dpi : float or `None` Number of pixels per inch. If not set the ``matplotlib`` default is used. @@ -785,7 +774,7 @@ def plot(sys, num_lead_cells=2, unit='nn', ``matplotlib`` value is used. ax : ``matplotlib.axes.Axes`` instance or `None` If `ax` is not `None`, no new figure is created, but the plot is done - within the existing Axes `ax`. in this case, `file`, `show`, `dpi` + within the existing Axes `ax`. in this case, `file`, `dpi` and `fig_size` are ignored. Returns @@ -1095,7 +1084,7 @@ def plot(sys, num_lead_cells=2, unit='nn', if line_coll.get_array() is not None and colorbar and fig is not None: fig.colorbar(line_coll) - _maybe_output_fig(fig, file=file, show=show) + _maybe_output_fig(fig, file=file) return fig @@ -1184,7 +1173,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, method='nearest', oversampling=3, num_lead_cells=0, file=None, - show=True, dpi=None, fig_size=None, ax=None, pos_transform=None, + dpi=None, fig_size=None, ax=None, pos_transform=None, background='#e0e0e0'): """Show interpolated map of a function defined for the sites of a system. @@ -1230,12 +1219,9 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, the position of leads. Defaults to 0. file : string or file object or `None` The output file. If `None`, output will be shown instead. - show : bool - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. ax : ``matplotlib.axes.Axes`` instance or `None` If `ax` is not `None`, no new figure is created, but the plot is done - within the existing Axes `ax`. in this case, `file`, `show`, `dpi` + within the existing Axes `ax`. in this case, `file`, `dpi` and `fig_size` are ignored. pos_transform : function or `None` Transformation to be applied to the site position. @@ -1336,12 +1322,12 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, extend = 'max' fig.colorbar(image, extend=extend) - _maybe_output_fig(fig, file=file, show=show) + _maybe_output_fig(fig, file=file) return fig -def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, +def bands(sys, args=(), momenta=65, file=None, dpi=None, fig_size=None, ax=None, *, params=None): """Plot band structure of a translationally invariant 1D system. @@ -1357,9 +1343,6 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, array of points at which the band structure has to be evaluated. file : string or file object or `None` The output file. If `None`, output will be shown instead. - show : bool - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. dpi : float Number of pixels per inch. If not set the ``matplotlib`` default is used. @@ -1368,7 +1351,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, ``matplotlib`` value is used. ax : ``matplotlib.axes.Axes`` instance or `None` If `ax` is not `None`, no new figure is created, but the plot is done - within the existing Axes `ax`. in this case, `file`, `show`, `dpi` + within the existing Axes `ax`. in this case, `file`, `dpi` and `fig_size` are ignored. params : dict, optional Dictionary of parameter names and their values. Mutually exclusive @@ -1408,15 +1391,15 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, def h_k(k): # H_k = H_0 + V e^-ik + V^\dagger e^ik mat = hop * cmath.exp(-1j * k) - mat += mat.conjugate().transpose() + ham + mat += mat.conjugate().transpose() + ham return mat - return spectrum(h_k, ('k', momenta), file=file, show=show, dpi=dpi, + return spectrum(h_k, ('k', momenta), file=file, dpi=dpi, fig_size=fig_size, ax=ax) def spectrum(syst, x, y=None, params=None, mask=None, file=None, - show=True, dpi=None, fig_size=None, ax=None): + dpi=None, fig_size=None, ax=None): """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters Parameters @@ -1439,9 +1422,6 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, values. file : string or file object or `None` The output file. If `None`, output will be shown instead. - show : bool - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. dpi : float Number of pixels per inch. If not set the ``matplotlib`` default is used. @@ -1450,8 +1430,8 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, ``matplotlib`` value is used. ax : ``matplotlib.axes.Axes`` instance or `None` If `ax` is not `None`, no new figure is created, but the plot is done - within the existing Axes `ax`. in this case, `file`, `show`, `dpi` - and `fig_size` are ignored. + within the existing Axes `ax`. in this case, `file`, `dpi` and + `fig_size` are ignored. Returns ------- @@ -1531,7 +1511,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, spec = spectrum[:, :, i].transpose() # row-major to x-y ordering ax.plot_surface(*(grid + [spec]), cstride=1, rstride=1) - _maybe_output_fig(fig, file=file, show=show) + _maybe_output_fig(fig, file=file) return fig @@ -1896,8 +1876,7 @@ def _linear_cmap(a, b): def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', max_linewidth=3, min_linewidth=1, density=2/9, colorbar=True, file=None, - show=True, dpi=None, fig_size=None, ax=None, - vmax=None): + dpi=None, fig_size=None, ax=None, vmax=None): """Draw streamlines of a flow field in Kwant style Solid colored streamlines are drawn, superimposed on a color plot of @@ -1936,9 +1915,6 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', provided. file : string or file object or `None` The output file. If `None`, output will be shown instead. - show : bool - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. dpi : float or `None` Number of pixels per inch. If not set the ``matplotlib`` default is used. @@ -1947,8 +1923,8 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', ``matplotlib`` value is used. ax : ``matplotlib.axes.Axes`` instance or `None` If `ax` is not `None`, no new figure is created, but the plot is done - within the existing Axes `ax`. in this case, `file`, `show`, `dpi` - and `fig_size` are ignored. + within the existing Axes `ax`. in this case, `file`, `dpi` and + `fig_size` are ignored. vmax : float or `None` The upper saturation limit for the colormap; flows higher than this will saturate. Note that there is no corresponding vmin @@ -2022,13 +1998,13 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', if colorbar and cmap and fig is not None: fig.colorbar(image) - _maybe_output_fig(fig, file=file, show=show) + _maybe_output_fig(fig, file=file) return fig def scalarplot(field, box, - cmap=None, colorbar=True, file=None, show=True, + cmap=None, colorbar=True, file=None, dpi=None, fig_size=None, ax=None, vmin=None, vmax=None, background='#e0e0e0'): """Draw a scalar field in Kwant style @@ -2049,9 +2025,6 @@ def scalarplot(field, box, provided. file : string or file object, optional The output file. If not provided, output will be shown instead. - show : bool, default: True - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. dpi : float, optional Number of pixels per inch. If not set the ``matplotlib`` default is used. @@ -2060,8 +2033,8 @@ def scalarplot(field, box, ``matplotlib`` value is used. ax : ``matplotlib.axes.Axes`` instance, optional If ``ax`` is provided, no new figure is created, but the plot is done - within the existing Axes ``ax``. in this case, ``file``, ``show``, - ``dpi`` and ``fig_size`` are ignored. + within the existing Axes ``ax``. in this case, ``file``, ``dpi`` and + ``fig_size`` are ignored. vmin, vmax : float, optional The lower/upper saturation limit for the colormap. background : matplotlib color spec @@ -2111,7 +2084,7 @@ def scalarplot(field, box, if colorbar and cmap and fig is not None: fig.colorbar(image) - _maybe_output_fig(fig, file=file, show=show) + _maybe_output_fig(fig, file=file) return fig diff --git a/kwant/wraparound.py b/kwant/wraparound.py index d707c265a5f4c3f0f3ae67070eccc621bb05f1dd..ebed4d479009f6a5dfee7e077e01aae9f9692b44 100644 --- a/kwant/wraparound.py +++ b/kwant/wraparound.py @@ -309,7 +309,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): def plot_2d_bands(syst, k_x=31, k_y=31, params=None, mask_brillouin_zone=False, extend_bbox=0, file=None, - show=True, dpi=None, fig_size=None, ax=None): + dpi=None, fig_size=None, ax=None): """Plot 2D band structure of a wrapped around system. This function is primarily useful for systems that have translational @@ -345,9 +345,6 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None, directions). file : string or file object, optional The output file. If None, output will be shown instead. - show : bool, default: False - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. dpi : float, optional Number of pixels per inch. If not set the ``matplotlib`` default is used. @@ -356,8 +353,8 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None, ``matplotlib`` value is used. ax : ``matplotlib.axes.Axes`` instance, optional If `ax` is not `None`, no new figure is created, but the plot is done - within the existing Axes `ax`. in this case, `file`, `show`, `dpi` - and `fig_size` are ignored. + within the existing Axes `ax`. in this case, `file`, `dpi` and + `fig_size` are ignored. Returns ------- @@ -454,6 +451,5 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None, y=('k_y', ks[1]) if lat_ndim == 2 else None, params=params, mask=(outside_bz if mask_brillouin_zone else None), - file=file, show=show, dpi=dpi, - fig_size=fig_size, ax=ax) + file=file, dpi=dpi, fig_size=fig_size, ax=ax) return fig