Commit 50359cd9 authored by Christoph Groth's avatar Christoph Groth Committed by Joseph Weston
Browse files

refactor current plotter API

parent fdc028c0
......@@ -51,9 +51,11 @@ except ImportError:
from . import system, builder, _common
__all__ = ['plot', 'map', 'bands', 'spectrum', 'current',
'interpolate_current', 'sys_leads_sites', 'sys_leads_hoppings',
'sys_leads_pos', 'sys_leads_hopping_pos', 'mask_interpolate']
'interpolate_current', 'streamplot',
'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos',
'sys_leads_hopping_pos', 'mask_interpolate']
# TODO: Remove the following once we depend on matplotlib >= 1.4.1.
......@@ -373,6 +375,16 @@ if mpl_enabled:
# matplotlib helper functions.
def _make_figure(dpi, fig_size):
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
return fig
def set_colors(color, collection, cmap, norm=None):
"""Process a color specification to a format accepted by collections.
......@@ -1314,13 +1326,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
# make a new figure unless axes specified
if not ax:
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
fig = _make_figure(dpi, fig_size)
if dim == 2:
ax = fig.add_subplot(1, 1, 1, aspect='equal')
ax.set_xmargin(0.05)
......@@ -1575,12 +1581,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
min -= border
max += border
if ax is None:
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
fig = _make_figure(dpi, fig_size)
ax = fig.add_subplot(1, 1, 1, aspect='equal')
else:
fig = None
......@@ -1801,7 +1802,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
return output_fig(fig, file=file, show=show)
def interpolate_current(syst, current, width=4, limit=None, n=9, a=None):
def interpolate_current(syst, current, relwidth=None, abswidth=None, n=9):
"""Interpolate currents in a system onto a regular grid.
The system graph together with current intensities defines a "discrete"
......@@ -1823,22 +1824,21 @@ def interpolate_current(syst, current, width=4, limit=None, n=9, a=None):
current : '1D array of float'
Must contain the intensity on each hoppings in the same order that they
appear in syst.graph.
width : float
(Minimum) width of the bumps used to generate the field, in units of
``a``. See also `limit`.
limit : float or `None`
Absolute resolution limit. The effective value of `width` is limited
to the length of the longest side of the bounding box divided by this
number.
relwidth : float or `None`
Relative width of the bumps used to generate the field, as a fraction
of the length of the longest side of the bounding box. This argument
is only used if `abswidth` is not given.
abswidth : float or `None`
Absolute width ot the bumps used to generate the field. Takes
precedence over `relwidth`. If neither is given, the bump width is set
to four times the length of the shortest hopping.
n : int
Number of points the grid must have over the width of the bump.
a : float
By default, the length of the shortest hopping.
Returns
-------
region : list of the coordinates of the grid for each dimension
field : value of the generated field on the grid points
box : the start/end points of the bounding box: ((x0, x1), (y0, y1))
"""
if not isinstance(syst, builder.FiniteSystem):
......@@ -1877,12 +1877,24 @@ def interpolate_current(syst, current, width=4, limit=None, n=9, a=None):
dirs = hops[:, 1] - hops[:, 0]
lens = np.sqrt(np.sum(dirs * dirs, -1))
dirs /= lens[:, None]
if a == None:
a = min(lens)
width *= a
if limit is not None:
width = max(np.max(bbox_size) / limit, width)
if abswidth is None:
if relwidth is None:
unique_lens = np.unique(lens)
longest = unique_lens[-1]
for shortest_nonzero in unique_lens:
if shortest_nonzero / longest < 1e-5:
break
width = 4 * shortest_nonzero
else:
width = relwidth * np.max(bbox_size)
else:
width = abswidth
# TODO: Generalize 'factor' prefactor to arbitrary dimensions and remove
# this check. This check is done here to keep changes local
if dim != 2:
raise ValueError("'interpolate_current' only works for 2D systems.")
factor = (3 / np.pi) / (width / 2)
scale = 2 / width
lens *= scale
......@@ -1947,7 +1959,7 @@ def interpolate_current(syst, current, width=4, limit=None, n=9, a=None):
field[field_slice] += dirs[i] * magns[..., None]
# 'field' contains contributions from both hoppings (i, j) and (j, i)
return region, field
return field, ((region[0][0], region[0][-1]), (region[1][0], region[1][-1]))
def _gamma_compress(linear):
......@@ -1972,13 +1984,31 @@ def _gamma_expand(corrected):
_gamma_expand = np.vectorize(_gamma_expand, otypes=[float])
def _streamplot(field, box, colorbar=True, cmap=None, file=None,
show=True, dpi=None, fig_size=None, ax=None,
linecolor='k', max_linewidth=3, min_linewidth=1, density=1):
def _linear_cmap(a, b):
"""Make a colormap that linearly interpolates between the colors a and b."""
a = matplotlib.colors.colorConverter.to_rgb(a)
b = matplotlib.colors.colorConverter.to_rgb(b)
a_linear = _gamma_expand(a)
b_linear = _gamma_expand(b)
color_diff = a_linear - b_linear
palette = (np.linspace(0, 1, 256).reshape((-1, 1))
* color_diff.reshape((1, -1)))
palette += b_linear
palette = _gamma_compress(palette)
return matplotlib.colors.ListedColormap(palette)
def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k',
max_linewidth=3, min_linewidth=1, density=2/9,
colorbar=True, file=None,
show=True, dpi=None, fig_size=None, ax=None):
if not mpl_enabled:
raise RuntimeError("matplotlib was not found, but is required "
"for current()")
# Matplotlib's "density" is in units of 30 streamlines...
density *= 1 / 30 * ta.array(field.shape[:2], int)
# Matplotlib plots images like matrices: image[y, x]. We use the opposite
# convention: image[x, y]. Hence, it is necessary to transpose.
field = field.transpose(1, 0, 2)
......@@ -1986,18 +2016,17 @@ def _streamplot(field, box, colorbar=True, cmap=None, file=None,
if field.shape[-1] != 2 or field.ndim != 3:
raise ValueError("Only 2D field can be plotted.")
# The default colormap is extremely ugly with streamplot.
if cmap is None:
cmap = _colormaps.kwant_red
cmap = matplotlib.cm.get_cmap(cmap)
if bgcolor is None:
if cmap is None:
cmap = _colormaps.kwant_red
cmap = matplotlib.cm.get_cmap(cmap)
bgcolor = cmap(0)[:3]
elif cmap is not None:
raise ValueError("The parameters 'cmap' and 'bgcolor' are "
"mutually exclusive.")
if ax is None:
fig = Figure()
if dpi is not None:
fig.set_dpi(dpi)
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
fig = _make_figure(dpi, fig_size)
ax = fig.add_subplot(1, 1, 1, aspect='equal')
else:
fig = None
......@@ -2007,27 +2036,20 @@ def _streamplot(field, box, colorbar=True, cmap=None, file=None,
speed = np.linalg.norm(field, axis=-1)
image = ax.imshow(speed, cmap=cmap,
interpolation='bicubic',
extent=[e for c in box for e in c],
origin='lower')
if cmap is None:
ax.set_axis_bgcolor(bgcolor)
else:
image = ax.imshow(speed, cmap=cmap,
interpolation='bicubic',
extent=[e for c in box for e in c],
origin='lower')
linewidth = max_linewidth / (np.max(speed) or 1) * speed
color = linewidth / min_linewidth
linewidth[linewidth < min_linewidth] = min_linewidth
color[color > 1] = 1
# Make a colormap that linearly interpolates between the line color and
# the background color.
linecolor = matplotlib.colors.colorConverter.to_rgb(linecolor)
linecolor_linear = _gamma_expand(linecolor)
bgcolor_linear = _gamma_expand(cmap(0)[:3])
color_diff = linecolor_linear - bgcolor_linear
line_cmap = (np.linspace(0, 1, 256).reshape((-1, 1))
* color_diff.reshape((1, -1)))
line_cmap += bgcolor_linear
line_cmap = _gamma_compress(line_cmap)
line_cmap = matplotlib.colors.ListedColormap(line_cmap)
line_cmap = _linear_cmap(linecolor, cmap(0))
ax.streamplot(X, Y, field[:,:,0], field[:,:,1],
density=density, linewidth=linewidth,
......@@ -2036,16 +2058,13 @@ def _streamplot(field, box, colorbar=True, cmap=None, file=None,
ax.set_xlim(*box[0])
ax.set_ylim(*box[1])
if colorbar and fig is not None:
fig.colorbar(image)
if fig is not None:
if colorbar and bgcolor is None:
fig.colorbar(image)
return output_fig(fig, file=file, show=show)
def current(syst, current, width=4, limit=30, n=9, a=None, density=2,
colorbar=True, cmap=None, file=None, show=True, dpi=None,
fig_size=None, ax=None, linecolor='k', max_linewidth=3):
def current(syst, current, relwidth=0.03, **kwargs):
"""Show an interpolated current defined for the hoppings of a system.
The system graph together with current intensities defines a "discrete"
......@@ -2068,40 +2087,11 @@ def current(syst, current, width=4, limit=30, n=9, a=None, density=2,
Sequence of values defining currents on each hopping of the system.
Ordered in the same way as ``syst.graph``. This typically will be
the result of evaluating a `~kwant.operator.Current` operator.
width : float
(Minimum) width of the bumps used to generate the field, in units of
``a``. See also `limit`.
limit : float or `None`
Absolute resolution limit. The effective value of `width` is limited
to the length of the longest side of the bounding box divided by this
number.
n : int
Number of points the grid over the width of a bump.
a : float
A reference length. Be default, the length of the shortest hopping.
density : float
The number of streamlines per bump width.
colorbar : bool, optional
Whether to show a color bar if numerical data has to be plotted.
Defaults to `True`. If `ax` is provided, the colorbar is never plotted.
cmap : ``matplotlib`` color map or `None`
The color map used for sites and optionally hoppings, if `None`,
``matplotlib`` default is used.
file : string or file object or `None`
The output file. If `None`, output will be shown instead.
show : bool
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float or `None`
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
fig_size : tuple or `None`
Figure size `(width, height)` in inches. If not set, the default
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
relwidth : float or `None`
Relative width of the bumps used to generate the field, as a fraction
of the length of the longest side of the bounding box.
**kwargs : various
Keyword args to be passed verbatim to `~kwant.plotter.streamplot`.
Returns
-------
......@@ -2109,24 +2099,8 @@ def current(syst, current, width=4, limit=30, n=9, a=None, density=2,
A figure with the output if `ax` is not set, else None.
"""
if not mpl_enabled:
raise RuntimeError("matplotlib was not found, but is required "
"for current()")
region, field = interpolate_current(syst, current, width=width, limit=limit,
n=n, a=a)
line_resolution = density / n * ta.array(field.shape[:2], int)
# "density" is matplotlib's name for the number of streamlines divided by
# 30.
return _streamplot(field, ((region[0][0], region[0][-1]),
(region[1][0], region[1][-1])),
colorbar=colorbar, cmap=cmap, file=file,
show=show, dpi=dpi, fig_size=fig_size, ax=ax,
linecolor=linecolor, max_linewidth=max_linewidth,
density=line_resolution / 30)
return streamplot(*interpolate_current(syst, current, relwidth),
**kwargs)
# TODO (Anton): Fix plotting of parts of the system using color = np.nan.
......
......@@ -352,7 +352,10 @@ def test_current_interpolation():
data = []
for n in [4, 6, 8, 11, 16]:
(x, y), j0 = plotter.interpolate_current(syst, J(psi[0]), n=n, width=width)
j0, box = plotter.interpolate_current(syst, J(psi[0]),
n=n, abswidth=width)
x, y = (np.linspace(mn, mx, shape)
for (mn, mx), shape in zip(box, j0.shape))
# slice field perpendicular to a cut along the y axis
y_axis = (np.argmin(np.abs(x)), slice(None), 0)
J_interp = scipy.integrate.simps(j0[y_axis], y)
......@@ -383,19 +386,19 @@ def test_current_interpolation():
divergence[a] += current
assert np.allclose(divergence, 0)
_, j0 = plotter.interpolate_current(syst, J0)
_, j1 = plotter.interpolate_current(syst, J1)
j0, _ = plotter.interpolate_current(syst, J0)
j1, _ = plotter.interpolate_current(syst, J1)
## Test linearity of interpolation.
_, j_tot = plotter.interpolate_current(syst, J0 + 2 * J1)
j_tot, _ = plotter.interpolate_current(syst, J0 + 2 * J1)
assert np.allclose(j_tot, j0 + 2 * j1)
## Test that divergence of interpolated current approaches zero as we make
## the interpolation finer.
data = []
for n in [4, 6, 8, 11, 16]:
grid, j = plotter.interpolate_current(syst, J0, n=n)
dx = [g[1] - g[0] for g in grid]
j, box = plotter.interpolate_current(syst, J0, n=n)
dx = [(mx - mn) / (shape - 1) for (mn, mx), shape in zip(box, j.shape)]
div_j = np.max(np.abs(div(j, dx)))
data.append((n, div_j))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment