From 21d32fd33e9e051f98f11078233ad3c4d1feb6bd Mon Sep 17 00:00:00 2001
From: Kelvin Loh <kel85uk@gmail.com>
Date: Mon, 10 Dec 2018 15:30:38 +0100
Subject: [PATCH] 2D implementation of the plot() function.

---
 kwant/_plotter.py |  70 ++++++++
 kwant/plotter.py  | 441 +++++++++++++++++++++++++++++++++++++++-------
 2 files changed, 446 insertions(+), 65 deletions(-)

diff --git a/kwant/_plotter.py b/kwant/_plotter.py
index 5cb965d3..4b24936d 100644
--- a/kwant/_plotter.py
+++ b/kwant/_plotter.py
@@ -93,6 +93,76 @@ def nparray_if_array(var):
     return np.asarray(var) if isarray(var) else var
 
 
+if plotly_available:
+
+    converter_map = {
+        "o": 0,
+        "v": 6,
+        "^": 5,
+        "<": 7,
+        ">": 8,
+        "s": 1,
+        "+": 3,
+        "x": 4,
+        "*": 17,
+        "d": 2,
+        "h": 14
+    }
+
+
+    def convert_symbol_mpl_plotly(mpl_symbol):
+        if isarray(mpl_symbol):
+            converted_symbol = [converter_map.get(i) for i in mpl_symbol]
+        else:
+            converted_symbol = converter_map.get(mpl_symbol)
+
+        if converted_symbol == None:
+            raise RuntimeWarning('Input symbol \'{}\' not supported. '
+                            'Only the following are supported: {}'.format(
+                                mpl_symbol, converter_map.keys()))
+        return converted_symbol
+
+
+    def convert_site_size_mpl_plotly(mpl_site_size, plotly_ref_px):
+        return np.sqrt(mpl_site_size)*(96.0/72.0)*plotly_ref_px
+
+
+    def convert_colormap_mpl_plotly(mpl_rgba):
+        _cmap_plotly = 255 * np.array(mpl_rgba)
+        return 'rgba({},{},{},{})'.format(*_cmap_plotly[0:-1],
+                                          _cmap_plotly[-1]/255)
+
+
+    def convert_cmap_list_mpl_plotly(mpl_cmap_name, N=255):
+        cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name)
+        cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl)
+        level = np.linspace(1, 0, N)
+        cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl))
+                                for level, cmap_mpl in zip(level,
+                                                            cmap_mpl_arr)]
+        return cmap_plotly_linear
+
+
+    def convert_lead_cmap_mpl_plotly(mpl_lead_cmap_init, mpl_lead_cmap_end,
+                                     N=255):
+        r_levels = np.linspace(mpl_lead_cmap_init[0],
+                               mpl_lead_cmap_end[0], N) * 255
+        g_levels = np.linspace(mpl_lead_cmap_init[1],
+                               mpl_lead_cmap_end[1], N) * 255
+        b_levels = np.linspace(mpl_lead_cmap_init[2],
+                               mpl_lead_cmap_end[2], N) * 255
+        a_levels = np.linspace(mpl_lead_cmap_init[3],
+                               mpl_lead_cmap_end[3], N)
+        level = np.linspace(1, 0, N)
+        cmap_plotly_linear = [(level, 'rgba({},{},{},{})'.format(*rgba))
+                                    for level, rgba in zip(level, zip(r_levels,
+                                                                      g_levels,
+                                                                      b_levels,
+                                                                      a_levels
+                                                                      ))]
+        return cmap_plotly_linear
+
+
 if mpl_available:
     class LineCollection(collections.LineCollection):
         def __init__(self, segments, reflen=None, **kwargs):
diff --git a/kwant/plotter.py b/kwant/plotter.py
index 13c26611..a38579a6 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -724,18 +724,21 @@ def sys_leads_hopping_pos(sys, hop_lead_nr):
 
 
 # Useful plot functions (to be extended).
