Commit a9a4770a authored by Anton Akhmerov's avatar Anton Akhmerov
Browse files

remove show option from plotting functions

This makes Kwant play nicer with pyplot, and allows to modify the figures
directly using pyplot functions.

Closes gitlab issue #233
parent 13cc0ee3
Pipeline #15238 passed with stages
in 42 minutes and 9 seconds
......@@ -41,11 +41,11 @@
#HIDDEN_BEGIN_pejz
def main():
lead = make_lead().finalized()
- kwant.plotter.bands(lead, show=False)
- kwant.plotter.bands(lead)
- pyplot.xlabel("momentum [(lattice constant)^-1]")
- pyplot.ylabel("energy [t]")
- pyplot.show()
+ fig = kwant.plotter.bands(lead, show=False)
+ fig = kwant.plotter.bands(lead)
+ pyplot.xlabel("momentum [(lattice constant)^-1]",
+ fontsize=_defs.mpl_label_size)
+ pyplot.ylabel("energy [t]", fontsize=_defs.mpl_label_size)
......
......@@ -63,7 +63,7 @@
ham = syst.hamiltonian_submatrix(params=dict(V=potential), sparse=True)
evecs = scipy.sparse.linalg.eigsh(ham, k=10, which='SM')[1]
kwant.plotter.map(syst, abs(evecs[:, n])**2, show=False)
kwant.plotter.map(syst, abs(evecs[:, n])**2)
#HIDDEN_END_plot_eigenstate
- plt.show()
+ save_figure('discretizer_gs')
......@@ -111,7 +111,7 @@
#HIDDEN_BEGIN_plot_qsh_band
kwant.plotter.bands(syst.leads[0], params=params,
momenta=np.linspace(-0.3, 0.3, 201), show=False)
momenta=np.linspace(-0.3, 0.3, 201))
#HIDDEN_END_plot_qsh_band
plt.grid()
......
......@@ -399,7 +399,7 @@
def plot_and_label_modes(lead, E):
# Plot the different modes
pmodes, _ = lead.modes(energy=E)
kwant.plotter.bands(lead, show=False)
kwant.plotter.bands(lead)
for i, k in enumerate(pmodes.momenta):
plt.plot(k, E, 'ko')
plt.annotate(str(i), xy=(k, E), xytext=(-5, 8),
......
......@@ -178,3 +178,18 @@ data would appear near the bottom of the color scale, and all of the features
would be washed out by the presence of the peak. Now `~kwant.plotter.map`
employs a heuristic for setting the colorscale when there are outliers,
and will emit a warning when this is detected.
Simplified integration with pyplot
----------------------------------
Previously Kwant would only produce plots if the user separately imported
``matplotlib.pyplot``. Futher, Kwant would always call ``pyplot.show()``
unless ``show=False`` argument was supplied. In the new version Kwant
automatically imports pyplot on first use of plotting functionality and does not
show the figures automatically. The latter allows to easier modify the figures
produced by Kwant::
kwant.plotter.plot(syst)
pyplot.title("System plot")
The ``show`` keyword argument was also removed from all plotting functions.
If you need to restore old behavior call ``pyplot.show()`` after the plotting.
......@@ -82,7 +82,7 @@ def _make_figure(dpi, fig_size, use_pyplot=False):
return fig
def _maybe_output_fig(fig, file=None, show=True):
def _maybe_output_fig(fig, file=None):
"""Output a matplotlib figure using a given output mode.
Parameters
......@@ -92,9 +92,6 @@ def _maybe_output_fig(fig, file=None, show=True):
file : string or a file object
The name of the target file or the target file itself
(opened for writing).
show : bool
Whether to call ``matplotlib.pyplot.show()``. Only has an effect if
not saving to a file.
Notes
-----
......@@ -107,11 +104,6 @@ def _maybe_output_fig(fig, file=None, show=True):
if file is not None:
fig.canvas.print_figure(file, dpi=fig.dpi)
elif show:
# If there was no file provided, pyplot should already be available and
# we can import it safely without additional warnings.
from matplotlib import pyplot
pyplot.show()
def set_colors(color, collection, cmap, norm=None):
......@@ -682,7 +674,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
lead_site_edgecolor=None, lead_site_lw=None,
lead_hop_lw=None, pos_transform=None,
cmap='gray', colorbar=True, file=None,
show=True, dpi=None, fig_size=None, ax=None):
dpi=None, fig_size=None, ax=None):
"""Plot a system in 2 or 3 dimensions.
An alias exists for this common name: ``kwant.plot``.
......@@ -780,9 +772,6 @@ def plot(sys, num_lead_cells=2, unit='nn',
provided.
file : string or file object or `None`
The output file. If `None`, output will be shown instead.
show : bool
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float or `None`
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
......@@ -791,7 +780,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
within the existing Axes `ax`. in this case, `file`, `dpi`
and `fig_size` are ignored.
Returns
......@@ -1101,7 +1090,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
if line_coll.get_array() is not None and colorbar and fig is not None:
fig.colorbar(line_coll)
_maybe_output_fig(fig, file=file, show=show)
_maybe_output_fig(fig, file=file)
return fig
......@@ -1190,7 +1179,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
method='nearest', oversampling=3, num_lead_cells=0, file=None,
show=True, dpi=None, fig_size=None, ax=None, pos_transform=None,
dpi=None, fig_size=None, ax=None, pos_transform=None,
background='#e0e0e0'):
"""Show interpolated map of a function defined for the sites of a system.
......@@ -1236,12 +1225,9 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
the position of leads. Defaults to 0.
file : string or file object or `None`
The output file. If `None`, output will be shown instead.
show : bool
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
ax : ``matplotlib.axes.Axes`` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
within the existing Axes `ax`. in this case, `file`, `dpi`
and `fig_size` are ignored.
pos_transform : function or `None`
Transformation to be applied to the site position.
......@@ -1343,12 +1329,12 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
extend = 'max'
fig.colorbar(image, extend=extend)
_maybe_output_fig(fig, file=file, show=show)
_maybe_output_fig(fig, file=file)
return fig
def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
def bands(sys, args=(), momenta=65, file=None, dpi=None,
fig_size=None, ax=None, *, params=None):
"""Plot band structure of a translationally invariant 1D system.
......@@ -1364,9 +1350,6 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
array of points at which the band structure has to be evaluated.
file : string or file object or `None`
The output file. If `None`, output will be shown instead.
show : bool
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
......@@ -1375,7 +1358,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
within the existing Axes `ax`. in this case, `file`, `dpi`
and `fig_size` are ignored.
params : dict, optional
Dictionary of parameter names and their values. Mutually exclusive
......@@ -1415,15 +1398,15 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
def h_k(k):
# H_k = H_0 + V e^-ik + V^\dagger e^ik
mat = hop * cmath.exp(-1j * k)
mat += mat.conjugate().transpose() + ham
mat += mat.conjugate().transpose() + ham
return mat
return spectrum(h_k, ('k', momenta), file=file, show=show, dpi=dpi,
return spectrum(h_k, ('k', momenta), file=file, dpi=dpi,
fig_size=fig_size, ax=ax)
def spectrum(syst, x, y=None, params=None, mask=None, file=None,
show=True, dpi=None, fig_size=None, ax=None):
dpi=None, fig_size=None, ax=None):
"""Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters
Parameters
......@@ -1446,9 +1429,6 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
values.
file : string or file object or `None`
The output file. If `None`, output will be shown instead.
show : bool
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
......@@ -1457,8 +1437,8 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
within the existing Axes `ax`. in this case, `file`, `dpi` and
`fig_size` are ignored.
Returns
-------
......@@ -1544,7 +1524,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
spec = spectrum[:, :, i].transpose() # row-major to x-y ordering
ax.plot_surface(*(grid + [spec]), cstride=1, rstride=1)
_maybe_output_fig(fig, file=file, show=show)
_maybe_output_fig(fig, file=file)
return fig
......@@ -1909,8 +1889,7 @@ def _linear_cmap(a, b):
def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
max_linewidth=3, min_linewidth=1, density=2/9,
colorbar=True, file=None,
show=True, dpi=None, fig_size=None, ax=None,
vmax=None):
dpi=None, fig_size=None, ax=None, vmax=None):
"""Draw streamlines of a flow field in Kwant style
Solid colored streamlines are drawn, superimposed on a color plot of
......@@ -1949,9 +1928,6 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
provided.
file : string or file object or `None`
The output file. If `None`, output will be shown instead.
show : bool
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float or `None`
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
......@@ -1960,8 +1936,8 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
within the existing Axes `ax`. in this case, `file`, `dpi` and
`fig_size` are ignored.
vmax : float or `None`
The upper saturation limit for the colormap; flows higher than
this will saturate. Note that there is no corresponding vmin
......@@ -2035,13 +2011,13 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
if colorbar and cmap and fig is not None:
fig.colorbar(image)
_maybe_output_fig(fig, file=file, show=show)
_maybe_output_fig(fig, file=file)
return fig
def scalarplot(field, box,
cmap=None, colorbar=True, file=None, show=True,
cmap=None, colorbar=True, file=None,
dpi=None, fig_size=None, ax=None, vmin=None, vmax=None,
background='#e0e0e0'):
"""Draw a scalar field in Kwant style
......@@ -2062,9 +2038,6 @@ def scalarplot(field, box,
provided.
file : string or file object, optional
The output file. If not provided, output will be shown instead.
show : bool, default: True
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately.
dpi : float, optional
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
......@@ -2073,8 +2046,8 @@ def scalarplot(field, box,
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance, optional
If ``ax`` is provided, no new figure is created, but the plot is done
within the existing Axes ``ax``. in this case, ``file``, ``show``,
``dpi`` and ``fig_size`` are ignored.
within the existing Axes ``ax``. in this case, ``file``, ``dpi`` and
``fig_size`` are ignored.
vmin, vmax : float, optional
The lower/upper saturation limit for the colormap.
background : matplotlib color spec
......@@ -2124,7 +2097,7 @@ def scalarplot(field, box,
if colorbar and cmap and fig is not None:
fig.colorbar(image)
_maybe_output_fig(fig, file=file, show=show)
_maybe_output_fig(fig, file=file)
return fig
......
......@@ -309,7 +309,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
mask_brillouin_zone=False, extend_bbox=0, file=None,
show=True, dpi=None, fig_size=None, ax=None):
dpi=None, fig_size=None, ax=None):
"""Plot 2D band structure of a wrapped around system.
This function is primarily useful for systems that have translational
......@@ -345,9 +345,6 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
directions).
file : string or file object, optional
The output file. If None, output will be shown instead.
show : bool, default: False
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float, optional
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
......@@ -356,8 +353,8 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance, optional
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
within the existing Axes `ax`. in this case, `file`, `dpi` and
`fig_size` are ignored.
Returns
-------
......@@ -454,6 +451,5 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
y=('k_y', ks[1]) if lat_ndim == 2 else None,
params=params,
mask=(outside_bz if mask_brillouin_zone else None),
file=file, show=show, dpi=dpi,
fig_size=fig_size, ax=ax)
file=file, dpi=dpi, fig_size=fig_size, ax=ax)
return fig
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment