From a29757a2c48b59c00e522d38825c9c695dd7b5e7 Mon Sep 17 00:00:00 2001 From: Joseph Weston <joseph@weston.cloud> Date: Fri, 23 Feb 2018 17:38:35 +0100 Subject: [PATCH] move functions to _plotter.py, where they are needed --- kwant/_plotter.py | 13 +++++++++++++ kwant/plotter.py | 39 +++++++++++++-------------------------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/kwant/_plotter.py b/kwant/_plotter.py index 8fc746a1..771e9ba2 100644 --- a/kwant/_plotter.py +++ b/kwant/_plotter.py @@ -39,6 +39,19 @@ except ImportError: mpl_available = False +# Collections that allow for symbols and linewiths to be given in data space +# (not for general use, only implement what's needed for plotter) +def isarray(var): + if hasattr(var, '__getitem__') and not isinstance(var, str): + return True + else: + return False + + +def nparray_if_array(var): + return np.asarray(var) if isarray(var) else var + + if mpl_available: class LineCollection(collections.LineCollection): def __init__(self, segments, reflen=None, **kwargs): diff --git a/kwant/plotter.py b/kwant/plotter.py index 27d06027..2fc6b6ac 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -39,19 +39,6 @@ __all__ = ['plot', 'map', 'bands', 'spectrum', 'current', _p = _common.lazy_import('_plotter') -# Collections that allow for symbols and linewiths to be given in data space -# (not for general use, only implement what's needed for plotter) -def isarray(var): - if hasattr(var, '__getitem__') and not isinstance(var, str): - return True - else: - return False - - -def nparray_if_array(var): - return np.asarray(var) if isarray(var) else var - - def _sample_array(array, n_samples, rng=None): rng = _common.ensure_rng(rng) la = len(array) @@ -93,7 +80,7 @@ def set_colors(color, collection, cmap, norm=None): if isinstance(collection, _p.mplot3d.art3d.Line3DCollection): length = len(collection._segments3d) # Once again, matplotlib fault! - if isarray(color) and len(color) == length: + if _p.isarray(color) and len(color) == length: try: # check if it is an array of floats for color mapping color = np.asarray(color, dtype=float) @@ -915,7 +902,7 @@ def plot(sys, num_lead_cells=2, unit='nn', def make_proper_site_spec(spec, fancy_indexing=False): if callable(spec): spec = [spec(i[0]) for i in sites if i[1] is None] - if (fancy_indexing and isarray(spec) + if (fancy_indexing and _p.isarray(spec) and not isinstance(spec, np.ndarray)): try: spec = np.asarray(spec) @@ -926,7 +913,7 @@ def plot(sys, num_lead_cells=2, unit='nn', def make_proper_hop_spec(spec, fancy_indexing=False): if callable(spec): spec = [spec(*i[0]) for i in hops if i[1] is None] - if (fancy_indexing and isarray(spec) + if (fancy_indexing and _p.isarray(spec) and not isinstance(spec, np.ndarray)): try: spec = np.asarray(spec) @@ -938,7 +925,7 @@ def plot(sys, num_lead_cells=2, unit='nn', if site_symbol is None: site_symbol = defaults['site_symbol'][dim] # separate different symbols (not done in 3D, the separation # would mess up sorting) - if (isarray(site_symbol) and dim != 3 and + if (_p.isarray(site_symbol) and dim != 3 and (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))): symbol_dict = defaultdict(list) for i, symbol in enumerate(site_symbol): @@ -999,23 +986,23 @@ def plot(sys, num_lead_cells=2, unit='nn', # take spec also for lead, if it's not a list/array, default, otherwise if lead_site_symbol is None: - lead_site_symbol = (site_symbol if not isarray(site_symbol) + lead_site_symbol = (site_symbol if not _p.isarray(site_symbol) else defaults['site_symbol'][dim]) if lead_site_size is None: - lead_site_size = (site_size if not isarray(site_size) + lead_site_size = (site_size if not _p.isarray(site_size) else defaults['site_size'][dim]) if lead_color is None: lead_color = defaults['lead_color'][dim] lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color) if lead_site_edgecolor is None: - lead_site_edgecolor = (site_edgecolor if not isarray(site_edgecolor) + lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor) else defaults['site_edgecolor'][dim]) if lead_site_lw is None: - lead_site_lw = (site_lw if not isarray(site_lw) + lead_site_lw = (site_lw if not _p.isarray(site_lw) else defaults['site_lw'][dim]) if lead_hop_lw is None: - lead_hop_lw = (hop_lw if not isarray(hop_lw) + lead_hop_lw = (hop_lw if not _p.isarray(hop_lw) else defaults['hop_lw'][dim]) hop_cmap = None @@ -1041,11 +1028,11 @@ def plot(sys, num_lead_cells=2, unit='nn', # plot system sites and hoppings for symbol, slc in symbol_slcs: - size = site_size[slc] if isarray(site_size) else site_size - col = site_color[slc] if isarray(site_color) else site_color - edgecol = (site_edgecolor[slc] if isarray(site_edgecolor) else + size = site_size[slc] if _p.isarray(site_size) else site_size + col = site_color[slc] if _p.isarray(site_color) else site_color + edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else site_edgecolor) - lw = site_lw[slc] if isarray(site_lw) else site_lw + lw = site_lw[slc] if _p.isarray(site_lw) else site_lw symbol_coll = symbols(ax, sites_pos[slc], size=size, reflen=reflen, symbol=symbol, -- GitLab