diff --git a/.gitignore b/.gitignore index ba94c807d204266b925eb7c25a92bc57bfe42971..5f53200b7e9a1ab828bc6a22ad7a23645c989fc3 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.pyc *.pyo *.so +*.pyd /kwant/*.c /kwant/*/*.c /build @@ -16,3 +17,4 @@ .coverage .eggs/ htmlcov/ +.ipynb_checkpoints/ diff --git a/AUTHORS.rst b/AUTHORS.rst index 6b1bf2b7b2918eb817f08d1460c2e322b723769f..07b31efed8e4582ac647c8a44330742f1e65bcb1 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -21,6 +21,7 @@ Contributors to Kwant include * Mathieu Istas (CEA Grenoble) * Daniel Jaschke (CEA Grenoble) * Thomas Kloss (CEA Grenoble) +* Kelvin Loh (TNO) * Bas Nijholt (TU Delft) * Michał Nowak (TU Delft) * Viacheslav Ostroukh (Leiden University) diff --git a/doc/source/pre/whatsnew/1.5.rst b/doc/source/pre/whatsnew/1.5.rst index 2e89d74791d9afa5739ecf2b32d11ec5216d0722..17b7b3bc2163e567d8cf4dff7732188148d38b96 100644 --- a/doc/source/pre/whatsnew/1.5.rst +++ b/doc/source/pre/whatsnew/1.5.rst @@ -65,6 +65,33 @@ rather than:: kwant.lattice.square() +Plotly for plots +---------------- +Kwant can now use the `Plotly <https://plot.ly/>`_ library when plotting, +though matplotlib is still used by default. Using plotly empowers Kwant +to produce high-quality interactive plots, as well as true 3D support: + +.. jupyter-execute:: + :hide-code: + + import kwant + import kwant.plotter + +.. jupyter-execute:: + + lat = kwant.lattice.cubic(norbs=1) + syst = kwant.Builder() + + def disk(r): + x, y, z = r + return -2 <= z < 2 and 10**2 < x**2 + y**2 < 20**2 + + syst[lat.shape(disk, (15, 0, 0))] = 4 + + kwant.plotter.set_engine("plotly") + kwant.plot(syst); + + Automatic addition of Peierls phase terms to Builders ----------------------------------------------------- Kwant 1.4 introduced `kwant.physics.magnetic_gauge` for computing Peierls diff --git a/doc/source/tutorial/plotting.rst b/doc/source/tutorial/plotting.rst index 48b6ed4796251b64aa9de10922ce911cd3aade86..b30237bc912176464698b2b8f50fc5e83ecc6616 100644 --- a/doc/source/tutorial/plotting.rst +++ b/doc/source/tutorial/plotting.rst @@ -292,6 +292,29 @@ arbitrarily, allowing for a good inspection of the geometry from all sides. does not properly honor the corresponding arguments. By resizing the plot window however one can manually adjust the aspect ratio. +If you also have plotly installed, you can now use the plotly engine for all +the plotting functions within Kwant (except streamplots). The way you would do +it is simple, just set the plotter engine to ``plotly`` and then call the +plotting function as you would do with some minor changes (See note below). For +example, from the previous plot, you would need to just do this: + +.. jupyter-execute:: + + kwant.plotter.set_engine('plotly') # Set to plotly engine + + kwant.plot(syst) + + kwant.plotter.get_engine() # Get the current engine + + kwant.plotter.set_engine('matplotlib') # Set to matplotlib engine + +.. note:: + + By default, the engine would be set to matplotlib if both matplotlib and + plotly are installed, and if either are installed, then, the default would + be the one available in your system. Certain attributes such as dpi or + fig_size or ax are not supported. + Also for 3D it is possible to customize the plot. For example, we can explicitly plot the hoppings as lines, and color sites differently depending on the sublattice: diff --git a/docker/Dockerfile.debian b/docker/Dockerfile.debian index 12ab7d7e42eb7eb032a3ce37b0f20078157e7337..ad235f066518144ebf33c28dd567c46f892acc3b 100644 --- a/docker/Dockerfile.debian +++ b/docker/Dockerfile.debian @@ -16,6 +16,7 @@ RUN echo "deb http://downloads.kwant-project.org/debian/ stable main" >> /etc/ap # all the hard Python dependencies python3-all-dev python3-setuptools python3-pip python3-tk python3-wheel \ python3-numpy python3-scipy python3-matplotlib python3-sympy python3-tinyarray \ + python3-plotly \ # Additional tools for running CI file rsync openssh-client \ && apt-get clean && \ diff --git a/docker/Dockerfile.ubuntu b/docker/Dockerfile.ubuntu index f1c54772ea69224788ff73268ddb0773708888fb..39789a7724ac9570d39a5ba94462eac635a5b4d8 100644 --- a/docker/Dockerfile.ubuntu +++ b/docker/Dockerfile.ubuntu @@ -14,6 +14,7 @@ RUN apt-add-repository -s ppa:kwant-project/ppa && \ # all the hard Python dependencies python3-all-dev python3-setuptools python3-pip python3-tk python3-wheel \ python3-numpy python3-scipy python3-matplotlib python3-sympy python3-tinyarray \ + python3-plotly \ # Additional tools for running CI file rsync openssh-client \ && apt-get clean && \ diff --git a/docker/kwant-latest.yml b/docker/kwant-latest.yml index 72de14af9dfb4a5535109954cd7fe517d2fd34bb..6a6283fbe618c48a27cee07edb09ff5db57d8321 100644 --- a/docker/kwant-latest.yml +++ b/docker/kwant-latest.yml @@ -9,6 +9,7 @@ dependencies: - sympy - matplotlib - qsymm + - plotly # Linear algebra libraraies - mumps - blas #=1.1 openblas diff --git a/docker/kwant-stable.yml b/docker/kwant-stable.yml index 93decc70dc27de7192a8f4af6548207e8322c1e7..bd676136d040f192e372dad731dbc9aee5ca5bab 100644 --- a/docker/kwant-stable.yml +++ b/docker/kwant-stable.yml @@ -9,6 +9,7 @@ dependencies: - sympy=1.1.1 - matplotlib=2.1.1 - qsymm=1.2.6 + - plotly=2.2.2 # Linear algebra libraraies - mumps - blas #=1.1 openblas diff --git a/kwant/_colormaps.py b/kwant/_colormaps.py index 86b11cb52ba83023b68b500fce58aaa5aa1adfe1..d92e4ebf2682d8fa5ae4110ff68ce531b19ffe6a 100644 --- a/kwant/_colormaps.py +++ b/kwant/_colormaps.py @@ -7,8 +7,6 @@ # http://kwant-project.org/authors. import numpy as np -from matplotlib.colors import ListedColormap - kr_data = [[ 0.98916316, 0.98474381, 0.99210697], [ 0.98723538, 0.98138853, 0.98740721], @@ -271,4 +269,4 @@ kr_data = [[ 0.98916316, 0.98474381, 0.99210697], kr_data = np.array(kr_data) kr_data = np.clip(kr_data / kr_data[0], 0, 1) -kwant_red = ListedColormap(kr_data, name="kwant red") +kwant_red = kr_data diff --git a/kwant/_plotter.py b/kwant/_plotter.py index 991fed41473f8ac045b39387c4fd490ce24f7869..ce778424d2bb284520b7edbe0bff201a53538cdf 100644 --- a/kwant/_plotter.py +++ b/kwant/_plotter.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2011-2018 Kwant authors. +# Copyright 2011-2019 Kwant authors. # # This file is part of Kwant. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution and at @@ -18,6 +18,19 @@ import warnings from math import sqrt, pi import numpy as np +from enum import Enum + + +try: + __IPYTHON__ + is_ipython_kernel = True +except NameError: + is_ipython_kernel = False + +global mpl_available +global plotly_available +mpl_available = False +plotly_available = False try: import matplotlib @@ -26,7 +39,10 @@ try: from matplotlib.figure import Figure from matplotlib import collections from . import _colormaps + from matplotlib.colors import ListedColormap mpl_available = True + kwant_red_matplotlib = ListedColormap(_colormaps.kwant_red, + name="kwant red") try: from mpl_toolkits import mplot3d has3d = True @@ -34,9 +50,41 @@ try: warnings.warn("3D plotting not available.", RuntimeWarning) has3d = False except ImportError: - warnings.warn("matplotlib is not available, only iterator-providing " - "functions will work.", RuntimeWarning) - mpl_available = False + warnings.warn("matplotlib is not available, if other engines are " + "unavailable, only iterator-providing functions will work", + RuntimeWarning) + + +try: + import plotly.offline as plotly_module + import plotly.graph_objs as plotly_graph_objs + init_notebook_mode_set = False + from . import _colormaps + plotly_available = True + + _cmap_plotly = 255 * _colormaps.kwant_red + _cmap_levels = np.linspace(0, 1, len(_cmap_plotly)) + kwant_red_plotly = [(level, 'rgb({},{},{})'.format(*rgb)) + for level, rgb in zip(_cmap_levels, _cmap_plotly)] +except ImportError: + warnings.warn("plotly is not available, if other engines are unavailable," + " only iterator-providing functions will work", + RuntimeWarning) + +Engines = [] + +if plotly_available: + Engines.append("plotly") + engine = "plotly" + +if mpl_available: + Engines.append("matplotlib") + engine = "matplotlib" + +if not ((mpl_available) or (plotly_available)): + engine = None + +Engines = frozenset(Engines) # Collections that allow for symbols and linewiths to be given in data space @@ -52,6 +100,113 @@ def nparray_if_array(var): return np.asarray(var) if isarray(var) else var +if plotly_available: + + # The converter_map and converter_map_3d converts the common marker symbols + # of matplotlib to the symbols of plotly + converter_map = { + "o": 0, + "v": 6, + "^": 5, + "<": 7, + ">": 8, + "s": 1, + "+": 3, + "x": 4, + "*": 17, + "d": 2, + "h": 14, + "no symbol": -1 + } + + converter_map_3d = { + "o": "circle", + "s": "square", + "+": "cross", + "x": "x", + "d": "diamond", + } + + def error_string(symbol_input, supported): + return 'Input symbol/s \'{}\' not supported. Only the following characters are supported: {}'.format(symbol_input, supported) + + + def convert_symbol_mpl_plotly(mpl_symbol): + if isarray(mpl_symbol): + try: + converted_symbol = [converter_map.get(i) for i in mpl_symbol] + except KeyError: + raise RuntimeError( error_string(mpl_symbol, list(converter_map)) ) + else: + try: + converted_symbol = converter_map.get(mpl_symbol) + except KeyError: + raise RuntimeError( error_string(mpl_symbol, list(converter_map)) ) + return converted_symbol + + + def convert_symbol_mpl_plotly_3d(mpl_symbol): + if isarray(mpl_symbol): + try: + converted_symbol = [converter_map_3d.get(i) for i in mpl_symbol] + except KeyError: + raise RuntimeError( error_string(mpl_symbol, list(converter_map_3d)) ) + else: + try: + converted_symbol = converter_map_3d.get(mpl_symbol) + except KeyError: + raise RuntimeError( error_string(mpl_symbol, list(converter_map_3d)) ) + return converted_symbol + + + def convert_site_size_mpl_plotly(mpl_site_size, plotly_ref_px): + # The conversion is such that we assume matplotlib's marker size is in + # square points (https://matplotlib.org/devdocs/api/_as_gen/matplotlib.pyplot.scatter.html) + # and we need to convert the points to pixels for plotly. + # Hence, 1 pixel = (96.0)/(72.0) point + return np.sqrt(mpl_site_size)*(96.0/72.0)*plotly_ref_px + + + def convert_colormap_mpl_plotly(mpl_rgba): + _cmap_plotly = 255 * np.array(mpl_rgba) + return 'rgba({},{},{},{})'.format(*_cmap_plotly[0:-1], + _cmap_plotly[-1]/255) + + + def convert_cmap_list_mpl_plotly(mpl_cmap_name, N=255): + if isinstance(mpl_cmap_name, str): + cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name) + cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl) + level = np.linspace(0, 1, N) + cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl)) + for level, cmap_mpl in zip(level, + cmap_mpl_arr)] + else: + assert(isinstance(mpl_cmap_name, list)) + # Do not do any conversion if it's already a list + cmap_plotly_linear = mpl_cmap_name + return cmap_plotly_linear + + + def convert_lead_cmap_mpl_plotly(mpl_lead_cmap_init, mpl_lead_cmap_end, + N=255): + r_levels = np.linspace(mpl_lead_cmap_init[0], + mpl_lead_cmap_end[0], N) * 255 + g_levels = np.linspace(mpl_lead_cmap_init[1], + mpl_lead_cmap_end[1], N) * 255 + b_levels = np.linspace(mpl_lead_cmap_init[2], + mpl_lead_cmap_end[2], N) * 255 + a_levels = np.linspace(mpl_lead_cmap_init[3], + mpl_lead_cmap_end[3], N) + level = np.linspace(0, 1, N) + cmap_plotly_linear = [(level, 'rgba({},{},{},{})'.format(*rgba)) + for level, rgba in zip(level, + zip(r_levels, g_levels, + b_levels, a_levels + ))] + return cmap_plotly_linear + + if mpl_available: class LineCollection(collections.LineCollection): def __init__(self, segments, reflen=None, **kwargs): @@ -321,3 +476,14 @@ if mpl_available: self.set_linewidths(self.linewidths_orig2 * factor) super().draw(renderer) + +if plotly_available: + def matplotlib_to_plotly_cmap(cmap, pl_entries): + h = 1.0/(pl_entries-1) + pl_colorscale = [] + + for k in range(pl_entries): + C = map(np.uint8, np.array(cmap(k*h)[:3])*255) + pl_colorscale.append([k*h, 'rgb'+str((C[0], C[1], C[2]))]) + + return pl_colorscale diff --git a/kwant/plotter.py b/kwant/plotter.py index 31a0aba6658b18eb3c79fccc3c0ffcdb2804bedf..021cf22a6a466727eaf1b5f4144dc7bb24192a23 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -30,7 +30,8 @@ from . import system, builder, _common from ._common import deprecate_args -__all__ = ['plot', 'map', 'bands', 'spectrum', 'current', 'density', +__all__ = ['set_engine', 'get_engine', + 'plot', 'map', 'bands', 'spectrum', 'current', 'density', 'interpolate_current', 'interpolate_density', 'streamplot', 'scalarplot', 'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos', @@ -41,6 +42,50 @@ __all__ = ['plot', 'map', 'bands', 'spectrum', 'current', 'density', _p = _common.lazy_import('_plotter') +def set_engine(engine): + """Set the plotting engine to use. + + Parameters + ---------- + engine : str + Options are: 'matplotlib', 'plotly'. + """ + + if ((_p.mpl_available) or (_p.plotly_available)): + try: + assert(engine in _p.Engines) + _p.engine = engine + except: + error_message = "Tried to set an unknown engine \'{}\'.".format( + engine) + error_message += " Supported engines are {}".format( + [e for e in _p.Engines]) + raise RuntimeError(error_message) + else: + warnings.warn("Tried to set \'{}\' but is not " + "available.".format(engine), RuntimeWarning) + + if ((_p.engine == "plotly") and + (not _p.init_notebook_mode_set)): + if (_p.is_ipython_kernel): + _p.init_notebook_mode_set = True + _p.plotly_module.init_notebook_mode(connected=True) + + +def get_engine(): + return _p.engine + + +def _check_incompatible_args_plotly(dpi, fig_size, ax): + assert(_p.engine == "plotly") + if(dpi or fig_size or ax): + raise RuntimeError( + "Plotly engine does not support setting 'dpi', 'fig_size' " + "or 'ax', either leave these parameters unspecified, or " + "select the matplotlib engine with" + "'kwant.plotter.set_engine(\"matplotlib\")'") + + def _sample_array(array, n_samples, rng=None): rng = _common.ensure_rng(rng) la = len(array) @@ -106,13 +151,24 @@ def _maybe_output_fig(fig, file=None, show=True): if fig is None: return - if file is not None: - fig.canvas.print_figure(file, dpi=fig.dpi) - elif show: - # If there was no file provided, pyplot should already be available and - # we can import it safely without additional warnings. - from matplotlib import pyplot - pyplot.show() + if _p.engine == "matplotlib": + if file is not None: + fig.canvas.print_figure(file, dpi=fig.dpi) + elif show: + # If there was no file provided, pyplot should already be available + # and we can import it safely without additional warnings. + from matplotlib import pyplot + pyplot.show() + elif _p.engine == "plotly": + if file is not None: + _p.plotly_module.plot(fig, show_link=False, filename=file, auto_open=False) + if show: + if (_p.is_ipython_kernel): + _p.plotly_module.iplot(fig) + else: + raise RuntimeError('show flag using the plotly engine can ' + 'only be True if and only if called from a ' + 'jupyter/ipython environment.') def set_colors(color, collection, cmap, norm=None): @@ -664,9 +720,13 @@ def sys_leads_hopping_pos(sys, hop_lead_nr): # Useful plot functions (to be extended). - +# The default plotly symbol size is a 6 px +# The keys of 2, and 3 represent the dimension of the system. +# e.g. the default for site_size for kwant system of dim=2 is 0.25, and +# dim=3 is 0.5 defaults = {'site_symbol': {2: 'o', 3: 'o'}, 'site_size': {2: 0.25, 3: 0.5}, + 'plotly_site_size_reference': 6, 'site_color': {2: 'black', 3: 'white'}, 'site_edgecolor': {2: 'black', 3: 'black'}, 'site_lw': {2: 0, 3: 0.1}, @@ -675,7 +735,7 @@ defaults = {'site_symbol': {2: 'o', 3: 'o'}, 'lead_color': {2: 'red', 3: 'red'}} -def plot(sys, num_lead_cells=2, unit='nn', +def plot(sys, num_lead_cells=2, unit=None, site_symbol=None, site_size=None, site_color=None, site_edgecolor=None, site_lw=None, hop_color=None, hop_lw=None, @@ -721,13 +781,13 @@ def plot(sys, num_lead_cells=2, unit='nn', - ``('p', nvert, angle)``: regular polygon with ``nvert`` vertices, rotated by ``angle``. ``angle`` is given in degrees, and ``angle=0`` corresponds to one edge of the polygon pointing upward. The - radius of the inner circle is 1 unit. - - 'no symbol': no symbol is plotted. + radius of the inner circle is 1 unit. [Unsupported by plotly engine] + - 'no symbol': no symbol is plotted. [Unsupported by plotly engine] - 'S', `('P', nvert, angle)`: as the lower-case variants described above, but with an area equal to a circle of radius 1. (Makes the visual size of the symbol equal to the size of a circle with - radius 1). - - matplotlib.path.Path instance. + radius 1). [Unsupported by plotly engine] + - matplotlib.path.Path instance. [Unsupported by plotly engine] Instead of a single symbol, different symbols can be specified for different sites by passing a function that returns a valid @@ -827,9 +887,101 @@ def plot(sys, num_lead_cells=2, unit='nn', its aspect ratio. """ - if not _p.mpl_available: - raise RuntimeError("matplotlib was not found, but is required " - "for plot()") + + # Provide default unit if user did not specify + if _p.engine == "matplotlib": + fig = _plot_matplotlib(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show, dpi, fig_size, ax) + elif _p.engine == "plotly": + _check_incompatible_args_plotly(dpi, fig_size, ax) + fig = _plot_plotly(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show) + elif _p.engine == None: + raise RuntimeError("Cannot use plot() without a plotting lib installed") + else: + raise RuntimeError("plot() does not support engine '{}'".format(_p.engine)) + + _maybe_output_fig(fig, file=file, show=show) + + return fig + +def _resize_to_dim(array, dim): + if array.shape[1] != dim: + ar = np.zeros((len(array), dim), dtype=float) + ar[:, : min(dim, array.shape[1])] = array[ + :, : min(dim, array.shape[1])] + return ar + else: + return array + + +def _check_length(name, loc): + value = loc[name] + if name in ('site_size', 'site_lw') and isinstance(value, tuple): + raise TypeError('{0} may not be a tuple, use list or ' + 'array instead.'.format(name)) + if isinstance(value, (str, tuple)): + return + try: + if len(value) != loc['n_syst_sites']: + raise ValueError('Length of {0} is not equal to number of ' + 'system sites.'.format(name)) + except TypeError: + pass + +# make all specs proper: either constant or lists/np.arrays: +def _make_proper_site_spec(spec_name, spec, syst, sites, fancy_indexing=False): + if _p.isarray(spec) and isinstance(syst, builder.Builder): + raise TypeError('{} cannot be an array when plotting' + ' a Builder; use a function instead.' + .format(spec_name)) + if callable(spec): + spec = [spec(i[0]) for i in sites if i[1] is None] + if (fancy_indexing and _p.isarray(spec) + and not isinstance(spec, np.ndarray)): + try: + spec = np.asarray(spec) + except: + spec = np.asarray(spec, dtype='object') + return spec + +def _make_proper_hop_spec(spec, hops, fancy_indexing=False): + if callable(spec): + spec = [spec(*i[0]) for i in hops if i[1] is None] + if (fancy_indexing and _p.isarray(spec) + and not isinstance(spec, np.ndarray)): + try: + spec = np.asarray(spec) + except: + spec = np.asarray(spec, dtype='object') + return spec + +def _plot_plotly(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show, fig=None): + + if unit == None: + unit = 'pt' syst = sys # for naming consistency inside function bodies # Generate data. @@ -840,35 +992,338 @@ def plot(sys, num_lead_cells=2, unit='nn', n_syst_hops = sum(i[1] is None for i in hops) end_pos, start_pos = sys_leads_hopping_pos(syst, hops) - # Choose plot type. - def resize_to_dim(array): - if array.shape[1] != dim: - ar = np.zeros((len(array), dim), dtype=float) - ar[:, : min(dim, array.shape[1])] = array[ - :, : min(dim, array.shape[1])] - return ar + loc = locals() + + for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor', + 'site_lw']: + _check_length(name, loc) + + # Apply transformations to the data + if pos_transform is not None: + sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos) + end_pos = np.apply_along_axis(pos_transform, 1, end_pos) + start_pos = np.apply_along_axis(pos_transform, 1, start_pos) + + dim = 3 if (sites_pos.shape[1] == 3) else 2 + + sites_pos = _resize_to_dim(sites_pos, dim) + end_pos = _resize_to_dim(end_pos, dim) + start_pos = _resize_to_dim(start_pos, dim) + + # Determine the reference length. + if unit != 'pt': + raise RuntimeError('Plotly engine currently only supports ' + 'the pt symbol size unit') + + site_symbol = _make_proper_site_spec('site_symbol', site_symbol, syst, sites) + if site_symbol is None: site_symbol = defaults['site_symbol'][dim] + # separate different symbols (not done in 3D, the separation + # would mess up sorting) + if (_p.isarray(site_symbol) and dim != 3 and + (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))): + symbol_dict = defaultdict(list) + for i, symbol in enumerate(site_symbol): + symbol_dict[symbol].append(i) + symbol_slcs = [] + for symbol, indx in symbol_dict.items(): + symbol_slcs.append((symbol, np.array(indx))) + fancy_indexing = True + else: + symbol_slcs = [(site_symbol, slice(n_syst_sites))] + fancy_indexing = False + + if site_color is None: + cycle = _color_cycle() + if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)): + # Skipping the leads for brevity. + families = sorted({site.family for site in syst.sites}) + color_mapping = dict(zip(families, cycle)) + def site_color(site): + return color_mapping[syst.sites[site].family] + elif isinstance(syst, builder.Builder): + families = sorted({site[0].family for site in sites}) + color_mapping = dict(zip(families, cycle)) + def site_color(site): + return color_mapping[site.family] else: - return array + # Unknown finalized system, no sites access. + site_color = defaults['site_color'][dim] - loc = locals() + site_size = _make_proper_site_spec('site_size',site_size, syst, sites, fancy_indexing) + site_color = _make_proper_site_spec('site_color',site_color, syst, sites, fancy_indexing) + site_edgecolor = _make_proper_site_spec('site_edgecolor',site_edgecolor, syst, sites, + fancy_indexing) + site_lw = _make_proper_site_spec('site_lw',site_lw, syst, sites, fancy_indexing) + + hop_color = _make_proper_hop_spec(hop_color, hops) + hop_lw = _make_proper_hop_spec(hop_lw, hops) + + # Choose defaults depending on dimension, if None was given + if site_size is None: site_size = defaults['site_size'][dim] + if site_edgecolor is None: + site_edgecolor = defaults['site_edgecolor'][dim] + if site_lw is None: site_lw = defaults['site_lw'][dim] + + if hop_color is None: hop_color = defaults['hop_color'][dim] + if hop_lw is None: hop_lw = defaults['hop_lw'][dim] + + if len(symbol_slcs) > 1: + try: + if site_color.ndim == 1 and len(site_color) == n_syst_sites: + site_color = np.asarray(site_color, dtype=float) + except: + pass + + # take spec also for lead, if it's not a list/array, default, otherwise + if lead_site_symbol is None: + lead_site_symbol = (site_symbol if not _p.isarray(site_symbol) + else defaults['site_symbol'][dim]) + if lead_site_size is None: + lead_site_size = (site_size if not _p.isarray(site_size) + else defaults['site_size'][dim]) + if lead_color is None: + lead_color = defaults['lead_color'][dim] + lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color) + + if lead_site_edgecolor is None: + lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor) + else defaults['site_edgecolor'][dim]) + if lead_site_lw is None: + lead_site_lw = (site_lw if not _p.isarray(site_lw) + else defaults['site_lw'][dim]) + if lead_hop_lw is None: + lead_hop_lw = (hop_lw if not _p.isarray(hop_lw) + else defaults['hop_lw'][dim]) - def check_length(name): - value = loc[name] - if name in ('site_size', 'site_lw') and isinstance(value, tuple): - raise TypeError('{0} may not be a tuple, use list or ' - 'array instead.'.format(name)) - if isinstance(value, (str, tuple)): - return + hop_cmap = None + if not isinstance(cmap, str): try: - if len(value) != n_syst_sites: - raise ValueError('Length of {0} is not equal to number of ' - 'system sites.'.format(name)) + cmap, hop_cmap = cmap except TypeError: pass + # Plot system sites and hoppings + + # First plot the nodes (sites) of the graph + assert dim == 2 or dim == 3 + site_node_trace, site_edge_trace = [], [] + for symbol, slc in symbol_slcs: + site_symbol_plotly = _p.convert_symbol_mpl_plotly(symbol) + if site_symbol_plotly == -1: + # The kwant documentation supports no symbol as a string argument for site_symbol + # If it evaluates to -1, then the user has specified "no symbol" as the input. + # https://kwant-project.org/doc/1/reference/generated/kwant.plotter.plot + continue + size = site_size[slc] if _p.isarray(site_size) else site_size + col = site_color[slc] if _p.isarray(site_color) else site_color + if _p.isarray(site_edgecolor) or _p.isarray(site_lw): + raise RuntimeError("Plotly engine not currently support an array " + "of linecolors or linewidths. Please restrict " + "to only a constant (i.e. no function or array)" + " site_edgecolor and site_lw property " + "for the entire plot.") + else: + edgecol = site_edgecolor if not isinstance(site_edgecolor, tuple) \ + else _p.convert_colormap_mpl_plotly(site_edgecolor) + lw = site_lw + + if dim == 3: + x, y, z = sites_pos[slc].transpose() + site_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=x, y=y, + z=z) + site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly_3d( + symbol) + else: + x, y = sites_pos[slc].transpose() + site_node_trace_elem = _p.plotly_graph_objs.Scatter(x=x, y=y) + site_node_trace_elem.marker.symbol = _p.convert_symbol_mpl_plotly( + symbol) + + site_node_trace_elem.mode = 'markers' + site_node_trace_elem.hoverinfo = 'none' + site_node_trace_elem.marker.showscale = False + site_node_trace_elem.marker.colorscale = \ + _p.convert_cmap_list_mpl_plotly(cmap) + site_node_trace_elem.marker.reversescale = False + marker_color = col if not isinstance(col, tuple) \ + else _p.convert_colormap_mpl_plotly(col) + site_node_trace_elem.marker.color = marker_color + site_node_trace_elem.marker.size = \ + _p.convert_site_size_mpl_plotly(size, + defaults['plotly_site_size_reference']) + + site_node_trace_elem.line.width = lw + site_node_trace_elem.line.color = edgecol + site_node_trace_elem.showlegend = False + + site_node_trace.append(site_node_trace_elem) + + # Now plot the edges (hops) of the graph + end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops] + + if dim == 3: + x0, y0, z0 = end.transpose() + x1, y1, z1 = start.transpose() + nones = [None] * len(x0) + site_edge_trace_elem = _p.plotly_graph_objs.Scatter3d( + x=np.array([x0, x1, nones]).transpose().flatten(), + y=np.array([y0, y1, nones]).transpose().flatten(), + z=np.array([z0, z1, nones]).transpose().flatten() + ) + else: + x0, y0 = end.transpose() + x1, y1 = start.transpose() + nones = [None] * len(x0) + site_edge_trace_elem = _p.plotly_graph_objs.Scatter( + x=np.array([x0, x1, nones]).transpose().flatten(), + y=np.array([y0, y1, nones]).transpose().flatten() + ) + + if _p.isarray(hop_color) or _p.isarray(hop_lw): + raise RuntimeError("Plotly engine not currently support an array " + "of linecolors or linewidths. Please restrict " + "to only a constant (i.e. no function or array)" + " hop_color and hop_lw property " + "for the entire plot.") + site_edge_trace_elem.line.width = hop_lw + site_edge_trace_elem.line.color = hop_color + site_edge_trace_elem.hoverinfo = 'none' + site_edge_trace_elem.showlegend = False + site_edge_trace_elem.mode = 'lines' + site_edge_trace.append(site_edge_trace_elem) + + # Plot lead sites and edges + + lead_node_trace, lead_edge_trace = [], [] + for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs): + lead_site_colors = np.array([i[2] for i in sites[sites_slc]], + dtype=float) + if dim == 3: + + x, y, z = sites_pos[sites_slc].transpose() + lead_node_trace_elem = _p.plotly_graph_objs.Scatter3d(x=x, y=y, + z=z) + lead_node_trace_elem.marker.symbol = \ + _p.convert_symbol_mpl_plotly_3d(lead_site_symbol) + else: + x, y = sites_pos[sites_slc].transpose() + lead_node_trace_elem = _p.plotly_graph_objs.Scatter(x=x, y=y) + lead_site_symbol_plotly = _p.convert_symbol_mpl_plotly(lead_site_symbol) + if lead_site_symbol_plotly == -1: + # The kwant documentation supports no symbol as a string argument for site_symbol + # If it evaluates to -1, then the user has specified "no symbol" as the input. + # https://kwant-project.org/doc/1/reference/generated/kwant.plotter.plot + continue + lead_node_trace_elem.marker.symbol = lead_site_symbol_plotly + + lead_node_trace_elem.mode = 'markers' + lead_node_trace_elem.hoverinfo = 'none' + lead_node_trace_elem.showlegend = False + lead_node_trace_elem.marker.showscale = False + lead_node_trace_elem.marker.reversescale = False + lead_node_trace_elem.marker.color = lead_site_colors + lead_node_trace_elem.marker.colorscale = \ + _p.convert_lead_cmap_mpl_plotly(lead_color, + [1, 1, 1, lead_color[3]]) + lead_node_trace_elem.marker.size = _p.convert_site_size_mpl_plotly( + lead_site_size, + defaults['plotly_site_size_reference']) + + if _p.isarray(lead_site_lw) or _p.isarray(lead_site_edgecolor): + raise RuntimeError("Plotly engine not currently support an array " + "of linecolors or linewidths. Please restrict " + "to only a constant (i.e. no function or array) " + "lead_site_lw and lead_site_edgecolor property " + "for the entire plot.") + lead_node_trace_elem.line.width = lead_site_lw + lead_node_trace_elem.line.color = lead_site_edgecolor + + if lead_node_trace_elem: + lead_node_trace.append(lead_node_trace_elem) + + lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float) + + end, start = end_pos[hops_slc], start_pos[hops_slc] + + if dim == 3: + x0, y0, z0 = end.transpose() + x1, y1, z1 = start.transpose() + nones = [None] * len(x0) + lead_edge_trace_elem = _p.plotly_graph_objs.Scatter3d( + x=np.array([x0, x1, nones]).transpose().flatten(), + y=np.array([y0, y1, nones]).transpose().flatten(), + z=np.array([z0, z1, nones]).transpose().flatten() + ) + + else: + x0, y0 = end.transpose() + x1, y1 = start.transpose() + nones = [None] * len(x0) + lead_edge_trace_elem = _p.plotly_graph_objs.Scatter( + x=np.array([x0, x1, nones]).transpose().flatten(), + y=np.array([y0, y1, nones]).transpose().flatten() + ) + + lead_edge_trace_elem.line.width = lead_hop_lw + lead_edge_trace_elem.line.color = _p.convert_colormap_mpl_plotly( + lead_color) + lead_edge_trace_elem.hoverinfo = 'none' + lead_edge_trace_elem.mode = 'lines' + lead_edge_trace_elem.showlegend = False + + lead_edge_trace.append(lead_edge_trace_elem) + + layout = _p.plotly_graph_objs.Layout( + showlegend=False, + hovermode='closest', + xaxis=dict(showgrid=False, zeroline=False, + showticklabels=True), + yaxis=dict(showgrid=False, zeroline=False, + showticklabels=True)) + if fig == None: + full_trace = list(itertools.chain.from_iterable([site_edge_trace, + site_node_trace, lead_edge_trace, + lead_node_trace])) + fig = _p.plotly_graph_objs.Figure(data=full_trace, + layout=layout) + else: + full_trace = list(itertools.chain.from_iterable([lead_edge_trace, + lead_node_trace])) + for trace in full_trace: + try: + fig.add_trace(trace) + except TypeError: + fig.data += [trace] + + return fig + + +def _plot_matplotlib(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show, dpi, fig_size, ax): + + if unit == None: + unit = 'nn' + + syst = sys # for naming consistency inside function bodies + # Generate data. + sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells) + n_syst_sites = sum(i[1] is None for i in sites) + sites_pos = sys_leads_pos(syst, sites) + hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells) + n_syst_hops = sum(i[1] is None for i in hops) + end_pos, start_pos = sys_leads_hopping_pos(syst, hops) + + loc = locals() for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor', 'site_lw']: - check_length(name) + _check_length(name, loc) # Apply transformations to the data if pos_transform is not None: @@ -879,9 +1334,9 @@ def plot(sys, num_lead_cells=2, unit='nn', dim = 3 if (sites_pos.shape[1] == 3) else 2 if dim == 3 and not _p.has3d: raise RuntimeError("Installed matplotlib does not support 3d plotting") - sites_pos = resize_to_dim(sites_pos) - end_pos = resize_to_dim(end_pos) - start_pos = resize_to_dim(start_pos) + sites_pos = _resize_to_dim(sites_pos, dim) + end_pos = _resize_to_dim(end_pos, dim) + start_pos = _resize_to_dim(start_pos, dim) # Determine the reference length. if unit == 'pt': @@ -914,35 +1369,7 @@ def plot(sys, num_lead_cells=2, unit='nn', except: raise ValueError('Invalid value of unit argument.') - # make all specs proper: either constant or lists/np.arrays: - def make_proper_site_spec(spec_name, spec, fancy_indexing=False): - if _p.isarray(spec) and isinstance(syst, builder.Builder): - raise TypeError('{} cannot be an array when plotting' - ' a Builder; use a function instead.' - .format(spec_name)) - if callable(spec): - spec = [spec(i[0]) for i in sites if i[1] is None] - if (fancy_indexing and _p.isarray(spec) - and not isinstance(spec, np.ndarray)): - try: - spec = np.asarray(spec) - except: - spec = np.asarray(spec, dtype='object') - return spec - - def make_proper_hop_spec(spec, fancy_indexing=False): - if callable(spec): - spec = [spec(*i[0]) for i in hops if i[1] is None] - if (fancy_indexing and _p.isarray(spec) - and not isinstance(spec, np.ndarray)): - try: - spec = np.asarray(spec) - except: - spec = np.asarray(spec, dtype='object') - return spec - - - site_symbol = make_proper_site_spec('site_symbol', site_symbol) + site_symbol = _make_proper_site_spec('site_symbol', site_symbol, syst, sites) if site_symbol is None: site_symbol = defaults['site_symbol'][dim] # separate different symbols (not done in 3D, the separation # would mess up sorting) @@ -976,13 +1403,13 @@ def plot(sys, num_lead_cells=2, unit='nn', # Unknown finalized system, no sites access. site_color = defaults['site_color'][dim] - site_size = make_proper_site_spec('site_size', site_size, fancy_indexing) - site_color = make_proper_site_spec('site_color', site_color, fancy_indexing) - site_edgecolor = make_proper_site_spec('site_edgecolor', site_edgecolor, fancy_indexing) - site_lw = make_proper_site_spec('site_lw', site_lw, fancy_indexing) + site_size = _make_proper_site_spec('site_size', site_size, syst, sites, fancy_indexing) + site_color = _make_proper_site_spec('site_color', site_color, syst, sites, fancy_indexing) + site_edgecolor = _make_proper_site_spec('site_edgecolor', site_edgecolor, syst, sites, fancy_indexing) + site_lw = _make_proper_site_spec('site_lw', site_lw, syst, sites, fancy_indexing) - hop_color = make_proper_hop_spec(hop_color) - hop_lw = make_proper_hop_spec(hop_lw) + hop_color = _make_proper_hop_spec(hop_color, hops) + hop_lw = _make_proper_hop_spec(hop_lw, hops) # Choose defaults depending on dimension, if None was given if site_size is None: site_size = defaults['site_size'][dim] @@ -1111,8 +1538,6 @@ def plot(sys, num_lead_cells=2, unit='nn', if line_coll.get_array() is not None and colorbar and fig is not None: fig.colorbar(line_coll) - _maybe_output_fig(fig, file=file, show=show) - return fig @@ -1187,6 +1612,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): range(len(cmin))) grid = tuple(np.ogrid[dims]) img = interpolate.griddata(coords, values, grid, method) + img = img.astype(np.float_) mask = np.mgrid[dims].reshape(len(cmin), -1).T # The numerical values in the following line are optimized for the common # case of a square lattice: @@ -1195,7 +1621,17 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): # * 0.4 (which is just below sqrt(2) - 1) makes tree.query() exact. mask = tree.query(mask, eps=0.4)[0] > 0.99 * a - return np.ma.masked_array(img, mask), cmin, cmax + masked_result_array = np.ma.masked_array(img, mask) + + try: + if _p.engine != "matplotlib": + result_array = masked_result_array.filled(np.NaN) + else: + result_array = masked_result_array + except AttributeError: + result_array = masked_result_array + + return result_array, img, cmin, cmax def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, @@ -1274,7 +1710,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, kwant.plotter.density """ - if not _p.mpl_available: + if not (_p.mpl_available or _p.plotly_available): raise RuntimeError("matplotlib was not found, but is required " "for map()") @@ -1296,22 +1732,15 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, 'for finalized systems.') value = np.array(value) with _common.reraise_warnings(): - img, min, max = mask_interpolate(coords, value, a, method, oversampling) - border = 0.5 * (max - min) / (np.asarray(img.shape) - 1) - min -= border - max += border - if ax is None: - fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) - ax = fig.add_subplot(1, 1, 1, aspect='equal') - else: - fig = None - - if cmap is None: - cmap = _p._colormaps.kwant_red + img, unmasked_data, _min, _max = mask_interpolate(coords, value, + a, method, oversampling) # Calculate the min/max bounds for the colormap. # User-provided values take precedence. - unmasked_data = img[~img.mask].data.flatten() + if _p.engine != "matplotlib": + unmasked_data = img.ravel() + else: + unmasked_data = img[~img.mask].data.flatten() unmasked_data = unmasked_data[~np.isnan(unmasked_data)] new_vmin, new_vmax = percentile_bound(unmasked_data, vmin, vmax) overflow_pct = 100 * np.sum(unmasked_data > new_vmax) / len(unmasked_data) @@ -1330,9 +1759,85 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, warnings.warn(''.join(msg), RuntimeWarning, stacklevel=2) vmin, vmax = new_vmin, new_vmax + if _p.engine == "matplotlib": + fig = _map_matplotlib(syst, img, colorbar, _max, _min, vmin, vmax, + overflow_pct, underflow_pct, cmap, num_lead_cells, + background, dpi, fig_size, ax, file) + elif _p.engine == "plotly": + fig = _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax, + overflow_pct, underflow_pct, cmap, num_lead_cells, + background) + elif _p.engine == None: + raise RuntimeError("Cannot use map() without a plotting lib installed") + else: + raise RuntimeError("map() does not support engine '{}'".format(_p.engine)) + + _maybe_output_fig(fig, file=file, show=show) + + return fig + + +def _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax, overflow_pct, + underflow_pct, cmap, num_lead_cells, background): + + border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1) + _min -= border + _max += border + + if cmap is None: + cmap = _p.kwant_red_plotly + + img = img.T + contour_object = _p.plotly_graph_objs.Heatmap() + contour_object.z = img + contour_object.x = np.linspace(_min[0],_max[0],img.shape[0]) + contour_object.y = np.linspace(_min[1],_max[1],img.shape[1]) + contour_object.zsmooth = False + contour_object.connectgaps = False + cmap = _p.convert_cmap_list_mpl_plotly(cmap) + contour_object.colorscale = cmap + contour_object.zmax = vmax + contour_object.zmin = vmin + contour_object.hoverinfo = 'none' + + contour_object.showscale = colorbar + + fig = _p.plotly_graph_objs.Figure(data=[contour_object]) + fig.layout.plot_bgcolor = background + fig.layout.showlegend = False + + if num_lead_cells: + fig = _plot_plotly(syst, num_lead_cells, site_symbol='no symbol', + hop_lw=0, lead_site_symbol='s', + lead_site_size=0.501, lead_site_lw=0,lead_hop_lw=0, + lead_color='black', colorbar=False, show=False, + fig=fig, unit='pt', site_size=None, site_color=None, + site_edgecolor=None, site_lw=0, hop_color=None, + lead_site_edgecolor=None,pos_transform=None, + cmap=None, file=None) + + return fig + + +def _map_matplotlib(syst, img, colorbar, _max, _min, vmin, vmax, + overflow_pct, underflow_pct, cmap, num_lead_cells, + background, dpi, fig_size, ax, file): + + border = 0.5 * (_max - _min) / (np.asarray(img.shape) - 1) + _min -= border + _max += border + if ax is None: + fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) + ax = fig.add_subplot(1, 1, 1, aspect='equal') + else: + fig = None + + if cmap is None: + cmap = _p.kwant_red_matplotlib + # Note that we tell imshow to show the array created by mask_interpolate # faithfully and not to interpolate by itself another time. - image = ax.imshow(img.T, extent=(min[0], max[0], min[1], max[1]), + image = ax.imshow(img.T, extent=(_min[0], _max[0], _min[1], _max[1]), origin='lower', interpolation='none', cmap=cmap, vmin=vmin, vmax=vmax) if num_lead_cells: @@ -1353,8 +1858,6 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, extend = 'max' fig.colorbar(image, extend=extend) - _maybe_output_fig(fig, file=file, show=show) - return fig @@ -1374,27 +1877,37 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, Either a number of sampling points on the interval [-pi, pi], or an array of points at which the band structure has to be evaluated. file : string or file object or `None` - The output file. If `None`, output will be shown instead. + The output file. If `None`, output will be shown instead. If plotly is + selected as the engine, the filename has to end with a html extension. show : bool - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. + For matplotlib engine, whether ``matplotlib.pyplot.show()`` is to be + called, and the output is to be shown immediately. + For the plotly engine, a call to ``iplot(fig)`` is made if + show is True. + Defaults to `True` for both engines. dpi : float Number of pixels per inch. If not set the ``matplotlib`` default is used. + Only for matplotlib engine. If the plotly engine is selected and + this argument is not None, then a RuntimeError will be triggered. fig_size : tuple Figure size `(width, height)` in inches. If not set, the default ``matplotlib`` value is used. + Only for matplotlib engine. If the plotly engine is selected and + this argument is not None, then a RuntimeError will be triggered. ax : ``matplotlib.axes.Axes`` instance or `None` If `ax` is not `None`, no new figure is created, but the plot is done within the existing Axes `ax`. in this case, `file`, `show`, `dpi` and `fig_size` are ignored. + Only for matplotlib engine. If the plotly engine is selected and + this argument is not None, then a RuntimeError will be triggered. params : dict, optional Dictionary of parameter names and their values. Mutually exclusive with 'args'. Returns ------- - fig : matplotlib figure + fig : matplotlib figure or plotly Figure object A figure with the output if `ax` is not set, else None. Notes @@ -1402,11 +1915,13 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, See `~kwant.physics.Bands` for the calculation of dispersion without plotting. """ - if not _p.mpl_available: - raise RuntimeError("matplotlib was not found, but is required " - "for bands()") - syst = sys # for naming consistency inside function bodies + + if _p.plotly_available: + if _p.engine == "plotly": + _check_incompatible_args_plotly(dpi, fig_size, ax) + + _common.ensure_isinstance(syst, (system.InfiniteSystem, system.InfiniteVectorizedSystem)) momenta = np.array(momenta) @@ -1435,7 +1950,10 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, def spectrum(syst, x, y=None, params=None, mask=None, file=None, show=True, dpi=None, fig_size=None, ax=None): - """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters + """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters. + + This function requires either matplotlib or plotly to be installed. + The default engine uses matplotlib for plotting. Parameters ---------- @@ -1456,32 +1974,69 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, if the spectrum should not be calculated for the given parameter values. file : string or file object or `None` - The output file. If `None`, output will be shown instead. + The output file. If `None`, output will be shown instead. If plotly is + selected as the engine, the filename has to end with a html extension. show : bool - Whether ``matplotlib.pyplot.show()`` is to be called, and the output is - to be shown immediately. Defaults to `True`. + For matplotlib engine, whether ``matplotlib.pyplot.show()`` is to be + called, and the output is to be shown immediately. + For the plotly engine, a call to ``iplot(fig)`` is made if + show is True. + Defaults to `True` for both engines. dpi : float Number of pixels per inch. If not set the ``matplotlib`` default is used. + Only for matplotlib engine. If the plotly engine is selected and + this argument is not None, then a RuntimeError will be triggered. fig_size : tuple Figure size `(width, height)` in inches. If not set, the default ``matplotlib`` value is used. + Only for matplotlib engine. If the plotly engine is selected and + this argument is not None, then a RuntimeError will be triggered. ax : ``matplotlib.axes.Axes`` instance or `None` If `ax` is not `None`, no new figure is created, but the plot is done within the existing Axes `ax`. in this case, `file`, `show`, `dpi` and `fig_size` are ignored. + Only for matplotlib engine. If the plotly engine is selected and + this argument is not None, then a RuntimeError will be triggered. Returns ------- - fig : matplotlib figure - A figure with the output if `ax` is not set, else None. + fig : matplotlib figure or plotly Figure object """ - if not _p.mpl_available: - raise RuntimeError("matplotlib was not found, but is required " - "for plot_spectrum()") - if y is not None and not _p.has3d: - raise RuntimeError("Installed matplotlib does not support 3d plotting") + params = params or dict() + + if _p.engine == "matplotlib": + return _spectrum_matplotlib(syst, x, y, params, mask, file, + show, dpi, fig_size, ax) + elif _p.engine == "plotly": + _check_incompatible_args_plotly(dpi, fig_size, ax) + return _spectrum_plotly(syst, x, y, params, mask, file, show) + elif _p.engine == None: + raise RuntimeError("Cannot use spectrum() without a plotting lib installed") + else: + raise RuntimeError("spectrum() does not support engine '{}'".format(_p.engine)) + + +def _generate_spectrum(syst, params, mask, x, y): + """Generates the spectrum dataset for the internal plotting + functions of spectrum(). + + Parameters + ---------- + See spectrum(...) documentation. + + Returns + ------- + spectrum : Numpy array + The energies of the system calculated at each coordinate. + planar : bool + True if y is None + array_values : tuple + The coordinates of x, y values of the dataset for plotting. + keys : tuple + Labels for the x and y axes. + """ if system.is_finite(syst): def ham(**kwargs): @@ -1492,9 +2047,9 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, raise TypeError("Expected 'syst' to be a finite Kwant system " "or a function.") - params = params or dict() - keys = (x[0],) if y is None else (x[0], y[0]) - array_values = (x[1],) if y is None else (x[1], y[1]) + planar = y is None + keys = (x[0],) if planar else (x[0], y[0]) + array_values = (x[1],) if planar else (x[1], y[1]) # calculate spectrum on the grid of points spectrum = [] @@ -1514,10 +2069,84 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, new_shape = [len(v) for v in array_values] + [-1] spectrum = np.array(spectrum).reshape(new_shape) + return spectrum, planar, array_values, keys + + +def _spectrum_plotly(syst, x, y=None, params=None, mask=None, + file=None, show=True): + """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters + using the plotly engine. + + Parameters + ---------- + See spectrum(...) documentation. + + Returns + ------- + fig : plotly Figure / dict + """ + + spectrum, planar, array_values, keys = _generate_spectrum(syst, params, + mask, x, y) + + if planar: + fig = _p.plotly_graph_objs.Figure(data=[ + _p.plotly_graph_objs.Scatter( + x=array_values[0], + y=energies, + ) for energies in spectrum.T + ]) + fig.layout.xaxis.title = keys[0] + fig.layout.yaxis.title = 'Energy' + fig.layout.showlegend = False + else: + fig = _p.plotly_graph_objs.Figure(data=[ + _p.plotly_graph_objs.Surface( + x=array_values[0], + y=array_values[1], + z=energies, + cmax=np.max(spectrum), + cmin=np.min(spectrum), + ) for energies in spectrum.T + ]) + fig.layout.scene.xaxis.title = keys[0] + fig.layout.scene.yaxis.title = keys[1] + fig.layout.scene.zaxis.title = 'Energy' + + fig.layout.title = ( + ', '.join('{} = {}'.format(*kv) for kv in params.items()) + ) + + _maybe_output_fig(fig, file=file, show=show) + + return fig + + +def _spectrum_matplotlib(syst, x, y=None, params=None, mask=None, file=None, + show=True, dpi=None, fig_size=None, ax=None): + """Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters + using the matplotlib engine. + + Parameters + ---------- + See spectrum(...) documentation. + + Returns + ------- + fig : matplotlib figure + A figure with the output if `ax` is not set, else None. + """ + + if y is not None and not _p.has3d: + raise RuntimeError("Installed matplotlib does not support 3d plotting") + + spectrum, planar, array_values, keys = _generate_spectrum(syst, params, + mask, x, y) + # set up axes if ax is None: fig = _make_figure(dpi, fig_size, use_pyplot=(file is None)) - if y is None: + if planar: ax = fig.add_subplot(1, 1, 1) else: warnings.filterwarnings('ignore', @@ -1525,7 +2154,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, ax = fig.add_subplot(1, 1, 1, projection='3d') warnings.resetwarnings() ax.set_xlabel(keys[0]) - if y is None: + if planar: ax.set_ylabel('Energy') else: ax.set_ylabel(keys[1]) @@ -1541,7 +2170,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, fig = None # actually do the plot - if y is None: + if planar: ax.plot(array_values[0], spectrum) else: if not hasattr(ax, 'plot_surface'): @@ -1922,6 +2551,31 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', colorbar=True, file=None, show=True, dpi=None, fig_size=None, ax=None, vmax=None): + if _p.engine == "matplotlib": + fig = _streamplot_matplotlib(field, box, cmap, bgcolor, linecolor, + max_linewidth, min_linewidth, density, colorbar, file, + show, dpi, fig_size, ax, vmax) + elif _p.engine == "plotly": + _check_incompatible_args_plotly(dpi, fig_size, ax) + fig = _streamplot_plotly(field, box, cmap, bgcolor, linecolor, + max_linewidth, min_linewidth, density, + colorbar, file, show, vmax) + elif _p.engine == None: + raise RuntimeError("Cannot use streamplot() without a plotting lib installed") + else: + raise RuntimeError("streamplot() does not support engine '{}'".format(_p.engine)) + _maybe_output_fig(fig, file=file, show=show) + + +def _streamplot_plotly(field, box, cmap, bgcolor, linecolor, + max_linewidth, min_linewidth, density, + colorbar, file, show, vmax): + raise RuntimeError("Streamplot() for plotly engine not implemented yet due to bug from plotly") + + +def _streamplot_matplotlib(field, box, cmap, bgcolor, linecolor, + max_linewidth, min_linewidth, density, colorbar, file, + show, dpi, fig_size, ax, vmax): """Draw streamlines of a flow field in Kwant style Solid colored streamlines are drawn, superimposed on a color plot of @@ -1930,7 +2584,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', would be thinner than `min_linewidth` are blended in a perceptually correct way into the background color in order to create the illusion of arbitrarily thin lines. (This is done because some plot - backends like PDF do not support lines of arbitrarily thin width.) + engines like PDF do not support lines of arbitrarily thin width.) Internally, this routine uses matplotlib's streamplot. @@ -1983,9 +2637,6 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', fig : matplotlib figure A figure with the output if `ax` is not set, else None. """ - if not _p.mpl_available: - raise RuntimeError("matplotlib was not found, but is required " - "for current()") # Matplotlib's "density" is in units of 30 streamlines... density *= 1 / 30 * ta.array(field.shape[:2], int) @@ -1999,7 +2650,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', if bgcolor is None: if cmap is None: - cmap = _p._colormaps.kwant_red + cmap = _p.kwant_red_matplotlib cmap = _p.matplotlib.cm.get_cmap(cmap) bgcolor = cmap(0)[:3] elif cmap is not None: @@ -2096,20 +2747,66 @@ def scalarplot(field, box, fig : matplotlib figure A figure with the output if ``ax`` is not set, else None. """ - if not _p.mpl_available: - raise RuntimeError("matplotlib was not found, but is required " - "for current()") # Matplotlib plots images like matrices: image[y, x]. We use the opposite # convention: image[x, y]. Hence, it is necessary to transpose. # Also squeeze out the last axis as it is just a scalar field + field = field.squeeze(axis=-1).transpose() if field.ndim != 2: raise ValueError("Only 2D field can be plotted.") + if vmin is None: + vmin = np.min(field) + if vmax is None: + vmax = np.max(field) + + if _p.engine == "matplotlib": + fig = _scalarplot_matplotlib(field, box, cmap, colorbar, + file, show, dpi, fig_size, ax, + vmin, vmax, background) + elif _p.engine == "plotly": + _check_incompatible_args_plotly(dpi, fig_size, ax) + fig = _scalarplot_plotly(field, box, cmap, colorbar, file, + show, vmin, vmax, background) + elif _p.engine == None: + raise RuntimeError("Cannot use scalarplot() without a plotting lib installed") + else: + raise RuntimeError("scalarplot() does not support engine '{}'".format(_p.engine)) + _maybe_output_fig(fig, file=file, show=show) + + return fig + + +def _scalarplot_plotly(field, box, cmap, colorbar, file, + show, vmin, vmax, background): + if cmap is None: - cmap = _p._colormaps.kwant_red + cmap = _p.kwant_red_plotly + + contour_object = _p.plotly_graph_objs.Heatmap() + contour_object.z = field + contour_object.x = np.linspace(*box[0],field.shape[0]) + contour_object.y = np.linspace(*box[1],field.shape[1]) + contour_object.zsmooth = 'best' + contour_object.colorscale = cmap + contour_object.zmax = vmax + contour_object.zmin = vmin + + contour_object.showscale = colorbar + + fig = _p.plotly_graph_objs.Figure(data=[contour_object]) + fig.layout.plot_bgcolor = background + + return fig + + +def _scalarplot_matplotlib(field, box, cmap, colorbar, file, show, dpi, + fig_size, ax, vmin, vmax, background): + + if cmap is None: + cmap = _p.kwant_red_matplotlib cmap = _p.matplotlib.cm.get_cmap(cmap) if ax is None: @@ -2118,11 +2815,6 @@ def scalarplot(field, box, else: fig = None - if vmin is None: - vmin = np.min(field) - if vmax is None: - vmax = np.max(field) - image = ax.imshow(field, cmap=cmap, interpolation='bicubic', extent=[e for c in box for e in c], @@ -2135,8 +2827,6 @@ def scalarplot(field, box, if colorbar and cmap and fig is not None: fig.colorbar(image) - _maybe_output_fig(fig, file=file, show=show) - return fig diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index 9edc902748c20d14372a881dda0093a68ac27a20..9b2955e5d942ba4f9415dca158889fbc761759eb 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -45,7 +45,7 @@ def test_matplotlib_backend_unset(): assert matplotlib_backend_chosen is False -def test_importable_without_matplotlib(): +def test_importable_without_backends(): prefix, sep, suffix = _plotter.__file__.rpartition('.') if suffix in ['pyc', 'pyo']: suffix = 'py' @@ -55,13 +55,16 @@ def test_importable_without_matplotlib(): code = f.read() code = code.replace(b'from . import', b'from kwant import') code = code.replace(b'matplotlib', b'totalblimp') + code = code.replace(b'plotly', b'plylot') with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") exec(code) # Trigger the warning. - assert len(w) == 1 + assert len(w) == 2 assert issubclass(w[0].category, RuntimeWarning) - assert "only iterator-providing functions" in str(w[0].message) + assert issubclass(w[1].category, RuntimeWarning) + assert "totalblimp is not available" in str(w[0].message) + assert "plylot is not available" in str(w[1].message) def syst_2d(W=3, r1=3, r2=8): @@ -113,18 +116,31 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0): return syst +def plotter_file_suffix(engine): + # We need this function so that we can add a .html suffix to the output filename. + # This is required because plotly will throw an error if filename is without the suffix. + if engine == "plotly": + return ".html" + else: + return None + + @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") -def test_plot(): +def test_matplotlib_plot(): + + plotter.set_engine('matplotlib') plot = plotter.plot syst2d = syst_2d() syst3d = syst_3d() color_opts = ['k', (lambda site: site.tag[0]), lambda site: (abs(site.tag[0] / 100), abs(site.tag[1] / 100), 0)] - with tempfile.TemporaryFile('w+b') as out: + engine = plotter.get_engine() + with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: + out_filename = out.name for color in color_opts: for syst in (syst2d, syst3d): - fig = plot(syst, site_color=color, cmap='binary', file=out) + fig = plot(syst, site_color=color, cmap='binary', file=out_filename) if (color != 'k' and isinstance(color(next(iter(syst2d.sites()))), float)): assert fig.axes[0].collections[0].get_array() is not None @@ -134,30 +150,66 @@ def test_plot(): abs(site.tag[1] / 100), 0)] for color in color_opts: for syst in (syst2d, syst3d): - fig = plot(syst2d, hop_color=color, cmap='binary', file=out, + fig = plot(syst2d, hop_color=color, cmap='binary', file=out_filename, fig_size=(2, 10), dpi=30) if color != 'k' and isinstance(color(next(iter(syst2d.sites())), None), float): assert fig.axes[0].collections[1].get_array() is not None - assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D) + assert isinstance(plot(syst3d, file=out_filename).axes[0], mplot3d.axes3d.Axes3D) syst2d.leads = [] - plot(syst2d, file=out) + plot(syst2d, file=out_filename) del syst2d[list(syst2d.hoppings())] - plot(syst2d, file=out) + plot(syst2d, file=out_filename) - plot(syst3d, file=out) + plot(syst3d, file=out_filename) with warnings.catch_warnings(): warnings.simplefilter("ignore") - plot(syst2d.finalized(), file=out) + plot(syst2d.finalized(), file=out_filename) # test 2D projections of 3D systems - plot(syst3d, file=out, pos_transform=lambda pos: pos[:2]) + plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2]) + + +@pytest.mark.skipif(not _plotter.plotly_available, reason="Plotly unavailable.") +def test_plotly_plot(): + + plotter.set_engine('plotly') + plot = plotter.plot + syst2d = syst_2d() + syst3d = syst_3d() + color_opts = ['black', (lambda site: site.tag[0]), + lambda site: (abs(site.tag[0] / 100), + abs(site.tag[1] / 100), 0)] + engine = plotter.get_engine() + with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: + out_filename = out.name + for color in color_opts: + for syst in (syst2d, syst3d): + plot(syst, site_color=color, cmap='binary', file=out_filename, show=False) + + color_opts = ['black', (lambda site, site2: site.tag[0]), + lambda site, site2: (abs(site.tag[0] / 100), + abs(site.tag[1] / 100), 0)] + + syst2d.leads = [] + plot(syst2d, file=out_filename, show=False) + del syst2d[list(syst2d.hoppings())] + plot(syst2d, file=out_filename, show=False) + + plot(syst3d, file=out_filename, show=False) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + plot(syst2d.finalized(), file=out_filename, show=False) + + # test 2D projections of 3D systems + plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2], show=False) @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") -def test_plot_more_site_families_than_colors(): +@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) +def test_plot_more_site_families_than_colors(engine): # test against regression reported in # https://gitlab.kwant-project.org/kwant/kwant/issues/257 ncolors = len(pyplot.rcParams['axes.prop_cycle']) @@ -166,17 +218,23 @@ def test_plot_more_site_families_than_colors(): for i in range(ncolors + 1)] for i, lat in enumerate(lattices): syst[lat(i, 0)] = None - with tempfile.TemporaryFile('w+b') as out: - plotter.plot(syst, file=out) + + plotter.set_engine(engine) + with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: + out_filename = out.name + print(out) + plotter.plot(syst, file=out_filename, show=False) @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") -def test_plot_raises_on_bad_site_spec(): +@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) +def test_plot_raises_on_bad_site_spec(engine): syst = kwant.Builder() lat = kwant.lattice.square(norbs=1) syst[(lat(i, j) for i in range(5) for j in range(5))] = None # Cannot provide site_size as an array when syst is a Builder + plotter.set_engine(engine) with pytest.raises(TypeError): plotter.plot(syst, site_size=[1] * 25) @@ -194,20 +252,24 @@ def bad_transform(pos): return x, y, 0 @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") -def test_map(): +@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) +def test_map(engine): + plotter.set_engine(engine) syst = syst_2d() - with tempfile.TemporaryFile('w+b') as out: + + with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: + out_filename = out.name plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform, - file=out, method='linear', a=4, oversampling=4, cmap='flag') + file=out_filename, method='linear', a=4, oversampling=4, cmap='flag', show=False) pytest.raises(ValueError, plotter.map, syst, lambda site: site.tag[0], - pos_transform=bad_transform, file=out) + pos_transform=bad_transform, file=out_filename) with warnings.catch_warnings(): warnings.simplefilter("ignore") plotter.map(syst.finalized(), range(len(syst.sites())), - file=out) + file=out_filename, show=False) pytest.raises(ValueError, plotter.map, syst, - range(len(syst.sites())), file=out) + range(len(syst.sites())), file=out_filename) def test_mask_interpolate(): @@ -230,22 +292,32 @@ def test_mask_interpolate(): @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") -def test_bands(): +@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) +def test_bands(engine): + + plotter.set_engine(engine) syst = syst_2d().finalized().leads[0] - with tempfile.TemporaryFile('w+b') as out: - plotter.bands(syst, file=out) - plotter.bands(syst, fig_size=(10, 10), file=out) - plotter.bands(syst, momenta=np.linspace(0, 2 * np.pi), file=out) + with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: + out_filename = out.name + plotter.bands(syst, show=False, file=out_filename) + plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out_filename) + + if engine == 'matplotlib': + plotter.bands(syst, show=False, fig_size=(10, 10), file=out_filename) + + fig = pyplot.Figure() + ax = fig.add_subplot(1, 1, 1) + plotter.bands(syst, show=False, ax=ax, file=out_filename) - fig = pyplot.Figure() - ax = fig.add_subplot(1, 1, 1) - plotter.bands(syst, ax=ax, file=out) @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") -def test_spectrum(): +@pytest.mark.parametrize("engine", ["plotly", "matplotlib"]) +def test_spectrum(engine): + + plotter.set_engine(engine) def ham_1d(a, b, c): return a**2 + b**2 + c**2 @@ -261,38 +333,43 @@ def test_spectrum(): vals = np.linspace(0, 1, 3) - with tempfile.TemporaryFile('w+b') as out: + with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: + out_filename = out.name for ham in (ham_1d, ham_2d, fsyst): - plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out) - # test with explicit figsize - plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), - fig_size=(10, 10), file=out) + plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out_filename, show=False) + if engine == 'matplotlib': + # test with explicit figsize + plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), + fig_size=(10, 10), file=out_filename, show=False) for ham in (ham_1d, ham_2d, fsyst): plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), - params=dict(c=1), file=out) - # test with explicit figsize - plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), - params=dict(c=1), fig_size=(10, 10), file=out) - - # test 2D plot and explicitly passing axis - fig = pyplot.figure() - ax = fig.add_subplot(1, 1, 1, projection='3d') - plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals), - params=dict(c=1), ax=ax, file=out) - # explicitly pass axis without 3D support - ax = fig.add_subplot(1, 1, 1) - with pytest.raises(TypeError): + params=dict(c=1), file=out_filename, show=False) + if engine == 'matplotlib': + # test with explicit figsize + plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), + params=dict(c=1), fig_size=(10, 10), file=out_filename, show=False) + + if engine == 'matplotlib': + # test 2D plot and explicitly passing axis + fig = pyplot.figure() + ax = fig.add_subplot(1, 1, 1, projection='3d') plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals), - params=dict(c=1), ax=ax, file=out) + params=dict(c=1), ax=ax, file=out_filename, show=False) + # explicitly pass axis without 3D support + ax = fig.add_subplot(1, 1, 1) + with pytest.raises(TypeError): + plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals), + params=dict(c=1), ax=ax, file=out_filename, show=False) def mask(a, b): return a > 0.5 - with tempfile.TemporaryFile('w+b') as out: + with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out: + out_filename = out.name plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1), - mask=mask, file=out) + mask=mask, file=out_filename, show=False) def syst_rect(lat, salt, W=3, L=50): @@ -552,7 +629,7 @@ def test_current(): current = J(kwant.wave_function(syst, energy=1)(1)[0]) # Test good codepath - with tempfile.TemporaryFile('w+b') as out: + with tempfile.NamedTemporaryFile('w+b') as out: plotter.current(syst, current, file=out) fig = pyplot.Figure() diff --git a/setup.py b/setup.py index 85cb54eb613e7de06db60ee17bb27647749a8d38..e97b4a80f07f4bfbc61a96327db72588d61aadf1 100755 --- a/setup.py +++ b/setup.py @@ -585,7 +585,8 @@ def main(): 'tinyarray >= 1.2'], extras_require={ # The oldest versions between: Debian stable, Ubuntu LTS - 'plotting': 'matplotlib >= 2.1.1', + 'plotting': ['matplotlib >= 2.1.1', + 'plotly >= 2.2.2'], 'continuum': 'sympy >= 1.1.1', # qsymm is only packaged on PyPI 'qsymm': 'qsymm >= 1.2.6',