From a29757a2c48b59c00e522d38825c9c695dd7b5e7 Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Fri, 23 Feb 2018 17:38:35 +0100
Subject: [PATCH] move functions to _plotter.py, where they are needed

---
 kwant/_plotter.py | 13 +++++++++++++
 kwant/plotter.py  | 39 +++++++++++++--------------------------
 2 files changed, 26 insertions(+), 26 deletions(-)

diff --git a/kwant/_plotter.py b/kwant/_plotter.py
index 8fc746a1..771e9ba2 100644
--- a/kwant/_plotter.py
+++ b/kwant/_plotter.py
@@ -39,6 +39,19 @@ except ImportError:
     mpl_available = False
 
 
+# Collections that allow for symbols and linewiths to be given in data space
+# (not for general use, only implement what's needed for plotter)
+def isarray(var):
+    if hasattr(var, '__getitem__') and not isinstance(var, str):
+        return True
+    else:
+        return False
+
+
+def nparray_if_array(var):
+    return np.asarray(var) if isarray(var) else var
+
+
 if mpl_available:
     class LineCollection(collections.LineCollection):
         def __init__(self, segments, reflen=None, **kwargs):
diff --git a/kwant/plotter.py b/kwant/plotter.py
index 27d06027..2fc6b6ac 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -39,19 +39,6 @@ __all__ = ['plot', 'map', 'bands', 'spectrum', 'current',
 _p = _common.lazy_import('_plotter')
 
 
-# Collections that allow for symbols and linewiths to be given in data space
-# (not for general use, only implement what's needed for plotter)
-def isarray(var):
-    if hasattr(var, '__getitem__') and not isinstance(var, str):
-        return True
-    else:
-        return False
-
-
-def nparray_if_array(var):
-    return np.asarray(var) if isarray(var) else var
-
-
 def _sample_array(array, n_samples, rng=None):
     rng = _common.ensure_rng(rng)
     la = len(array)
@@ -93,7 +80,7 @@ def set_colors(color, collection, cmap, norm=None):
     if isinstance(collection, _p.mplot3d.art3d.Line3DCollection):
         length = len(collection._segments3d)  # Once again, matplotlib fault!
 
-    if isarray(color) and len(color) == length:
+    if _p.isarray(color) and len(color) == length:
         try:
             # check if it is an array of floats for color mapping
             color = np.asarray(color, dtype=float)
@@ -915,7 +902,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
     def make_proper_site_spec(spec, fancy_indexing=False):
         if callable(spec):
             spec = [spec(i[0]) for i in sites if i[1] is None]
-        if (fancy_indexing and isarray(spec)
+        if (fancy_indexing and _p.isarray(spec)
             and not isinstance(spec, np.ndarray)):
             try:
                 spec = np.asarray(spec)
@@ -926,7 +913,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
     def make_proper_hop_spec(spec, fancy_indexing=False):
         if callable(spec):
             spec = [spec(*i[0]) for i in hops if i[1] is None]
-        if (fancy_indexing and isarray(spec)
+        if (fancy_indexing and _p.isarray(spec)
             and not isinstance(spec, np.ndarray)):
             try:
                 spec = np.asarray(spec)
@@ -938,7 +925,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
     if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
     # separate different symbols (not done in 3D, the separation
     # would mess up sorting)
-    if (isarray(site_symbol) and dim != 3 and
+    if (_p.isarray(site_symbol) and dim != 3 and
         (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))):
         symbol_dict = defaultdict(list)
         for i, symbol in enumerate(site_symbol):
@@ -999,23 +986,23 @@ def plot(sys, num_lead_cells=2, unit='nn',
 
     # take spec also for lead, if it's not a list/array, default, otherwise
     if lead_site_symbol is None:
-        lead_site_symbol = (site_symbol if not isarray(site_symbol)
+        lead_site_symbol = (site_symbol if not _p.isarray(site_symbol)
                             else defaults['site_symbol'][dim])
     if lead_site_size is None:
-        lead_site_size = (site_size if not isarray(site_size)
+        lead_site_size = (site_size if not _p.isarray(site_size)
                           else defaults['site_size'][dim])
     if lead_color is None:
         lead_color = defaults['lead_color'][dim]
     lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color)
 
     if lead_site_edgecolor is None:
-        lead_site_edgecolor = (site_edgecolor if not isarray(site_edgecolor)
+        lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor)
                                else defaults['site_edgecolor'][dim])
     if lead_site_lw is None:
-        lead_site_lw = (site_lw if not isarray(site_lw)
+        lead_site_lw = (site_lw if not _p.isarray(site_lw)
                         else defaults['site_lw'][dim])
     if lead_hop_lw is None:
-        lead_hop_lw = (hop_lw if not isarray(hop_lw)
+        lead_hop_lw = (hop_lw if not _p.isarray(hop_lw)
                        else defaults['hop_lw'][dim])
 
     hop_cmap = None
@@ -1041,11 +1028,11 @@ def plot(sys, num_lead_cells=2, unit='nn',
 
     # plot system sites and hoppings
     for symbol, slc in symbol_slcs:
-        size = site_size[slc] if isarray(site_size) else site_size
-        col = site_color[slc] if isarray(site_color) else site_color
-        edgecol = (site_edgecolor[slc] if isarray(site_edgecolor) else
+        size = site_size[slc] if _p.isarray(site_size) else site_size
+        col = site_color[slc] if _p.isarray(site_color) else site_color
+        edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else
                    site_edgecolor)
-        lw = site_lw[slc] if isarray(site_lw) else site_lw
+        lw = site_lw[slc] if _p.isarray(site_lw) else site_lw
 
         symbol_coll = symbols(ax, sites_pos[slc], size=size,
                               reflen=reflen, symbol=symbol,
-- 
GitLab