Commit 134a3e82 authored by Adrien Sorgniard's avatar Adrien Sorgniard Committed by Christoph Groth
Browse files

add pos_transform option to plotter.map

parent c2e7fcc4
......@@ -2,9 +2,6 @@ Roughly in order of importance. -*-org-*-
* Document the order of sites/orbitals in finalized builders
* Add a pos_transform argument to kwant.plotter.map
Check whether other arguments from plot() could be useful.
* Add calculation of current density
* Re-design the interface of low level systems
......
......@@ -74,3 +74,8 @@ one writes ::
the error message will be more helpful now.
Please continue reporting confusing error messages on the Kwant mailing list.
New option ``pos_transform`` of `~kwant.plotter.map`
----------------------------------------------------------------
This option which already existed for `kwant.plotter.plot` is now also
available for `kwant.plotter.map`.
......@@ -1442,7 +1442,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
method='nearest', oversampling=3, num_lead_cells=0, file=None,
show=True, dpi=None, fig_size=None, ax=None):
show=True, dpi=None, fig_size=None, ax=None, pos_transform=None):
"""Show interpolated map of a function defined for the sites of a system.
Create a pixmap representation of a function of the sites of a system by
......@@ -1489,6 +1489,8 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
pos_transform : function or `None`
Transformation to be applied to the site position.
Returns
-------
......@@ -1508,8 +1510,13 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
sites = sys_leads_sites(sys, 0)[0]
coords = sys_leads_pos(sys, sites)
if pos_transform is not None:
coords = np.apply_along_axis(pos_transform, 1, coords)
if coords.shape[1] != 2:
raise ValueError('Only 2D systems can be plotted this way.')
if callable(value):
value = [value(site[0]) for site in sites]
else:
......
......@@ -130,14 +130,24 @@ def test_plot():
warnings.simplefilter("ignore")
plot(sys2d.finalized(), file=out)
def good_transform(pos):
x, y = pos
return y, x
def bad_transform(pos):
x, y = pos
return x, y, 0
def test_map():
if not plotter.mpl_enabled:
raise nose.SkipTest
sys = sys_2d()
with tempfile.TemporaryFile('w+b') as out:
plotter.map(sys, lambda site: site.tag[0], file=out,
method='linear', a=4, oversampling=4, cmap='flag')
plotter.map(sys, lambda site: site.tag[0], pos_transform=good_transform,
file=out, method='linear', a=4, oversampling=4, cmap='flag')
nose.tools.assert_raises(ValueError, plotter.map, sys,
lambda site: site.tag[0],
pos_transform=bad_transform, file=out)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plotter.map(sys.finalized(), xrange(len(sys.sites())),
......@@ -151,7 +161,7 @@ def test_mask_interpolate():
coords = np.random.rand(10, 2)
coords[5] *= 1e-8
coords[5] += coords[0]
warnings.simplefilter("ignore")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
......@@ -162,5 +172,5 @@ def test_mask_interpolate():
assert_raises(ValueError, plotter.mask_interpolate,
coords, np.ones(len(coords)))
assert_raises(ValueError, plotter.mask_interpolate,
assert_raises(ValueError, plotter.mask_interpolate,
coords, np.ones(2 * len(coords)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment