From 82baedca2ad07113729fb34e8d20c541756017a2 Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Tue, 13 Mar 2018 18:51:24 +0100
Subject: [PATCH] improve heuristic for setting colormap limits in plotter.map

Previously we used the matplotlib default behavior of setting
the colormap limits (vmin and vmax) to the limits of the plotted
data, if they are not provided by the user. Now we set the
colormap limits to the 2nd and 98th percentile of the input
data, and stretch the limits by 10% on either end in an attempt
to include all the data in the limits. If we cannot do so
we set the limits to the percentiles plus the stretch
and issue a warning.

Also, whenever the data falls outside the colorbar limits
(even when the latter are set by the user) we set the appropriate
ends of the colorbar to be pointy, indicating that the data
extends beyond the colorbar.

Closes #183
---
 doc/source/pre/whatsnew/1.4.rst | 11 +++++++
 kwant/plotter.py                | 58 ++++++++++++++++++++++++++++++++-
 2 files changed, 68 insertions(+), 1 deletion(-)

diff --git a/doc/source/pre/whatsnew/1.4.rst b/doc/source/pre/whatsnew/1.4.rst
index ab00cd14..12c72b3b 100644
--- a/doc/source/pre/whatsnew/1.4.rst
+++ b/doc/source/pre/whatsnew/1.4.rst
@@ -12,3 +12,14 @@ The function `~kwant.plotter.streamplot` has got a new option ``vmax``.  Note
 that this option is not available in `~kwant.plotter.current`.  In order to use
 it, one has to call ``streamplot`` directly as shown in the docstring of
 ``current``.
+
+Improved heuristic for colorscale limits in `kwant.plotter.map`
+---------------------------------------------------------------
+Previously `~kwant.plotter.map` would set the limits for the color scale
+to the extrema of the data being plotted when ``vmin`` and ``vmax`` were
+not provided. This is the behaviour of ``matplotlib.imshow``. When the data
+to be plotted has very sharp and high peaks this would mean that most of the
+data would appear near the bottom of the color scale, and all of the features
+would be washed out by the presence of the peak. Now `~kwant.plotter.map`
+employs a heuristic for setting the colorscale when there are outliers,
+and will emit a warning when this is detected.
diff --git a/kwant/plotter.py b/kwant/plotter.py
index a5702141..405cf280 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -97,6 +97,34 @@ def set_colors(color, collection, cmap, norm=None):
     collection.set_color(colors)
 
 
+def percentile_bound(data, vmin, vmax, percentile=96, stretch=0.1):
+    """Return the bounds that captures at least 'percentile' of 'data'.
+
+    If 'vmin' or 'vmax' are provided, then the corresponding bound is
+    exactly 'vmin' or 'vmax'. First we set the bounds such that the
+    provided percentile of the data is within them. Then we try to
+    extend the bounds to cover all the data, maximally stretching each
+    bound by a factor 'stretch'.
+    """
+    if vmin is not None and vmax is not None:
+        return vmin, vmax
+
+    percentile = (100 - percentile) / 2
+    percentiles = (0, percentile, 100 - percentile, 100)
+    mn, bound_mn, bound_mx, mx = np.percentile(data.flatten(), percentiles)
+
+    bound_mn = bound_mn if vmin is None else vmin
+    bound_mx = bound_mx if vmax is None else vmax
+
+    # Stretch the lower and upper bounds to cover all the data, if
+    # we stretch the bound by less than a factor 'stretch'.
+    stretch = (bound_mx - bound_mn) * stretch
+    out_mn = max(bound_mn - stretch, mn) if vmin is None else vmin
+    out_mx = min(bound_mx + stretch, mx) if vmax is None else vmax
+
+    return (out_mn, out_mx)
+
+
 symbol_dict = {'O': 'o', 's': ('p', 4, 45), 'S': ('P', 4, 45)}
 
 def get_symbol(symbols):
@@ -1278,6 +1306,26 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
     if cmap is None:
         cmap = _p._colormaps.kwant_red
 
+    # Calculate the min/max bounds for the colormap.
+    # User-provided values take precedence.
+    unmasked_data = img[~img.mask].data.flatten()
+    new_vmin, new_vmax = percentile_bound(unmasked_data, vmin, vmax)
+    overflow_pct = 100 * np.sum(unmasked_data > new_vmax) / len(unmasked_data)
+    underflow_pct = 100 * np.sum(unmasked_data < new_vmin) / len(unmasked_data)
+    if (vmin is None and underflow_pct) or (vmax is None and overflow_pct):
+        msg = (
+            'The plotted data contains ',
+            '{:.2f}% of values overflowing upper limit {:g} '
+                .format(overflow_pct, new_vmax)
+                if overflow_pct > 0 else '',
+            'and ' if overflow_pct > 0 and underflow_pct > 0 else '',
+            '{:.2f}% of values underflowing lower limit {:g} '
+                .format(underflow_pct, new_vmin)
+                if underflow_pct > 0 else '',
+        )
+        warnings.warn(''.join(msg), RuntimeWarning, stacklevel=2)
+    vmin, vmax = new_vmin, new_vmax
+
     # Note that we tell imshow to show the array created by mask_interpolate
     # faithfully and not to interpolate by itself another time.
     image = ax.imshow(img.T, extent=(min[0], max[0], min[1], max[1]),
@@ -1291,7 +1339,15 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
     ax.patch.set_facecolor(background)
 
     if colorbar and fig is not None:
-        fig.colorbar(image)
+        # Make the colorbar ends pointy if we saturate the colormap
+        extend = 'neither'
+        if underflow_pct > 0 and overflow_pct > 0:
+            extend = 'both'
+        elif underflow_pct > 0:
+            extend = 'min'
+        elif overflow_pct > 0:
+            extend = 'max'
+        fig.colorbar(image, extend=extend)
 
     if fig is not None:
         return output_fig(fig, file=file, show=show)
-- 
GitLab