From 134a3e8266f5e612edc9423047499ecbbc116dba Mon Sep 17 00:00:00 2001 From: Adrien Sorgniard <adrien.sorgniard@ecl13.ec-lyon.fr> Date: Mon, 18 May 2015 12:51:54 +0200 Subject: [PATCH] add pos_transform option to plotter.map --- TODO | 3 --- doc/source/pre/whatsnew/1.1.rst | 5 +++++ kwant/plotter.py | 9 ++++++++- kwant/tests/test_plotter.py | 18 ++++++++++++++---- 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/TODO b/TODO index 907aa9bf..f253a9d2 100644 --- a/TODO +++ b/TODO @@ -2,9 +2,6 @@ Roughly in order of importance. -*-org-*- * Document the order of sites/orbitals in finalized builders -* Add a pos_transform argument to kwant.plotter.map - Check whether other arguments from plot() could be useful. - * Add calculation of current density * Re-design the interface of low level systems diff --git a/doc/source/pre/whatsnew/1.1.rst b/doc/source/pre/whatsnew/1.1.rst index 703c3829..babf9b06 100644 --- a/doc/source/pre/whatsnew/1.1.rst +++ b/doc/source/pre/whatsnew/1.1.rst @@ -74,3 +74,8 @@ one writes :: the error message will be more helpful now. Please continue reporting confusing error messages on the Kwant mailing list. + +New option ``pos_transform`` of `~kwant.plotter.map` +---------------------------------------------------------------- +This option which already existed for `kwant.plotter.plot` is now also +available for `kwant.plotter.map`. diff --git a/kwant/plotter.py b/kwant/plotter.py index c9e8d370..52e06f2e 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -1442,7 +1442,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, method='nearest', oversampling=3, num_lead_cells=0, file=None, - show=True, dpi=None, fig_size=None, ax=None): + show=True, dpi=None, fig_size=None, ax=None, pos_transform=None): """Show interpolated map of a function defined for the sites of a system. Create a pixmap representation of a function of the sites of a system by @@ -1489,6 +1489,8 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, If `ax` is not `None`, no new figure is created, but the plot is done within the existing Axes `ax`. in this case, `file`, `show`, `dpi` and `fig_size` are ignored. + pos_transform : function or `None` + Transformation to be applied to the site position. Returns ------- @@ -1508,8 +1510,13 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, sites = sys_leads_sites(sys, 0)[0] coords = sys_leads_pos(sys, sites) + + if pos_transform is not None: + coords = np.apply_along_axis(pos_transform, 1, coords) + if coords.shape[1] != 2: raise ValueError('Only 2D systems can be plotted this way.') + if callable(value): value = [value(site[0]) for site in sites] else: diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index 08a6fbfe..05bdd3af 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -130,14 +130,24 @@ def test_plot(): warnings.simplefilter("ignore") plot(sys2d.finalized(), file=out) +def good_transform(pos): + x, y = pos + return y, x + +def bad_transform(pos): + x, y = pos + return x, y, 0 def test_map(): if not plotter.mpl_enabled: raise nose.SkipTest sys = sys_2d() with tempfile.TemporaryFile('w+b') as out: - plotter.map(sys, lambda site: site.tag[0], file=out, - method='linear', a=4, oversampling=4, cmap='flag') + plotter.map(sys, lambda site: site.tag[0], pos_transform=good_transform, + file=out, method='linear', a=4, oversampling=4, cmap='flag') + nose.tools.assert_raises(ValueError, plotter.map, sys, + lambda site: site.tag[0], + pos_transform=bad_transform, file=out) with warnings.catch_warnings(): warnings.simplefilter("ignore") plotter.map(sys.finalized(), xrange(len(sys.sites())), @@ -151,7 +161,7 @@ def test_mask_interpolate(): coords = np.random.rand(10, 2) coords[5] *= 1e-8 coords[5] += coords[0] - + warnings.simplefilter("ignore") with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -162,5 +172,5 @@ def test_mask_interpolate(): assert_raises(ValueError, plotter.mask_interpolate, coords, np.ones(len(coords))) - assert_raises(ValueError, plotter.mask_interpolate, + assert_raises(ValueError, plotter.mask_interpolate, coords, np.ones(2 * len(coords))) -- GitLab