diff --git a/kwant/plotter.py b/kwant/plotter.py index 335ba85798ab70dbebba74b25fa51cf3f8b208a6..c96a33d4efb7cda5922c3b632571816ff5e279d7 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -1305,14 +1305,14 @@ def plot(sys, num_lead_cells=2, unit='nn', site_edgecolor) lw = site_lw[slc] if isarray(site_lw) else site_lw - symbols(ax, sites_pos[slc], size=size, - reflen=reflen, symbol=symbol, - facecolor=col, edgecolor=edgecol, - linewidth=lw, cmap=cmap, norm=norm, zorder=2) + symbol_coll = symbols(ax, sites_pos[slc], size=size, + reflen=reflen, symbol=symbol, + facecolor=col, edgecolor=edgecol, + linewidth=lw, cmap=cmap, norm=norm, zorder=2) end, start = end_pos[: n_sys_hops], start_pos[: n_sys_hops] - lines(ax, end, start, reflen, hop_color, linewidths=hop_lw, zorder=1, - cmap=hop_cmap) + line_coll = lines(ax, end, start, reflen, hop_color, linewidths=hop_lw, + zorder=1, cmap=hop_cmap) # plot lead sites and hoppings norm = matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5) @@ -1351,10 +1351,11 @@ def plot(sys, num_lead_cells=2, unit='nn', m = (min_ + max_) / 2 ax.auto_scale_xyz(*[(i - w, i + w) for i in m], had_data=True) - if ax.collections[0].get_array() is not None and colorbar: - fig.colorbar(ax.collections[0]) - if ax.collections[1].get_array() is not None and colorbar: - fig.colorbar(ax.collections[1]) + # add separate colorbars for symbols and hoppings if ncessary + if symbol_coll.get_array() is not None and colorbar: + fig.colorbar(symbol_coll) + if line_coll.get_array() is not None and colorbar: + fig.colorbar(line_coll) if fig is not None: return output_fig(fig, file=file, show=show)