-
+# The default plotly symbol size is a 6 px
 defaults = {'site_symbol': {2: 'o', 3: 'o'},
             'site_size': {2: 0.25, 3: 0.5},
+            'plotly_site_size_reference': 6,
             'site_color': {2: 'black', 3: 'white'},
             'site_edgecolor': {2: 'black', 3: 'black'},
             'site_lw': {2: 0, 3: 0.1},
             'hop_color': {2: 'black', 3: 'black'},
             'hop_lw': {2: 0.1, 3: 0},
-            'lead_color': {2: 'red', 3: 'red'}}
+            'lead_color': {2: 'red', 3: 'red'},
+            'unit': {_p.Backends.plotly: 'pt',
+                     _p.Backends.matplotlib: 'nn'}}
 
 
-def plot(sys, num_lead_cells=2, unit='nn',
+def plot(sys, num_lead_cells=2, unit=None,
          site_symbol=None, site_size=None,
          site_color=None, site_edgecolor=None, site_lw=None,
          hop_color=None, hop_lw=None,
@@ -887,10 +890,104 @@ def plot(sys, num_lead_cells=2, unit='nn',
       its aspect ratio.
 
     """
-    if not _p.mpl_available:
-        raise RuntimeError("matplotlib was not found, but is required "
+
+    # Provide default unit if user did not specify
+    if unit == None:
+        unit = defaults['unit'][get_backend()]
+    if get_backend() == _p.Backends.matplotlib:
+        fig = _plot_matplotlib(sys, num_lead_cells, unit,
+                        site_symbol, site_size,
+                        site_color, site_edgecolor, site_lw,
+                        hop_color, hop_lw,
+                        lead_site_symbol, lead_site_size, lead_color,
+                        lead_site_edgecolor, lead_site_lw,
+                        lead_hop_lw, pos_transform,
+                        cmap, colorbar, file,
+                        show, dpi, fig_size, ax)
+    elif get_backend() == _p.Backends.plotly:
+        _check_incompatible_args_plotly(dpi, fig_size, ax)
+        fig = _plot_plotly(sys, num_lead_cells, unit,
+                        site_symbol, site_size,
+                        site_color, site_edgecolor, site_lw,
+                        hop_color, hop_lw,
+                        lead_site_symbol, lead_site_size, lead_color,
+                        lead_site_edgecolor, lead_site_lw,
+                        lead_hop_lw, pos_transform,
+                        cmap, colorbar, file,
+                        show)
+    else:
+        raise RuntimeError("Backend not supported by plot().")
+
+    _maybe_output_fig(fig, file=file, show=show)
+
+    return fig
+
+def _resize_to_dim(array, dim):
+    if array.shape[1] != dim:
+        ar = np.zeros((len(array), dim), dtype=float)
+        ar[:, : min(dim, array.shape[1])] = array[
+            :, : min(dim, array.shape[1])]
+        return ar
+    else:
+        return array
+
+
+def _check_length(name, loc):
+    value = loc[name]
+    if name in ('site_size', 'site_lw') and isinstance(value, tuple):
+        raise TypeError('{0} may not be a tuple, use list or '
+                        'array instead.'.format(name))
+    if isinstance(value, (str, tuple)):
+        return
+    try:
+        if len(value) != loc['n_syst_sites']:
+            raise ValueError('Length of {0} is not equal to number of '
+                             'system sites.'.format(name))
+    except TypeError:
+        pass
+
+# make all specs proper: either constant or lists/np.arrays:
+def _make_proper_site_spec(spec_name,  spec, fancy_indexing=False):
+    if _p.isarray(spec) and isinstance(syst, builder.Builder):
+        raise TypeError('{} cannot be an array when plotting'
+                        ' a Builder; use a function instead.'
+                        .format(spec_name))
+    if callable(spec):
+        spec = [spec(i[0]) for i in sites if i[1] is None]
+    if (fancy_indexing and _p.isarray(spec)
+        and not isinstance(spec, np.ndarray)):
+        try:
+            spec = np.asarray(spec)
+        except:
+            spec = np.asarray(spec, dtype='object')
+    return spec
+
+def _make_proper_hop_spec(spec, fancy_indexing=False):
+    if callable(spec):
+        spec = [spec(*i[0]) for i in hops if i[1] is None]
+    if (fancy_indexing and _p.isarray(spec)
+        and not isinstance(spec, np.ndarray)):
+        try:
+            spec = np.asarray(spec)
+        except:
+            spec = np.asarray(spec, dtype='object')
+    return spec
+
+def _plot_plotly(sys, num_lead_cells, unit,
+         site_symbol, site_size,
+         site_color, site_edgecolor, site_lw,
+         hop_color, hop_lw,
+         lead_site_symbol, lead_site_size, lead_color,
+         lead_site_edgecolor, lead_site_lw,
+         lead_hop_lw, pos_transform,
+         cmap, colorbar, file,
+         show):
+
+    if not _p.plotly_available:
+        raise RuntimeError("plotly was not found, but is required "
                            "for plot()")
 
+    print('In _plot_plotly')
     syst = sys  # for naming consistency inside function bodies
     # Generate data.
     sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells)
@@ -900,35 +997,279 @@ def plot(sys, num_lead_cells=2, unit='nn',
     n_syst_hops = sum(i[1] is None for i in hops)
     end_pos, start_pos = sys_leads_hopping_pos(syst, hops)
 
-    # Choose plot type.
-    def resize_to_dim(array):
-        if array.shape[1] != dim:
-            ar = np.zeros((len(array), dim), dtype=float)
-            ar[:, : min(dim, array.shape[1])] = array[
-                :, : min(dim, array.shape[1])]
-            return ar
+    loc = locals()
+
+    for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor',
+                 'site_lw']:
+        _check_length(name, loc)
+
+    # Apply transformations to the data
+    if pos_transform is not None:
+        sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos)
+        end_pos = np.apply_along_axis(pos_transform, 1, end_pos)
+        start_pos = np.apply_along_axis(pos_transform, 1, start_pos)
+
+    dim = 3 if (sites_pos.shape[1] == 3) else 2
+
+    sites_pos = _resize_to_dim(sites_pos, dim)
+    end_pos = _resize_to_dim(end_pos, dim)
+    start_pos = _resize_to_dim(start_pos, dim)
+
+    # Determine the reference length.
+    if unit != 'pt':
+        raise RuntimeError('Plotly backend currently only supports '
+                         'the pt symbol size unit')
+
+
+    site_symbol = _make_proper_site_spec('site_symbol', site_symbol, sites)
+    if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
+    # separate different symbols (not done in 3D, the separation
+    # would mess up sorting)
+    if (_p.isarray(site_symbol) and dim != 3 and
+        (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))):
+        symbol_dict = defaultdict(list)
+        for i, symbol in enumerate(site_symbol):
+            symbol_dict[symbol].append(i)
+        symbol_slcs = []
+        for symbol, indx in symbol_dict.items():
+            symbol_slcs.append((symbol, np.array(indx)))
+        fancy_indexing = True
+    else:
+        symbol_slcs = [(site_symbol, slice(n_syst_sites))]
+        fancy_indexing = False
+
+    if site_color is None:
+        cycle = _color_cycle()
+        if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)):
+            # Skipping the leads for brevity.
+            families = sorted({site.family for site in syst.sites})
+            color_mapping = dict(zip(families, cycle))
+            def site_color(site):
+                return color_mapping[syst.sites[site].family]
+        elif isinstance(syst, builder.Builder):
+            families = sorted({site[0].family for site in sites})
+            color_mapping = dict(zip(families, cycle))
+            def site_color(site):
+                return color_mapping[site.family]
         else:
-            return array
+            # Unknown finalized system, no sites access.
+            site_color = defaults['site_color'][dim]
 
-    loc = locals()
+    site_size = _make_proper_site_spec('site_size',site_size, sites, fancy_indexing)
+    site_color = _make_proper_site_spec('site_color',site_color, sites, fancy_indexing)
+    site_edgecolor = _make_proper_site_spec('site_edgecolor',site_edgecolor, sites,
+                                            fancy_indexing)
+    site_lw = _make_proper_site_spec('site_lw',site_lw, sites, fancy_indexing)
 
-    def check_length(name):
-        value = loc[name]
-        if name in ('site_size', 'site_lw') and isinstance(value, tuple):
-            raise TypeError('{0} may not be a tuple, use list or '
-                            'array instead.'.format(name))
-        if isinstance(value, (str, tuple)):
-            return
+    hop_color = _make_proper_hop_spec(hop_color, hops)
+    hop_lw = _make_proper_hop_spec(hop_lw, hops)
+
+    # Choose defaults depending on dimension, if None was given
+    if site_size is None: site_size = defaults['site_size'][dim]
+    if site_edgecolor is None:
+        site_edgecolor = defaults['site_edgecolor'][dim]
+    if site_lw is None: site_lw = defaults['site_lw'][dim]
+
+    if hop_color is None: hop_color = defaults['hop_color'][dim]
+    if hop_lw is None: hop_lw = defaults['hop_lw'][dim]
+
+    # if symbols are split up into different collections,
+    # the colormapping will fail without normalization
+    norm = None
+    if len(symbol_slcs) > 1:
         try:
-            if len(value) != n_syst_sites:
-                raise ValueError('Length of {0} is not equal to number of '
-                                 'system sites.'.format(name))
+            if site_color.ndim == 1 and len(site_color) == n_syst_sites:
+                site_color = np.asarray(site_color, dtype=float)
+                norm = _p.matplotlib.colors.Normalize(site_color.min(),
+                                                      site_color.max())
+        except:
+            pass
+
+    # take spec also for lead, if it's not a list/array, default, otherwise
+    if lead_site_symbol is None:
+        lead_site_symbol = (site_symbol if not _p.isarray(site_symbol)
+                            else defaults['site_symbol'][dim])
+    if lead_site_size is None:
+        lead_site_size = (site_size if not _p.isarray(site_size)
+                          else defaults['site_size'][dim])
+    if lead_color is None:
+        lead_color = defaults['lead_color'][dim]
+    lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color)
+
+    if lead_site_edgecolor is None:
+        lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor)
+                               else defaults['site_edgecolor'][dim])
+    if lead_site_lw is None:
+        lead_site_lw = (site_lw if not _p.isarray(site_lw)
+                        else defaults['site_lw'][dim])
+    if lead_hop_lw is None:
+        lead_hop_lw = (hop_lw if not _p.isarray(hop_lw)
+                       else defaults['hop_lw'][dim])
+
+    hop_cmap = None
+    if not isinstance(cmap, str):
+        try:
+            cmap, hop_cmap = cmap
         except TypeError:
             pass
+    # plot system sites and hoppings
+    site_node_trace, site_edge_trace = [], []
+    for symbol, slc in symbol_slcs:
+        size = site_size[slc] if _p.isarray(site_size) else site_size
+        col = site_color[slc] if _p.isarray(site_color) else site_color
+        edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else
+                   site_edgecolor)
+        lw = site_lw[slc] if _p.isarray(site_lw) else site_lw
+
+        site_symbol_plotly = _p.convert_symbol_mpl_plotly(symbol)
+        site_node_trace_elem = _p.plotly_graph_objs.Scatter(
+                        x=[],
+                        y=[],
+                        text=[],
+                        mode='markers',
+                        hoverinfo='text',
+                        marker=dict(
+                            showscale=False,
+                            colorscale=_p.convert_cmap_list_mpl_plotly(cmap),
+                            reversescale=True,
+                            color=col,
+                            size=_p.convert_site_size_mpl_plotly(size,
+                                       defaults['plotly_site_size_reference']),
+                            symbol=site_symbol_plotly,
+                            line=dict(width=lw,
+                                      color=edgecol)
+                            ))
+
+
+        for i in range(len(sites_pos[slc])):
+            x, y = sites_pos[slc][i]
+            site_node_trace_elem['x'] += tuple([x])
+            site_node_trace_elem['y'] += tuple([y])
+
+        site_node_trace.append(site_node_trace_elem)
+
+    end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops]
+    dim = end.shape[1]
+    assert dim == 2 or dim == 3
+    if dim == 2:
+        site_edge_trace_elem = _p.plotly_graph_objs.Scatter(
+                                x=[],
+                                y=[],
+                                line=dict(width=hop_lw,color=hop_color),
+                                hoverinfo='none',
+                                mode='lines')
+        for i in range(len(end)):
+            x0, y0 = end[i]
+            x1, y1 = start[i]
+            site_edge_trace_elem['x'] += tuple([x0, x1, None])
+            site_edge_trace_elem['y'] += tuple([y0, y1, None])
+        site_edge_trace.append(site_edge_trace_elem)
+    else:
+        raise RuntimeError('dim=3 is unsupported yet in plotly backend')
+
+    # Make conversion of colormap
+
+    lead_site_symbol_plotly = _p.convert_symbol_mpl_plotly(lead_site_symbol)
+
+    lead_node_trace, lead_edge_trace = [], []
+    for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs):
+        lead_site_colors = np.array([i[2] for i in sites[sites_slc]],
+                                    dtype=float)
+        lead_node_trace_elem = _p.plotly_graph_objs.Scatter(
+                        x=[],
+                        y=[],
+                        text=[],
+                        mode='markers',
+                        hoverinfo='text',
+                        marker=dict(
+                            showscale=False,
+                            reversescale=True,
+                            color=lead_site_colors,
+                            colorscale=_p.convert_lead_cmap_mpl_plotly(
+                                            lead_color, [1,1,1,lead_color[3]]),
+                            size=_p.convert_site_size_mpl_plotly(
+                                   lead_site_size,
+                                   defaults['plotly_site_size_reference']),
+                            symbol=lead_site_symbol_plotly,
+                            line=dict(width=lead_site_lw,
+                                      color=lead_site_edgecolor)
+                            ))
+        for i in range(len(sites_pos[sites_slc])):
+            x, y = sites_pos[sites_slc][i]
+            lead_node_trace_elem['x'] += tuple([x])
+            lead_node_trace_elem['y'] += tuple([y])
+        lead_node_trace.append(lead_node_trace_elem)
+        lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float)
+        # Note: the previous version of the code had in addition this
+        # line in the 3D case:
+        # lead_hop_colors = 1 / np.sqrt(1. + lead_hop_colors)
+        # Uses lead_cmap for the colormap
+        # 1) Make each line a scatter object. Takes a lot of memory but should work
+        # 2) Get the color from the previous object
+        end, start = end_pos[hops_slc], start_pos[hops_slc]
+        if dim == 2:
+            lead_edge_trace_elem = _p.plotly_graph_objs.Scatter(
+                                    x=[],
+                                    y=[],
+                                    line=dict(width=lead_hop_lw,
+                                              color='red'),
+                                    hoverinfo='none',
+                                    mode='lines')
+            for i in range(len(end)):
+                x0, y0 = end[i]
+                x1, y1 = start[i]
+                lead_edge_trace_elem['x'] += tuple([x0, x1, None])
+                lead_edge_trace_elem['y'] += tuple([y0, y1, None])
+
+            lead_edge_trace.append(lead_edge_trace_elem)
+        else:
+            raise RuntimeError('dim=3 is unsupported yet in plotly backend')
+
+    layout = _p.plotly_graph_objs.Layout(
+                                showlegend=False,
+                                hovermode='closest',
+                                xaxis=dict(showgrid=False, zeroline=False,
+                                           showticklabels=True),
+                                yaxis=dict(showgrid=False, zeroline=False,
+                                           showticklabels=True))
+    full_trace = list(itertools.chain.from_iterable([site_edge_trace,
+                                        site_node_trace, lead_edge_trace,
+                                        lead_node_trace]))
+    fig = _p.plotly_graph_objs.Figure(data=full_trace,
+                                      layout=layout)
+
+    return fig
+
+
+def _plot_matplotlib(sys, num_lead_cells, unit,
+         site_symbol, site_size,
+         site_color, site_edgecolor, site_lw,
+         hop_color, hop_lw,
+         lead_site_symbol, lead_site_size, lead_color,
+         lead_site_edgecolor, lead_site_lw,
+         lead_hop_lw, pos_transform,
+         cmap, colorbar, file,
+         show, dpi, fig_size, ax):
+
+    if not _p.mpl_available:
+        raise RuntimeError("matplotlib was not found, but is required "
+                           "for plot()")
+
+    print('In _plot_matplotlib')
+    syst = sys  # for naming consistency inside function bodies
+    # Generate data.
+    sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells)
+    n_syst_sites = sum(i[1] is None for i in sites)
+    sites_pos = sys_leads_pos(syst, sites)
+    hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells)
+    n_syst_hops = sum(i[1] is None for i in hops)
+    end_pos, start_pos = sys_leads_hopping_pos(syst, hops)
+
+    loc = locals()
 
     for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor',
                  'site_lw']:
-        check_length(name)
+        _check_length(name, loc)
 
     # Apply transformations to the data
     if pos_transform is not None:
@@ -939,9 +1280,9 @@ def plot(sys, num_lead_cells=2, unit='nn',
     dim = 3 if (sites_pos.shape[1] == 3) else 2
     if dim == 3 and not _p.has3d:
         raise RuntimeError("Installed matplotlib does not support 3d plotting")
-    sites_pos = resize_to_dim(sites_pos)
-    end_pos = resize_to_dim(end_pos)
-    start_pos = resize_to_dim(start_pos)
+    sites_pos = _resize_to_dim(sites_pos, dim)
+    end_pos = _resize_to_dim(end_pos, dim)
+    start_pos = _resize_to_dim(start_pos, dim)
 
     # Determine the reference length.
     if unit == 'pt':
@@ -974,35 +1315,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
         except:
             raise ValueError('Invalid value of unit argument.')
 
-    # make all specs proper: either constant or lists/np.arrays:
-    def make_proper_site_spec(spec_name,  spec, fancy_indexing=False):
-        if _p.isarray(spec) and isinstance(syst, builder.Builder):
-            raise TypeError('{} cannot be an array when plotting'
-                            ' a Builder; use a function instead.'
-                            .format(spec_name))
-        if callable(spec):
-            spec = [spec(i[0]) for i in sites if i[1] is None]
-        if (fancy_indexing and _p.isarray(spec)
-            and not isinstance(spec, np.ndarray)):
-            try:
-                spec = np.asarray(spec)
-            except:
-                spec = np.asarray(spec, dtype='object')
-        return spec
-
-    def make_proper_hop_spec(spec, fancy_indexing=False):
-        if callable(spec):
-            spec = [spec(*i[0]) for i in hops if i[1] is None]
-        if (fancy_indexing and _p.isarray(spec)
-            and not isinstance(spec, np.ndarray)):
-            try:
-                spec = np.asarray(spec)
-            except:
-                spec = np.asarray(spec, dtype='object')
-        return spec
-
-
-    site_symbol = make_proper_site_spec('site_symbol', site_symbol)
+    site_symbol = _make_proper_site_spec('site_symbol', site_symbol)
     if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
     # separate different symbols (not done in 3D, the separation
     # would mess up sorting)
@@ -1036,13 +1349,13 @@ def plot(sys, num_lead_cells=2, unit='nn',
             # Unknown finalized system, no sites access.
             site_color = defaults['site_color'][dim]
 
-    site_size = make_proper_site_spec('site_size', site_size, fancy_indexing)
-    site_color = make_proper_site_spec('site_color', site_color, fancy_indexing)
-    site_edgecolor = make_proper_site_spec('site_edgecolor', site_edgecolor, fancy_indexing)
-    site_lw = make_proper_site_spec('site_lw', site_lw, fancy_indexing)
+    site_size = _make_proper_site_spec('site_size', site_size, fancy_indexing)
+    site_color = _make_proper_site_spec('site_color', site_color, fancy_indexing)
+    site_edgecolor = _make_proper_site_spec('site_edgecolor', site_edgecolor, fancy_indexing)
+    site_lw = _make_proper_site_spec('site_lw', site_lw, fancy_indexing)
 
-    hop_color = make_proper_hop_spec(hop_color)
-    hop_lw = make_proper_hop_spec(hop_lw)
+    hop_color = _make_proper_hop_spec(hop_color, hops)
+    hop_lw = _make_proper_hop_spec(hop_lw, hops)
 
     # Choose defaults depending on dimension, if None was given
     if site_size is None: site_size = defaults['site_size'][dim]
@@ -1171,8 +1484,6 @@ def plot(sys, num_lead_cells=2, unit='nn',
     if line_coll.get_array() is not None and colorbar and fig is not None:
         fig.colorbar(line_coll)
 
-    _maybe_output_fig(fig, file=file, show=show)
-
     return fig
 
 
-- 
GitLab