From 24c94298ba7682621049aca5eaae46a09b5d3679 Mon Sep 17 00:00:00 2001
From: Michael Wimmer <wimmer@lorentz.leidenuniv.nl>
Date: Wed, 28 Aug 2013 17:29:13 +0200
Subject: [PATCH] fix colorbar if multiple symbols are used

---
 kwant/plotter.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/kwant/plotter.py b/kwant/plotter.py
index 335ba857..c96a33d4 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -1305,14 +1305,14 @@ def plot(sys, num_lead_cells=2, unit='nn',
                    site_edgecolor)
         lw = site_lw[slc] if isarray(site_lw) else site_lw
 
-        symbols(ax, sites_pos[slc], size=size,
-                reflen=reflen, symbol=symbol,
-                facecolor=col, edgecolor=edgecol,
-                linewidth=lw, cmap=cmap, norm=norm, zorder=2)
+        symbol_coll = symbols(ax, sites_pos[slc], size=size,
+                              reflen=reflen, symbol=symbol,
+                              facecolor=col, edgecolor=edgecol,
+                              linewidth=lw, cmap=cmap, norm=norm, zorder=2)
 
     end, start = end_pos[: n_sys_hops], start_pos[: n_sys_hops]
-    lines(ax, end, start, reflen, hop_color, linewidths=hop_lw, zorder=1,
-          cmap=hop_cmap)
+    line_coll = lines(ax, end, start, reflen, hop_color, linewidths=hop_lw,
+                      zorder=1, cmap=hop_cmap)
 
     # plot lead sites and hoppings
     norm = matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5)
@@ -1351,10 +1351,11 @@ def plot(sys, num_lead_cells=2, unit='nn',
         m = (min_ + max_) / 2
         ax.auto_scale_xyz(*[(i - w, i + w) for i in m], had_data=True)
 
-    if ax.collections[0].get_array() is not None and colorbar:
-        fig.colorbar(ax.collections[0])
-    if ax.collections[1].get_array() is not None and colorbar:
-        fig.colorbar(ax.collections[1])
+    # add separate colorbars for symbols and hoppings if ncessary
+    if symbol_coll.get_array() is not None and colorbar:
+        fig.colorbar(symbol_coll)
+    if line_coll.get_array() is not None and colorbar:
+        fig.colorbar(line_coll)
 
     if fig is not None:
         return output_fig(fig, file=file, show=show)
-- 
GitLab