Skip to content
Snippets Groups Projects
Commit 486232e4 authored by Anton Akhmerov's avatar Anton Akhmerov Committed by Christoph Groth
Browse files

provide axes argument to every plotting function

parent 4ec4e7cf
No related branches found
No related tags found
No related merge requests found
......@@ -1059,12 +1059,12 @@ def plot(sys, num_lead_cells=2, unit='nn',
ax : `matplotlib.axes.Axes` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ingored.
and `fig_size` are ignored.
Returns
-------
result : matplotlib figure or axes instance
A figure with the output if ``ax==None``, otherwise `ax`.
fig : matplotlib figure
A figure with the output if `ax` is not set, else None.
Notes
-----
......@@ -1323,7 +1323,8 @@ def plot(sys, num_lead_cells=2, unit='nn',
if ax.collections[1].get_array() is not None and colorbar:
fig.colorbar(ax.collections[1])
return output_fig(fig, file=file, show=show) if fig else ax
if fig is not None:
return output_fig(fig, file=file, show=show)
def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
......@@ -1398,7 +1399,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
method='nearest', oversampling=3, num_lead_cells=0, file=None,
show=True, dpi=None, fig_size=None):
show=True, dpi=None, fig_size=None, ax=None):
"""Show interpolated map of a function defined for the sites of a system.
Create a pixmap representation of a function of the sites of a system by
......@@ -1440,11 +1441,15 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
show : bool
Whether `matplotlib.pyplot.show()` is to be called, and the output is
to be shown immediately. Defaults to `True`.
ax : `matplotlib.axes.Axes` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
Returns
-------
fig : matplotlib figure
A figure with the output.
A figure with the output if `ax` is not set, else None.
Notes
-----
......@@ -1467,13 +1472,16 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
border = 0.5 * (max - min) / (np.asarray(img.shape) - 1)
min -= border
max += border
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
ax = fig.add_subplot(1, 1, 1, aspect='equal', adjustable='datalim')
if ax is None:
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
ax = fig.add_subplot(1, 1, 1, aspect='equal', adjustable='datalim')
else:
fig = None
# Note that we tell imshow to show the array created by mask_interpolate
# faithfully and not to interpolate by itself another time.
......@@ -1487,11 +1495,12 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
if colorbar:
fig.colorbar(image)
return output_fig(fig, file=file, show=show)
if fig is not None:
return output_fig(fig, file=file, show=show)
def bands(sys, momenta=65, args=(), file=None, show=True, dpi=None,
fig_size=None):
fig_size=None, ax=None):
"""Plot band structure of a translationally invariant 1D system.
Parameters
......@@ -1514,11 +1523,15 @@ def bands(sys, momenta=65, args=(), file=None, show=True, dpi=None,
fig_size : tuple
Figure size `(width, height)` in inches. If not set, the default
`matplotlib` value is used.
ax : `matplotlib.axes.Axes` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
Returns
-------
fig : matplotlib figure
A figure with the output.
A figure with the output if `ax` is not set, else None.
Notes
-----
......@@ -1531,15 +1544,20 @@ def bands(sys, momenta=65, args=(), file=None, show=True, dpi=None,
bands = physics.Bands(sys, args=args)
energies = [bands(k) for k in momenta]
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
ax = fig.add_subplot(1, 1, 1)
if ax is None:
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
ax = fig.add_subplot(1, 1, 1)
else:
fig = None
ax.plot(momenta, energies)
return output_fig(fig, file=file, show=show)
if fig is not None:
return output_fig(fig, file=file, show=show)
# TODO (Anton): Fix plotting of parts of the system using color = np.nan.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment