From 21d32fd33e9e051f98f11078233ad3c4d1feb6bd Mon Sep 17 00:00:00 2001 From: Kelvin Loh <kel85uk@gmail.com> Date: Mon, 10 Dec 2018 15:30:38 +0100 Subject: [PATCH] 2D implementation of the plot() function. --- kwant/_plotter.py | 70 ++++++++ kwant/plotter.py | 441 +++++++++++++++++++++++++++++++++++++++------- 2 files changed, 446 insertions(+), 65 deletions(-) diff --git a/kwant/_plotter.py b/kwant/_plotter.py index 5cb965d3..4b24936d 100644 --- a/kwant/_plotter.py +++ b/kwant/_plotter.py @@ -93,6 +93,76 @@ def nparray_if_array(var): return np.asarray(var) if isarray(var) else var +if plotly_available: + + converter_map = { + "o": 0, + "v": 6, + "^": 5, + "<": 7, + ">": 8, + "s": 1, + "+": 3, + "x": 4, + "*": 17, + "d": 2, + "h": 14 + } + + + def convert_symbol_mpl_plotly(mpl_symbol): + if isarray(mpl_symbol): + converted_symbol = [converter_map.get(i) for i in mpl_symbol] + else: + converted_symbol = converter_map.get(mpl_symbol) + + if converted_symbol == None: + raise RuntimeWarning('Input symbol \'{}\' not supported. ' + 'Only the following are supported: {}'.format( + mpl_symbol, converter_map.keys())) + return converted_symbol + + + def convert_site_size_mpl_plotly(mpl_site_size, plotly_ref_px): + return np.sqrt(mpl_site_size)*(96.0/72.0)*plotly_ref_px + + + def convert_colormap_mpl_plotly(mpl_rgba): + _cmap_plotly = 255 * np.array(mpl_rgba) + return 'rgba({},{},{},{})'.format(*_cmap_plotly[0:-1], + _cmap_plotly[-1]/255) + + + def convert_cmap_list_mpl_plotly(mpl_cmap_name, N=255): + cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name) + cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl) + level = np.linspace(1, 0, N) + cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl)) + for level, cmap_mpl in zip(level, + cmap_mpl_arr)] + return cmap_plotly_linear + + + def convert_lead_cmap_mpl_plotly(mpl_lead_cmap_init, mpl_lead_cmap_end, + N=255): + r_levels = np.linspace(mpl_lead_cmap_init[0], + mpl_lead_cmap_end[0], N) * 255 + g_levels = np.linspace(mpl_lead_cmap_init[1], + mpl_lead_cmap_end[1], N) * 255 + b_levels = np.linspace(mpl_lead_cmap_init[2], + mpl_lead_cmap_end[2], N) * 255 + a_levels = np.linspace(mpl_lead_cmap_init[3], + mpl_lead_cmap_end[3], N) + level = np.linspace(1, 0, N) + cmap_plotly_linear = [(level, 'rgba({},{},{},{})'.format(*rgba)) + for level, rgba in zip(level, zip(r_levels, + g_levels, + b_levels, + a_levels + ))] + return cmap_plotly_linear + + if mpl_available: class LineCollection(collections.LineCollection): def __init__(self, segments, reflen=None, **kwargs): diff --git a/kwant/plotter.py b/kwant/plotter.py index 13c26611..a38579a6 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -724,18 +724,21 @@ def sys_leads_hopping_pos(sys, hop_lead_nr): # Useful plot functions (to be extended). - +# The default plotly symbol size is a 6 px defaults = {'site_symbol': {2: 'o', 3: 'o'}, 'site_size': {2: 0.25, 3: 0.5}, + 'plotly_site_size_reference': 6, 'site_color': {2: 'black', 3: 'white'}, 'site_edgecolor': {2: 'black', 3: 'black'}, 'site_lw': {2: 0, 3: 0.1}, 'hop_color': {2: 'black', 3: 'black'}, 'hop_lw': {2: 0.1, 3: 0}, - 'lead_color': {2: 'red', 3: 'red'}} + 'lead_color': {2: 'red', 3: 'red'}, + 'unit': {_p.Backends.plotly: 'pt', + _p.Backends.matplotlib: 'nn'}} -def plot(sys, num_lead_cells=2, unit='nn', +def plot(sys, num_lead_cells=2, unit=None, site_symbol=None, site_size=None, site_color=None, site_edgecolor=None, site_lw=None, hop_color=None, hop_lw=None, @@ -887,10 +890,104 @@ def plot(sys, num_lead_cells=2, unit='nn', its aspect ratio. """ - if not _p.mpl_available: - raise RuntimeError("matplotlib was not found, but is required " + + # Provide default unit if user did not specify + if unit == None: + unit = defaults['unit'][get_backend()] + if get_backend() == _p.Backends.matplotlib: + fig = _plot_matplotlib(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show, dpi, fig_size, ax) + elif get_backend() == _p.Backends.plotly: + _check_incompatible_args_plotly(dpi, fig_size, ax) + fig = _plot_plotly(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show) + else: + raise RuntimeError("Backend not supported by plot().") + + _maybe_output_fig(fig, file=file, show=show) + + return fig + +def _resize_to_dim(array, dim): + if array.shape[1] != dim: + ar = np.zeros((len(array), dim), dtype=float) + ar[:, : min(dim, array.shape[1])] = array[ + :, : min(dim, array.shape[1])] + return ar + else: + return array + + +def _check_length(name, loc): + value = loc[name] + if name in ('site_size', 'site_lw') and isinstance(value, tuple): + raise TypeError('{0} may not be a tuple, use list or ' + 'array instead.'.format(name)) + if isinstance(value, (str, tuple)): + return + try: + if len(value) != loc['n_syst_sites']: + raise ValueError('Length of {0} is not equal to number of ' + 'system sites.'.format(name)) + except TypeError: + pass + +# make all specs proper: either constant or lists/np.arrays: +def _make_proper_site_spec(spec_name, spec, fancy_indexing=False): + if _p.isarray(spec) and isinstance(syst, builder.Builder): + raise TypeError('{} cannot be an array when plotting' + ' a Builder; use a function instead.' + .format(spec_name)) + if callable(spec): + spec = [spec(i[0]) for i in sites if i[1] is None] + if (fancy_indexing and _p.isarray(spec) + and not isinstance(spec, np.ndarray)): + try: + spec = np.asarray(spec) + except: + spec = np.asarray(spec, dtype='object') + return spec + +def _make_proper_hop_spec(spec, fancy_indexing=False): + if callable(spec): + spec = [spec(*i[0]) for i in hops if i[1] is None] + if (fancy_indexing and _p.isarray(spec) + and not isinstance(spec, np.ndarray)): + try: + spec = np.asarray(spec) + except: + spec = np.asarray(spec, dtype='object') + return spec + +def _plot_plotly(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show): + + if not _p.plotly_available: + raise RuntimeError("plotly was not found, but is required " "for plot()") + print('In _plot_plotly') syst = sys # for naming consistency inside function bodies # Generate data. sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells) @@ -900,35 +997,279 @@ def plot(sys, num_lead_cells=2, unit='nn', n_syst_hops = sum(i[1] is None for i in hops) end_pos, start_pos = sys_leads_hopping_pos(syst, hops) - # Choose plot type. - def resize_to_dim(array): - if array.shape[1] != dim: - ar = np.zeros((len(array), dim), dtype=float) - ar[:, : min(dim, array.shape[1])] = array[ - :, : min(dim, array.shape[1])] - return ar + loc = locals() + + for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor', + 'site_lw']: + _check_length(name, loc) + + # Apply transformations to the data + if pos_transform is not None: + sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos) + end_pos = np.apply_along_axis(pos_transform, 1, end_pos) + start_pos = np.apply_along_axis(pos_transform, 1, start_pos) + + dim = 3 if (sites_pos.shape[1] == 3) else 2 + + sites_pos = _resize_to_dim(sites_pos, dim) + end_pos = _resize_to_dim(end_pos, dim) + start_pos = _resize_to_dim(start_pos, dim) + + # Determine the reference length. + if unit != 'pt': + raise RuntimeError('Plotly backend currently only supports ' + 'the pt symbol size unit') + + + site_symbol = _make_proper_site_spec('site_symbol', site_symbol, sites) + if site_symbol is None: site_symbol = defaults['site_symbol'][dim] + # separate different symbols (not done in 3D, the separation + # would mess up sorting) + if (_p.isarray(site_symbol) and dim != 3 and + (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))): + symbol_dict = defaultdict(list) + for i, symbol in enumerate(site_symbol): + symbol_dict[symbol].append(i) + symbol_slcs = [] + for symbol, indx in symbol_dict.items(): + symbol_slcs.append((symbol, np.array(indx))) + fancy_indexing = True + else: + symbol_slcs = [(site_symbol, slice(n_syst_sites))] + fancy_indexing = False + + if site_color is None: + cycle = _color_cycle() + if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)): + # Skipping the leads for brevity. + families = sorted({site.family for site in syst.sites}) + color_mapping = dict(zip(families, cycle)) + def site_color(site): + return color_mapping[syst.sites[site].family] + elif isinstance(syst, builder.Builder): + families = sorted({site[0].family for site in sites}) + color_mapping = dict(zip(families, cycle)) + def site_color(site): + return color_mapping[site.family] else: - return array + # Unknown finalized system, no sites access. + site_color = defaults['site_color'][dim] - loc = locals() + site_size = _make_proper_site_spec('site_size',site_size, sites, fancy_indexing) + site_color = _make_proper_site_spec('site_color',site_color, sites, fancy_indexing) + site_edgecolor = _make_proper_site_spec('site_edgecolor',site_edgecolor, sites, + fancy_indexing) + site_lw = _make_proper_site_spec('site_lw',site_lw, sites, fancy_indexing) - def check_length(name): - value = loc[name] - if name in ('site_size', 'site_lw') and isinstance(value, tuple): - raise TypeError('{0} may not be a tuple, use list or ' - 'array instead.'.format(name)) - if isinstance(value, (str, tuple)): - return + hop_color = _make_proper_hop_spec(hop_color, hops) + hop_lw = _make_proper_hop_spec(hop_lw, hops) + + # Choose defaults depending on dimension, if None was given + if site_size is None: site_size = defaults['site_size'][dim] + if site_edgecolor is None: + site_edgecolor = defaults['site_edgecolor'][dim] + if site_lw is None: site_lw = defaults['site_lw'][dim] + + if hop_color is None: hop_color = defaults['hop_color'][dim] + if hop_lw is None: hop_lw = defaults['hop_lw'][dim] + + # if symbols are split up into different collections, + # the colormapping will fail without normalization + norm = None + if len(symbol_slcs) > 1: try: - if len(value) != n_syst_sites: - raise ValueError('Length of {0} is not equal to number of ' - 'system sites.'.format(name)) + if site_color.ndim == 1 and len(site_color) == n_syst_sites: + site_color = np.asarray(site_color, dtype=float) + norm = _p.matplotlib.colors.Normalize(site_color.min(), + site_color.max()) + except: + pass + + # take spec also for lead, if it's not a list/array, default, otherwise + if lead_site_symbol is None: + lead_site_symbol = (site_symbol if not _p.isarray(site_symbol) + else defaults['site_symbol'][dim]) + if lead_site_size is None: + lead_site_size = (site_size if not _p.isarray(site_size) + else defaults['site_size'][dim]) + if lead_color is None: + lead_color = defaults['lead_color'][dim] + lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color) + + if lead_site_edgecolor is None: + lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor) + else defaults['site_edgecolor'][dim]) + if lead_site_lw is None: + lead_site_lw = (site_lw if not _p.isarray(site_lw) + else defaults['site_lw'][dim]) + if lead_hop_lw is None: + lead_hop_lw = (hop_lw if not _p.isarray(hop_lw) + else defaults['hop_lw'][dim]) + + hop_cmap = None + if not isinstance(cmap, str): + try: + cmap, hop_cmap = cmap except TypeError: pass + # plot system sites and hoppings + site_node_trace, site_edge_trace = [], [] + for symbol, slc in symbol_slcs: + size = site_size[slc] if _p.isarray(site_size) else site_size + col = site_color[slc] if _p.isarray(site_color) else site_color + edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else + site_edgecolor) + lw = site_lw[slc] if _p.isarray(site_lw) else site_lw + + site_symbol_plotly = _p.convert_symbol_mpl_plotly(symbol) + site_node_trace_elem = _p.plotly_graph_objs.Scatter( + x=[], + y=[], + text=[], + mode='markers', + hoverinfo='text', + marker=dict( + showscale=False, + colorscale=_p.convert_cmap_list_mpl_plotly(cmap), + reversescale=True, + color=col, + size=_p.convert_site_size_mpl_plotly(size, + defaults['plotly_site_size_reference']), + symbol=site_symbol_plotly, + line=dict(width=lw, + color=edgecol) + )) + + + for i in range(len(sites_pos[slc])): + x, y = sites_pos[slc][i] + site_node_trace_elem['x'] += tuple([x]) + site_node_trace_elem['y'] += tuple([y]) + + site_node_trace.append(site_node_trace_elem) + + end, start = end_pos[: n_syst_hops], start_pos[: n_syst_hops] + dim = end.shape[1] + assert dim == 2 or dim == 3 + if dim == 2: + site_edge_trace_elem = _p.plotly_graph_objs.Scatter( + x=[], + y=[], + line=dict(width=hop_lw,color=hop_color), + hoverinfo='none', + mode='lines') + for i in range(len(end)): + x0, y0 = end[i] + x1, y1 = start[i] + site_edge_trace_elem['x'] += tuple([x0, x1, None]) + site_edge_trace_elem['y'] += tuple([y0, y1, None]) + site_edge_trace.append(site_edge_trace_elem) + else: + raise RuntimeError('dim=3 is unsupported yet in plotly backend') + + # Make conversion of colormap + + lead_site_symbol_plotly = _p.convert_symbol_mpl_plotly(lead_site_symbol) + + lead_node_trace, lead_edge_trace = [], [] + for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs): + lead_site_colors = np.array([i[2] for i in sites[sites_slc]], + dtype=float) + lead_node_trace_elem = _p.plotly_graph_objs.Scatter( + x=[], + y=[], + text=[], + mode='markers', + hoverinfo='text', + marker=dict( + showscale=False, + reversescale=True, + color=lead_site_colors, + colorscale=_p.convert_lead_cmap_mpl_plotly( + lead_color, [1,1,1,lead_color[3]]), + size=_p.convert_site_size_mpl_plotly( + lead_site_size, + defaults['plotly_site_size_reference']), + symbol=lead_site_symbol_plotly, + line=dict(width=lead_site_lw, + color=lead_site_edgecolor) + )) + for i in range(len(sites_pos[sites_slc])): + x, y = sites_pos[sites_slc][i] + lead_node_trace_elem['x'] += tuple([x]) + lead_node_trace_elem['y'] += tuple([y]) + lead_node_trace.append(lead_node_trace_elem) + lead_hop_colors = np.array([i[2] for i in hops[hops_slc]], dtype=float) + # Note: the previous version of the code had in addition this + # line in the 3D case: + # lead_hop_colors = 1 / np.sqrt(1. + lead_hop_colors) + # Uses lead_cmap for the colormap + # 1) Make each line a scatter object. Takes a lot of memory but should work + # 2) Get the color from the previous object + end, start = end_pos[hops_slc], start_pos[hops_slc] + if dim == 2: + lead_edge_trace_elem = _p.plotly_graph_objs.Scatter( + x=[], + y=[], + line=dict(width=lead_hop_lw, + color='red'), + hoverinfo='none', + mode='lines') + for i in range(len(end)): + x0, y0 = end[i] + x1, y1 = start[i] + lead_edge_trace_elem['x'] += tuple([x0, x1, None]) + lead_edge_trace_elem['y'] += tuple([y0, y1, None]) + + lead_edge_trace.append(lead_edge_trace_elem) + else: + raise RuntimeError('dim=3 is unsupported yet in plotly backend') + + layout = _p.plotly_graph_objs.Layout( + showlegend=False, + hovermode='closest', + xaxis=dict(showgrid=False, zeroline=False, + showticklabels=True), + yaxis=dict(showgrid=False, zeroline=False, + showticklabels=True)) + full_trace = list(itertools.chain.from_iterable([site_edge_trace, + site_node_trace, lead_edge_trace, + lead_node_trace])) + fig = _p.plotly_graph_objs.Figure(data=full_trace, + layout=layout) + + return fig + + +def _plot_matplotlib(sys, num_lead_cells, unit, + site_symbol, site_size, + site_color, site_edgecolor, site_lw, + hop_color, hop_lw, + lead_site_symbol, lead_site_size, lead_color, + lead_site_edgecolor, lead_site_lw, + lead_hop_lw, pos_transform, + cmap, colorbar, file, + show, dpi, fig_size, ax): + + if not _p.mpl_available: + raise RuntimeError("matplotlib was not found, but is required " + "for plot()") + + print('In _plot_matplotlib') + syst = sys # for naming consistency inside function bodies + # Generate data. + sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells) + n_syst_sites = sum(i[1] is None for i in sites) + sites_pos = sys_leads_pos(syst, sites) + hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells) + n_syst_hops = sum(i[1] is None for i in hops) + end_pos, start_pos = sys_leads_hopping_pos(syst, hops) + + loc = locals() for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor', 'site_lw']: - check_length(name) + _check_length(name, loc) # Apply transformations to the data if pos_transform is not None: @@ -939,9 +1280,9 @@ def plot(sys, num_lead_cells=2, unit='nn', dim = 3 if (sites_pos.shape[1] == 3) else 2 if dim == 3 and not _p.has3d: raise RuntimeError("Installed matplotlib does not support 3d plotting") - sites_pos = resize_to_dim(sites_pos) - end_pos = resize_to_dim(end_pos) - start_pos = resize_to_dim(start_pos) + sites_pos = _resize_to_dim(sites_pos, dim) + end_pos = _resize_to_dim(end_pos, dim) + start_pos = _resize_to_dim(start_pos, dim) # Determine the reference length. if unit == 'pt': @@ -974,35 +1315,7 @@ def plot(sys, num_lead_cells=2, unit='nn', except: raise ValueError('Invalid value of unit argument.') - # make all specs proper: either constant or lists/np.arrays: - def make_proper_site_spec(spec_name, spec, fancy_indexing=False): - if _p.isarray(spec) and isinstance(syst, builder.Builder): - raise TypeError('{} cannot be an array when plotting' - ' a Builder; use a function instead.' - .format(spec_name)) - if callable(spec): - spec = [spec(i[0]) for i in sites if i[1] is None] - if (fancy_indexing and _p.isarray(spec) - and not isinstance(spec, np.ndarray)): - try: - spec = np.asarray(spec) - except: - spec = np.asarray(spec, dtype='object') - return spec - - def make_proper_hop_spec(spec, fancy_indexing=False): - if callable(spec): - spec = [spec(*i[0]) for i in hops if i[1] is None] - if (fancy_indexing and _p.isarray(spec) - and not isinstance(spec, np.ndarray)): - try: - spec = np.asarray(spec) - except: - spec = np.asarray(spec, dtype='object') - return spec - - - site_symbol = make_proper_site_spec('site_symbol', site_symbol) + site_symbol = _make_proper_site_spec('site_symbol', site_symbol) if site_symbol is None: site_symbol = defaults['site_symbol'][dim] # separate different symbols (not done in 3D, the separation # would mess up sorting) @@ -1036,13 +1349,13 @@ def plot(sys, num_lead_cells=2, unit='nn', # Unknown finalized system, no sites access. site_color = defaults['site_color'][dim] - site_size = make_proper_site_spec('site_size', site_size, fancy_indexing) - site_color = make_proper_site_spec('site_color', site_color, fancy_indexing) - site_edgecolor = make_proper_site_spec('site_edgecolor', site_edgecolor, fancy_indexing) - site_lw = make_proper_site_spec('site_lw', site_lw, fancy_indexing) + site_size = _make_proper_site_spec('site_size', site_size, fancy_indexing) + site_color = _make_proper_site_spec('site_color', site_color, fancy_indexing) + site_edgecolor = _make_proper_site_spec('site_edgecolor', site_edgecolor, fancy_indexing) + site_lw = _make_proper_site_spec('site_lw', site_lw, fancy_indexing) - hop_color = make_proper_hop_spec(hop_color) - hop_lw = make_proper_hop_spec(hop_lw) + hop_color = _make_proper_hop_spec(hop_color, hops) + hop_lw = _make_proper_hop_spec(hop_lw, hops) # Choose defaults depending on dimension, if None was given if site_size is None: site_size = defaults['site_size'][dim] @@ -1171,8 +1484,6 @@ def plot(sys, num_lead_cells=2, unit='nn', if line_coll.get_array() is not None and colorbar and fig is not None: fig.colorbar(line_coll) - _maybe_output_fig(fig, file=file, show=show) - return fig -- GitLab