From 134a3e8266f5e612edc9423047499ecbbc116dba Mon Sep 17 00:00:00 2001
From: Adrien Sorgniard <adrien.sorgniard@ecl13.ec-lyon.fr>
Date: Mon, 18 May 2015 12:51:54 +0200
Subject: [PATCH] add pos_transform option to plotter.map

---
 TODO                            |  3 ---
 doc/source/pre/whatsnew/1.1.rst |  5 +++++
 kwant/plotter.py                |  9 ++++++++-
 kwant/tests/test_plotter.py     | 18 ++++++++++++++----
 4 files changed, 27 insertions(+), 8 deletions(-)

diff --git a/TODO b/TODO
index 907aa9bf..f253a9d2 100644
--- a/TODO
+++ b/TODO
@@ -2,9 +2,6 @@ Roughly in order of importance.                                     -*-org-*-
 
 * Document the order of sites/orbitals in finalized builders
 
-* Add a pos_transform argument to kwant.plotter.map
-  Check whether other arguments from plot() could be useful.
-
 * Add calculation of current density
 
 * Re-design the interface of low level systems
diff --git a/doc/source/pre/whatsnew/1.1.rst b/doc/source/pre/whatsnew/1.1.rst
index 703c3829..babf9b06 100644
--- a/doc/source/pre/whatsnew/1.1.rst
+++ b/doc/source/pre/whatsnew/1.1.rst
@@ -74,3 +74,8 @@ one writes ::
 the error message will be more helpful now.
 
 Please continue reporting confusing error messages on the Kwant mailing list.
+
+New option ``pos_transform`` of `~kwant.plotter.map`
+----------------------------------------------------------------
+This option which already existed for `kwant.plotter.plot` is now also
+available for `kwant.plotter.map`.
diff --git a/kwant/plotter.py b/kwant/plotter.py
index c9e8d370..52e06f2e 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -1442,7 +1442,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
 
 def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
         method='nearest', oversampling=3, num_lead_cells=0, file=None,
-        show=True, dpi=None, fig_size=None, ax=None):
+        show=True, dpi=None, fig_size=None, ax=None, pos_transform=None):
     """Show interpolated map of a function defined for the sites of a system.
 
     Create a pixmap representation of a function of the sites of a system by
@@ -1489,6 +1489,8 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
         If `ax` is not `None`, no new figure is created, but the plot is done
         within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
         and `fig_size` are ignored.
+    pos_transform : function or `None`
+        Transformation to be applied to the site position.
 
     Returns
     -------
@@ -1508,8 +1510,13 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
 
     sites = sys_leads_sites(sys, 0)[0]
     coords = sys_leads_pos(sys, sites)
+
+    if pos_transform is not None:
+        coords = np.apply_along_axis(pos_transform, 1, coords)
+
     if coords.shape[1] != 2:
         raise ValueError('Only 2D systems can be plotted this way.')
+
     if callable(value):
         value = [value(site[0]) for site in sites]
     else:
diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py
index 08a6fbfe..05bdd3af 100644
--- a/kwant/tests/test_plotter.py
+++ b/kwant/tests/test_plotter.py
@@ -130,14 +130,24 @@ def test_plot():
             warnings.simplefilter("ignore")
             plot(sys2d.finalized(), file=out)
 
+def good_transform(pos):
+    x, y = pos
+    return y, x
+
+def bad_transform(pos):
+    x, y = pos
+    return x, y, 0
 
 def test_map():
     if not plotter.mpl_enabled:
         raise nose.SkipTest
     sys = sys_2d()
     with tempfile.TemporaryFile('w+b') as out:
-        plotter.map(sys, lambda site: site.tag[0], file=out,
-                          method='linear', a=4, oversampling=4, cmap='flag')
+        plotter.map(sys, lambda site: site.tag[0], pos_transform=good_transform,
+                    file=out, method='linear', a=4, oversampling=4, cmap='flag')
+        nose.tools.assert_raises(ValueError, plotter.map, sys,
+                                 lambda site: site.tag[0],
+                                 pos_transform=bad_transform, file=out)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
             plotter.map(sys.finalized(), xrange(len(sys.sites())),
@@ -151,7 +161,7 @@ def test_mask_interpolate():
     coords = np.random.rand(10, 2)
     coords[5] *= 1e-8
     coords[5] += coords[0]
-    
+
     warnings.simplefilter("ignore")
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
@@ -162,5 +172,5 @@ def test_mask_interpolate():
 
     assert_raises(ValueError, plotter.mask_interpolate,
                   coords, np.ones(len(coords)))
-    assert_raises(ValueError, plotter.mask_interpolate, 
+    assert_raises(ValueError, plotter.mask_interpolate,
                   coords, np.ones(2 * len(coords)))
-- 
GitLab