plotter.py 87.6 KB
Newer Older
1
# -*- coding: utf-8 -*-
2
# Copyright 2011-2018 Kwant authors.
3
#
Christoph Groth's avatar
Christoph Groth committed
4
5
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
6
# http://kwant-project.org/license.  A list of Kwant authors can be found in
Christoph Groth's avatar
Christoph Groth committed
7
# the file AUTHORS.rst at the top-level directory of this distribution and at
8
9
# http://kwant-project.org/authors.

10
"""Plotter module for Kwant.
11

12
13
14
15
16
17
This module provides iterators useful for any plotter routine, such as a list
of system sites, their coordinates, lead sites at any lead unit cell, etc.  If
`matplotlib` is available, it also provides simple functions for plotting the
system in two or three dimensions.
"""

18
from collections import defaultdict
19
import sys
20
21
import itertools
import functools
22
import warnings
23
import cmath
24
import numpy as np
25
import tinyarray as ta
26
from scipy import spatial, interpolate
27
28
from math import cos, sin, pi, sqrt

29
from . import system, builder, _common
30
from ._common import deprecate_args
31

Christoph Groth's avatar
Christoph Groth committed
32

33
34
35
__all__ = ['plot', 'map', 'bands', 'spectrum', 'current', 'density',
           'interpolate_current', 'interpolate_density',
           'streamplot', 'scalarplot',
Christoph Groth's avatar
Christoph Groth committed
36
37
           'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos',
           'sys_leads_hopping_pos', 'mask_interpolate']
38

39
40
41
42
# All the expensive imports are done in _plotter.py. We lazy load the module
# to avoid slowing down the initial import of Kwant.
_p = _common.lazy_import('_plotter')

Anton Akhmerov's avatar
Anton Akhmerov committed
43

44
45
def _sample_array(array, n_samples, rng=None):
    rng = _common.ensure_rng(rng)
46
    la = len(array)
47
    return array[rng.choice(range(la), min(n_samples, la), replace=False)]
48
49


50
51
# matplotlib helper functions.

52
53
54
55
56
57
def _color_cycle():
    """Infinitely cycle through colors from the matplotlib color cycle."""
    props = _p.matplotlib.rcParams['axes.prop_cycle']
    return itertools.cycle(x['color'] for x in props)


58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def _make_figure(dpi, fig_size, use_pyplot=False):
    if 'matplotlib.backends' not in sys.modules:
        warnings.warn(
            "Kwant's plotting functions have\nthe side effect of "
            "selecting the matplotlib backend. To avoid this "
            "warning,\nimport matplotlib.pyplot, "
            "matplotlib.backends or call matplotlib.use().",
            RuntimeWarning, stacklevel=3
        )
    if use_pyplot:
        # We import backends and pyplot only at the last possible moment (=now)
        # because this has the side effect of selecting the matplotlib backend
        # for good.  Warn if backend has not been set yet.  This check is the
        # same as the one performed inside matplotlib.use.
        from matplotlib import pyplot
        fig = pyplot.figure()
    else:
        from matplotlib.backends.backend_agg import FigureCanvasAgg
Christoph Groth's avatar
Christoph Groth committed
76
        fig = _p.Figure()
77
        fig.canvas = FigureCanvasAgg(fig)
Christoph Groth's avatar
Christoph Groth committed
78
79
80
81
82
83
84
85
    if dpi is not None:
        fig.set_dpi(dpi)
    if fig_size is not None:
        fig.set_figwidth(fig_size[0])
        fig.set_figheight(fig_size[1])
    return fig


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def _maybe_output_fig(fig, file=None, show=True):
    """Output a matplotlib figure using a given output mode.

    Parameters
    ----------
    fig : matplotlib.figure.Figure instance
        The figure to be output.
    file : string or a file object
        The name of the target file or the target file itself
        (opened for writing).
    show : bool
        Whether to call ``matplotlib.pyplot.show()``.  Only has an effect if
        not saving to a file.

    Notes
    -----
    The behavior of this function producing a file is different from that of
    matplotlib in that the `dpi` attribute of the figure is used by defaul
    instead of the matplotlib config setting.
    """
    if fig is None:
        return

    if file is not None:
        fig.canvas.print_figure(file, dpi=fig.dpi)
    elif show:
        # If there was no file provided, pyplot should already be available and
        # we can import it safely without additional warnings.
        from matplotlib import pyplot
        pyplot.show()


118
def set_colors(color, collection, cmap, norm=None):
119
    """Process a color specification to a format accepted by collections.
120

121
122
123
    Parameters
    ----------
    color : color specification
124
    collection : instance of a subclass of ``matplotlib.collections.Collection``
125
        Collection to which the color is added.
126
    cmap : ``matplotlib`` color map specification or None
127
        Color map to be used if colors are specified as floats.
128
    norm : ``matplotlib`` color norm
129
130
        Norm to be used if colors are specified as floats.
    """
131
132
133
134

    length = max(len(collection.get_paths()), len(collection.get_offsets()))

    # matplotlib gets confused if dtype='object'
Anton Akhmerov's avatar
Anton Akhmerov committed
135
    if (isinstance(color, np.ndarray) and color.dtype == np.dtype('object')):
136
137
        color = tuple(color)

Christoph Groth's avatar
Christoph Groth committed
138
    if _p.has3d and isinstance(collection, _p.mplot3d.art3d.Line3DCollection):
139
        length = len(collection._segments3d)  # Once again, matplotlib fault!
140

141
    if _p.isarray(color) and len(color) == length:
142
143
144
145
146
147
148
149
150
151
152
153
        try:
            # check if it is an array of floats for color mapping
            color = np.asarray(color, dtype=float)
            if color.ndim == 1:
                collection.set_array(color)
                collection.set_cmap(cmap)
                collection.set_norm(norm)
                collection.set_color(None)
                return
        except (TypeError, ValueError):
            pass

154
    colors = _p.matplotlib.colors.colorConverter.to_rgba_array(color)
155
    collection.set_color(colors)
156
157


158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def percentile_bound(data, vmin, vmax, percentile=96, stretch=0.1):
    """Return the bounds that captures at least 'percentile' of 'data'.

    If 'vmin' or 'vmax' are provided, then the corresponding bound is
    exactly 'vmin' or 'vmax'. First we set the bounds such that the
    provided percentile of the data is within them. Then we try to
    extend the bounds to cover all the data, maximally stretching each
    bound by a factor 'stretch'.
    """
    if vmin is not None and vmax is not None:
        return vmin, vmax

    percentile = (100 - percentile) / 2
    percentiles = (0, percentile, 100 - percentile, 100)
    mn, bound_mn, bound_mx, mx = np.percentile(data.flatten(), percentiles)

    bound_mn = bound_mn if vmin is None else vmin
    bound_mx = bound_mx if vmax is None else vmax

    # Stretch the lower and upper bounds to cover all the data, if
    # we stretch the bound by less than a factor 'stretch'.
    stretch = (bound_mx - bound_mn) * stretch
    out_mn = max(bound_mn - stretch, mn) if vmin is None else vmin
    out_mx = min(bound_mx + stretch, mx) if vmax is None else vmax

    return (out_mn, out_mx)


Anton Akhmerov's avatar
Anton Akhmerov committed
186
symbol_dict = {'O': 'o', 's': ('p', 4, 45), 'S': ('P', 4, 45)}
187
188

def get_symbol(symbols):
189
    """Return the path corresponding to the description in ``symbols``"""
Anton Akhmerov's avatar
Anton Akhmerov committed
190
    # Figure out if list of symbols or single symbol.
191
192
193
    if not hasattr(symbols, '__getitem__'):
        symbols = [symbols]
    elif len(symbols) == 3 and symbols[0] in ('p', 'P'):
Anton Akhmerov's avatar
Anton Akhmerov committed
194
195
        # Most likely a polygon specification (at least not a valid other
        # symbol).
196
197
        symbols = [symbols]

Anton Akhmerov's avatar
Anton Akhmerov committed
198
199
    symbols = [symbol_dict[symbol] if symbol in symbol_dict else symbol for
               symbol in symbols]
200
201
202

    paths = []
    for symbol in symbols:
203
        if isinstance(symbol, _p.matplotlib.path.Path):
204
205
206
207
208
209
            return symbol
        elif hasattr(symbol, '__getitem__') and len(symbol) == 3:
            kind, n, angle = symbol

            if kind in ['p', 'P']:
                if kind == 'p':
Anton Akhmerov's avatar
Anton Akhmerov committed
210
                    radius = 1. / cos(pi / n)
211
212
213
214
215
216
                else:
                    # make the polygon such that it has area equal
                    # to a unit circle
                    radius = sqrt(2 * pi / (n * sin(2 * pi / n)))

                angle = pi * angle / 180
217
218
219
                patch = _p.matplotlib.patches.RegularPolygon((0, 0), n,
                                                             radius=radius,
                                                             orientation=angle)
220
221
222
            else:
                raise ValueError("Unknown symbol definition " + str(symbol))
        elif symbol == 'o':
223
            patch = _p.matplotlib.patches.Circle((0, 0), 1)
224
225
226
227
228
229

        paths.append(patch.get_path().transformed(patch.get_transform()))

    return paths


Anton Akhmerov's avatar
Anton Akhmerov committed
230
231
def symbols(axes, pos, symbol='o', size=1, reflen=None, facecolor='k',
            edgecolor='k', linewidth=None, cmap=None, norm=None, zorder=0,
232
            **kwargs):
Anton Akhmerov's avatar
Anton Akhmerov committed
233
    """Add a collection of symbols (2D or 3D) to an axes instance.
234
235
236

    Parameters
    ----------
237
238
    axes : matplotlib.axes.Axes instance
        Axes to which the lines have to be added.
239
240
241
242
243
244
245
    pos0 : 2d or 3d array_like
        Coordinates of each symbol.
    symbol: symbol definition.
        TODO To be written.
    size: float or 1d array
        Size(s) of the symbols. Defaults to 1.
    reflen: float or None, optional
246
        If ``reflen`` is ``None``, the symbol sizes and linewidths are
247
        given in points (absolute size in the figure space). If
248
249
250
        ``reflen`` is a number, the symbol sizes and linewidths are
        given in units of ``reflen`` in data space (i.e. scales with the
        scale of the plot). Defaults to ``None``.
251
252
253
254
255
256
    facecolor: color definition, optional
    edgecolor: color definition, optional
        Defines the fill and edge color of the symbol, repsectively.
        Either a single object that is a proper matplotlib color
        definition or a sequence of such objects of appropriate
        length.  Defaults to all black.
257
    cmap : ``matplotlib`` color map specification or None
258
        Color map to be used if colors are specified as floats.
259
    norm : ``matplotlib`` color norm
260
        Norm to be used if colors are specified as floats.
261
262
    zorder: int
        Order in which different collections are drawn: larger
263
264
        ``zorder`` means the collection is drawn over collections with
        smaller ``zorder`` values.
265
266
    **kwargs : dict keyword arguments to
        pass to `PathCollection` or `Path3DCollection`, respectively.
267

268
269
    Returns
    -------
270
271
    `PathCollection` or `Path3DCollection` instance containing all the
    symbols that were added.
272
    """
273

274
275
    dim = pos.shape[1]
    assert dim == 2 or dim == 3
276

277
278
279
    #internally, size must be array_like
    try:
        size[0]
Anton Akhmerov's avatar
Anton Akhmerov committed
280
    except TypeError:
281
282
283
        size = (size, )

    if dim == 2:
284
        Collection = _p.PathCollection
285
    else:
286
        Collection = _p.Path3DCollection
287
288

    if len(pos) == 0 or np.all(symbol == 'no symbol') or np.all(size == 0):
Anton Akhmerov's avatar
Anton Akhmerov committed
289
290
291
292
        paths = []
        pos = np.empty((0, dim))
    else:
        paths = get_symbol(symbol)
293

Anton Akhmerov's avatar
Anton Akhmerov committed
294
295
    coll = Collection(paths, sizes=size, reflen=reflen, linewidths=linewidth,
                      offsets=pos, transOffset=axes.transData, zorder=zorder)
296

297
298
    set_colors(facecolor, coll, cmap, norm)
    coll.set_edgecolors(edgecolor)
Anton Akhmerov's avatar
Anton Akhmerov committed
299

300
    coll.update(kwargs)
301

302
303
304
305
    if dim == 2:
        axes.add_collection(coll)
    else:
        axes.add_collection3d(coll)
306

307
    return coll
308
309


310
311
def lines(axes, pos0, pos1, reflen=None, colors='k', linestyles='solid',
          cmap=None, norm=None, zorder=0, **kwargs):
Anton Akhmerov's avatar
Anton Akhmerov committed
312
    """Add a collection of line segments (2D or 3D) to an axes instance.
313
314
315

    Parameters
    ----------
316
317
    axes : matplotlib.axes.Axes instance
        Axes to which the lines have to be added.
318
    pos0 : 2d or 3d array_like
319
        Starting coordinates of each line segment
320
321
322
323
324
325
326
    pos1 : 2d or 3d array_like
        Ending coordinates of each line segment
    reflen: float or None, optional
        If `reflen` is `None`, the linewidths are given in points (absolute
        size in the figure space). If `reflen` is a number, the linewidths
        are given in units of `reflen` in data space (i.e. scales with
        the scale of the plot). Defaults to `None`.
327
328
329
330
331
332
333
334
    colors : color definition, optional
        Either a single object that is a proper matplotlib color definition
        or a sequence of such objects of appropriate length.  Defaults to all
        segments black.
    linestyles :linestyle definition, optional
        Either a single object that is a proper matplotlib line style
        definition or a sequence of such objects of appropriate length.
        Defaults to all segments solid.
335
    cmap : ``matplotlib`` color map specification or None
336
        Color map to be used if colors are specified as floats.
337
    norm : ``matplotlib`` color norm
338
        Norm to be used if colors are specified as floats.
339
340
341
342
343
344
    zorder: int
        Order in which different collections are drawn: larger
        `zorder` means the collection is drawn over collections with
        smaller `zorder` values.
    **kwargs : dict keyword arguments to
        pass to `LineCollection` or `Line3DCollection`, respectively.
345
346
347

    Returns
    -------
348
    `LineCollection` or `Line3DCollection` instance containing all the
349
    segments that were added.
350
351
    """

352
    if not pos0.shape == pos1.shape:
353
        raise ValueError('Incompatible lengths of coordinate arrays.')
354

355
356
357
    dim = pos0.shape[1]
    assert dim == 2 or dim == 3
    if dim == 2:
358
        Collection = _p.LineCollection
359
    else:
360
        Collection = _p.Line3DCollection
361
362
363
364
365
366
367
368
369
370

    if (len(pos0) == 0 or
        ('linewidths' in kwargs and kwargs['linewidths'] == 0)):
        coll = Collection([], reflen=reflen, linestyles=linestyles,
                          zorder=zorder)
        coll.update(kwargs)
        if dim == 2:
            axes.add_collection(coll)
        else:
            axes.add_collection3d(coll)
371
        return coll
372

373
    segments = np.c_[pos0, pos1].reshape(pos0.shape[0], 2, dim)
374

375
376
    coll = Collection(segments, reflen=reflen, linestyles=linestyles,
                      zorder=zorder)
377
    set_colors(colors, coll, cmap, norm)
378
    coll.update(kwargs)
379

380
381
382
383
    if dim == 2:
        axes.add_collection(coll)
    else:
        axes.add_collection3d(coll)
384

385
    return coll
386
387


388
# Extracting necessary data from the system.
389

390
def sys_leads_sites(sys, num_lead_cells=2):
391
    """Return all the sites of the system and of the leads as a list.
392

393
394
395
396
    Parameters
    ----------
    sys : kwant.builder.Builder or kwant.system.System instance
        The system, sites of which should be returned.
397
    num_lead_cells : integer
398
399
400
        The number of times lead sites from each lead should be returned.
        This is useful for showing several unit cells of the lead next to the
        system.
401

402
403
404
    Returns
    -------
    sites : list of (site, lead_number, copy_number) tuples
405
        A site is a `~kwant.builder.Site` instance if the system is not finalized,
406
407
        and an integer otherwise.  For system sites `lead_number` is `None` and
        `copy_number` is `0`, for leads both are integers.
408
409
    lead_cells : list of slices
        `lead_cells[i]` gives the position of all the coordinates of lead
410
        `i` within `sites`.
411

412
413
414
    Notes
    -----
    Leads are only supported if they are of the same type as the original
415
416
    system, i.e.  sites of `~kwant.builder.BuilderLead` leads are returned with an
    unfinalized system, and sites of ``system.InfiniteSystem`` leads are
417
418
    returned with a finalized system.
    """
419
    syst = sys  # for naming consistency within function bodies
420
    lead_cells = []
421
422
423
    if isinstance(syst, builder.Builder):
        sites = [(site, None, 0) for site in syst.sites()]
        for leadnr, lead in enumerate(syst.leads):
424
            start = len(sites)
425
            if hasattr(lead, 'builder') and len(lead.interface):
426
427
                sites.extend(((site, leadnr, i) for site in
                              lead.builder.sites() for i in
Joseph Weston's avatar
Joseph Weston committed
428
                              range(num_lead_cells)))
429
            lead_cells.append(slice(start, len(sites)))
430
431
432
    elif isinstance(syst, system.FiniteSystem):
        sites = [(i, None, 0) for i in range(syst.graph.num_nodes)]
        for leadnr, lead in enumerate(syst.leads):
433
            start = len(sites)
434
            # We will only plot leads with a graph and with a symmetry.
435
            if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and
436
                len(syst.lead_interfaces[leadnr])):
437
                sites.extend(((site, leadnr, i) for site in
Joseph Weston's avatar
Joseph Weston committed
438
439
                              range(lead.cell_size) for i in
                              range(num_lead_cells)))
440
            lead_cells.append(slice(start, len(sites)))
441
442
    else:
        raise TypeError('Unrecognized system type.')
443
    return sites, lead_cells
444
445


446
447
def sys_leads_pos(sys, site_lead_nr):
    """Return an array of positions of sites in a system.
448
449
450

    Parameters
    ----------
451
452
    sys : `kwant.builder.Builder` or `kwant.system.System` instance
        The system, coordinates of sites of which should be returned.
453
    site_lead_nr : list of `(site, leadnr, copynr)` tuples
454
455
456
457
458
459
        Output of `sys_leads_sites` applied to the system.

    Returns
    -------
    coords : numpy.ndarray of floats
        Array of coordinates of the sites.
460
461
462

    Notes
    -----
463
464
465
    This function uses `site.pos` property to get the position of a builder
    site and `sys.pos(sitenr)` for finalized systems.  This function requires
    that all the positions of all the sites have the same dimensionality.
466
    """
467
468

    # Note about efficiency (also applies to sys_leads_hoppings_pos)
469
    # NumPy is really slow when making a NumPy array from a tinyarray
470
471
472
    # (buffer interface seems very slow). It's much faster to first
    # convert to a tuple and then to convert to numpy array ...

473
474
    syst = sys  # for naming consistency inside function bodies
    is_builder = isinstance(syst, builder.Builder)
475
    num_lead_cells = site_lead_nr[-1][2] + 1
476
    if is_builder:
477
        pos = np.array(ta.array([i[0].pos for i in site_lead_nr]))
478
    else:
479
480
481
        syst_from_lead = lambda lead: (syst if (lead is None)
                                      else syst.leads[lead])
        pos = np.array(ta.array([syst_from_lead(i[1]).pos(i[0])
482
                                 for i in site_lead_nr]))
483
484
485
486
487
488
489
490
491
    if pos.dtype == object:  # Happens if not all the pos are same length.
        raise ValueError("pos attribute of the sites does not have consistent"
                         " values.")
    dim = pos.shape[1]

    def get_vec_domain(lead_nr):
        if lead_nr is None:
            return np.zeros((dim,)), 0
        if is_builder:
492
            sym = syst.leads[lead_nr].builder.symmetry
493
            try:
494
                site = syst.leads[lead_nr].interface[0]
495
496
            except IndexError:
                return (0, 0)
497
        else:
498
            try:
499
500
                sym = syst.leads[lead_nr].symmetry
                site = syst.sites[syst.lead_interfaces[lead_nr][0]]
501
502
            except (AttributeError, IndexError):
                # empty leads, or leads without symmetry aren't drawn anyways
503
                return (0, 0)
504
        dom = sym.which(site)[0] + 1
505
        # Conversion to numpy array here useful for efficiency
506
507
        vec = np.array(sym.periods)[0]
        return vec, dom
508
    vecs_doms = dict((i, get_vec_domain(i)) for i in range(len(syst.leads)))
509
    vecs_doms[None] = np.zeros((dim,)), 0
Joseph Weston's avatar
Joseph Weston committed
510
511
    for k, v in vecs_doms.items():
        vecs_doms[k] = [v[0] * i for i in range(v[1], v[1] + num_lead_cells)]
512
513
514
515
    pos += [vecs_doms[i[1]][i[2]] for i in site_lead_nr]
    return pos


516
def sys_leads_hoppings(sys, num_lead_cells=2):
517
    """Return all the hoppings of the system and of the leads as an iterator.
518

519
520
521
522
    Parameters
    ----------
    sys : kwant.builder.Builder or kwant.system.System instance
        The system, sites of which should be returned.
523
    num_lead_cells : integer
524
525
526
        The number of times lead sites from each lead should be returned.
        This is useful for showing several unit cells of the lead next to the
        system.
527

528
529
530
    Returns
    -------
    hoppings : list of (hopping, lead_number, copy_number) tuples
531
        A site is a `~kwant.builder.Site` instance if the system is not finalized,
532
533
        and an integer otherwise.  For system sites `lead_number` is `None` and
        `copy_number` is `0`, for leads both are integers.
534
535
    lead_cells : list of slices
        `lead_cells[i]` gives the position of all the coordinates of lead
536
        `i` within `hoppings`.
537

538
539
540
    Notes
    -----
    Leads are only supported if they are of the same type as the original
541
542
    system, i.e.  hoppings of `~kwant.builder.BuilderLead` leads are returned with an
    unfinalized system, and hoppings of `~kwant.system.InfiniteSystem` leads are
543
544
    returned with a finalized system.
    """
545
546

    syst = sys  # for naming consistency inside function bodies
547
    hoppings = []
548
    lead_cells = []
549
550
    if isinstance(syst, builder.Builder):
        hoppings.extend(((hop, None, 0) for hop in syst.hoppings()))
551

552
553
554
555
556
557
558
559
560
561
562
563
564
        def lead_hoppings(lead):
            sym = lead.symmetry
            for site2, site1 in lead.hoppings():
                shift1 = sym.which(site1)[0]
                shift2 = sym.which(site2)[0]
                # We need to make sure that the hopping is between a site in a
                # fundamental domain and a site with a negative domain.  The
                # direction of the hopping is chosen arbitrarily
                # NOTE(Anton): This may need to be revisited with the future
                # builder format changes.
                shift = max(shift1, shift2)
                yield sym.act([-shift], site2), sym.act([-shift], site1)

565
        for leadnr, lead in enumerate(syst.leads):
566
            start = len(hoppings)
567
            if hasattr(lead, 'builder') and len(lead.interface):
568
                hoppings.extend(((hop, leadnr, i) for hop in
Anton Akhmerov's avatar
Anton Akhmerov committed
569
                                 lead_hoppings(lead.builder) for i in
Joseph Weston's avatar
Joseph Weston committed
570
                                 range(num_lead_cells)))
571
            lead_cells.append(slice(start, len(hoppings)))
572
573
574
575
    elif isinstance(syst, system.System):
        def ll_hoppings(syst):
            for i in range(syst.graph.num_nodes):
                for j in syst.graph.out_neighbors(i):
576
577
                    if i < j:
                        yield i, j
Anton Akhmerov's avatar
Anton Akhmerov committed
578

579
580
        hoppings.extend(((hop, None, 0) for hop in ll_hoppings(syst)))
        for leadnr, lead in enumerate(syst.leads):
581
            start = len(hoppings)
582
            # We will only plot leads with a graph and with a symmetry.
583
            if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and
584
                len(syst.lead_interfaces[leadnr])):
Anton Akhmerov's avatar
Anton Akhmerov committed
585
                hoppings.extend(((hop, leadnr, i) for hop in ll_hoppings(lead)
Joseph Weston's avatar
Joseph Weston committed
586
                                 for i in range(num_lead_cells)))
587
            lead_cells.append(slice(start, len(hoppings)))
588
    else:
589
        raise TypeError('Unrecognized system type.')
590
    return hoppings, lead_cells
591
592


593
594
def sys_leads_hopping_pos(sys, hop_lead_nr):
    """Return arrays of coordinates of all hoppings in a system.
595

596
597
    Parameters
    ----------
598
    sys : ``~kwant.builder.Builder`` or ``~kwant.system.System`` instance
599
        The system, coordinates of sites of which should be returned.
600
    hoppings : list of ``(hopping, leadnr, copynr)`` tuples
601
        Output of `sys_leads_hoppings` applied to the system.
602

603
604
    Returns
    -------
605
    coords : (end_site, start_site): tuple of NumPy arrays of floats
606
607
608
        Array of coordinates of the hoppings.  The first half of coordinates
        in each array entry are those of the first site in the hopping, the
        last half are those of the second site.
609

610
611
    Notes
    -----
612
613
    This function uses ``site.pos`` property to get the position of a builder
    site and ``sys.pos(sitenr)`` for finalized systems.  This function requires
614
615
    that all the positions of all the sites have the same dimensionality.
    """
616
617
618

    syst = sys  # for naming consistency inside function bodies
    is_builder = isinstance(syst, builder.Builder)
619
620
    if len(hop_lead_nr) == 0:
        return np.empty((0, 3)), np.empty((0, 3))
621
    num_lead_cells = hop_lead_nr[-1][2] + 1
622
    if is_builder:
623
        pos = np.array(ta.array([ta.array(tuple(i[0][0].pos) +
Anton Akhmerov's avatar
Anton Akhmerov committed
624
625
                                          tuple(i[0][1].pos)) for i in
                                 hop_lead_nr]))
626
    else:
627
628
629
630
        syst_from_lead = lambda lead: (syst if (lead is None) else
                                      syst.leads[lead])
        pos = ta.array([ta.array(tuple(syst_from_lead(i[1]).pos(i[0][0])) +
                                 tuple(syst_from_lead(i[1]).pos(i[0][1]))) for i
Anton Akhmerov's avatar
Anton Akhmerov committed
631
                        in hop_lead_nr])
632
        pos = np.array(pos)
633
634
635
636
637
638
639
640
641
    if pos.dtype == object:  # Happens if not all the pos are same length.
        raise ValueError("pos attribute of the sites does not have consistent"
                         " values.")
    dim = pos.shape[1]

    def get_vec_domain(lead_nr):
        if lead_nr is None:
            return np.zeros((dim,)), 0
        if is_builder:
642
            sym = syst.leads[lead_nr].builder.symmetry
643
            try:
644
                site = syst.leads[lead_nr].interface[0]
645
646
            except IndexError:
                return (0, 0)
647
        else:
648
            try:
649
650
                sym = syst.leads[lead_nr].symmetry
                site = syst.sites[syst.lead_interfaces[lead_nr][0]]
651
652
            except (AttributeError, IndexError):
                # empyt leads or leads without symmetry are not drawn anyways
653
                return (0, 0)
654
655
656
657
        dom = sym.which(site)[0] + 1
        vec = np.array(sym.periods)[0]
        return np.r_[vec, vec], dom

658
    vecs_doms = dict((i, get_vec_domain(i)) for i in range(len(syst.leads)))
659
    vecs_doms[None] = np.zeros((dim,)), 0
Joseph Weston's avatar
Joseph Weston committed
660
661
    for k, v in vecs_doms.items():
        vecs_doms[k] = [v[0] * i for i in range(v[1], v[1] + num_lead_cells)]
662
    pos += [vecs_doms[i[1]][i[2]] for i in hop_lead_nr]
663
    return np.copy(pos[:, : dim // 2]), np.copy(pos[:, dim // 2:])
664

665

666
667
# Useful plot functions (to be extended).

668
669
670
671
672
673
674
675
defaults = {'site_symbol': {2: 'o', 3: 'o'},
            'site_size': {2: 0.25, 3: 0.5},
            'site_color': {2: 'black', 3: 'white'},
            'site_edgecolor': {2: 'black', 3: 'black'},
            'site_lw': {2: 0, 3: 0.1},
            'hop_color': {2: 'black', 3: 'black'},
            'hop_lw': {2: 0.1, 3: 0},
            'lead_color': {2: 'red', 3: 'red'}}
676

677

Christoph Groth's avatar
Christoph Groth committed
678
def plot(sys, num_lead_cells=2, unit='nn',
679
680
681
682
683
684
685
         site_symbol=None, site_size=None,
         site_color=None, site_edgecolor=None, site_lw=None,
         hop_color=None, hop_lw=None,
         lead_site_symbol=None, lead_site_size=None, lead_color=None,
         lead_site_edgecolor=None, lead_site_lw=None,
         lead_hop_lw=None, pos_transform=None,
         cmap='gray', colorbar=True, file=None,
Anton Akhmerov's avatar
Anton Akhmerov committed
686
         show=True, dpi=None, fig_size=None, ax=None):
687
    """Plot a system in 2 or 3 dimensions.
688

689
690
    An alias exists for this common name: ``kwant.plot``.

691
692
693
694
    Parameters
    ----------
    sys : kwant.builder.Builder or kwant.system.FiniteSystem
        A system to be plotted.
695
    num_lead_cells : int
696
        Number of lead copies to be shown with the system.
697
698
    unit : 'nn', 'pt', or float
        The unit used to specify symbol sizes and linewidths.
699
700
        Possible choices are:

701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
        - 'nn': unit is the shortest hopping or a typical nearst neighbor
          distance in the system if there are no hoppings.  This means that
          symbol sizes/linewidths will scale as the zoom level of the figure is
          changed.  Very short distances are discarded before searching for the
          shortest.  This choice means that the symbols will scale if the
          figure is zoomed.
        - 'pt': unit is points (point = 1/72 inch) in figure space.  This means
          that symbols and linewidths will always be drawn with the same size
          independent of zoom level of the plot.
        - float: sizes are given in units of this value in real (system) space,
          and will accordingly scale as the plot is zoomed.

        The default value is 'nn', which allows to ensure that the images
        neighboring sites do not overlap.

716
    site_symbol : symbol specification, function, array, or `None`
717
718
        Symbol used for representing a site in the plot. Can be specified as

719
720
        - 'o': circle with radius of 1 unit.
        - 's': square with inner circle radius of 1 unit.
721
722
        - ``('p', nvert, angle)``: regular polygon with ``nvert`` vertices,
          rotated by ``angle``. ``angle`` is given in degrees, and ``angle=0``
723
724
          corresponds to one edge of the polygon pointing upward. The
          radius of the inner circle is 1 unit.
725
        - 'no symbol': no symbol is plotted.
726
727
        - 'S', `('P', nvert, angle)`: as the lower-case variants described
          above, but with an area equal to a circle of radius 1. (Makes
728
          the visual size of the symbol equal to the size of a circle with
Christoph Groth's avatar
Christoph Groth committed
729
          radius 1).
730
        - matplotlib.path.Path instance.
731
732
733
734
735

        Instead of a single symbol, different symbols can be specified
        for different sites by passing a function that returns a valid
        symbol specification for each site, or by passing an array of
        symbols specifications (only for kwant.system.FiniteSystem).
736
    site_size : number, function, array, or `None`
737
        Relative (linear) size of the site symbol.
738
    site_color : ``matplotlib`` color description, function, array, or `None`
739
740
        A color used for plotting a site in the system. If a colormap is used,
        it should be a function returning single floats or a one-dimensional
741
742
        array of floats. By default sites are colored by their site family,
        using the current matplotlib color cycle.
743
    site_edgecolor : ``matplotlib`` color description, function, array, or `None`
744
745
746
        Color used for plotting the edges of the site symbols. Only
        valid matplotlib color descriptions are allowed (and no
        combination of floats and colormap as for site_color).
747
    site_lw : number, function, array, or `None`
748
        Linewidth of the site symbol edges.
749
    hop_color : ``matplotlib`` color description or a function
750
751
752
753
754
755
756
757
758
759
760
761
762
        Same as `site_color`, but for hoppings.  A function is passed two sites
        in this case. (arrays are not allowed in this case).
    hop_lw : number, function, or `None`
        Linewidth of the hoppings.
    lead_site_symbol : symbol specification or `None`
        Symbol to be used for the leads. See `site_symbol` for allowed
        specifications. Note that for leads, only constants
        (i.e. no functions or arrays) are allowed. If None, then
        `site_symbol` is used if it is constant (i.e. no function or array),
        the default otherwise. The same holds for the other lead properties
        below.
    lead_site_size : number or `None`
        Relative (linear) size of the lead symbol
763
    lead_color : ``matplotlib`` color description or `None`
764
        For the leads, `num_lead_cells` copies of the lead unit cell
765
766
767
768
        are plotted. They are plotted in color fading from `lead_color`
        to white (alpha values in `lead_color` are supported) when moving
        from the system into the lead. Is also applied to the
        hoppings.
769
    lead_site_edgecolor : ``matplotlib`` color description or `None`
770
771
772
773
774
        Color of the symbol edges (no fading done).
    lead_site_lw : number or `None`
        Linewidth of the lead symbols.
    lead_hop_lw : number or `None`
        Linewidth of the lead hoppings.
775
    cmap : ``matplotlib`` color map or a sequence of two color maps or `None`
776
        The color map used for sites and optionally hoppings.
777
    pos_transform : function or `None`
778
779
        Transformation to be applied to the site position.
    colorbar : bool
780
781
        Whether to show a colorbar if colormap is used. Ignored if `ax` is
        provided.
782
783
784
    file : string or file object or `None`
        The output file.  If `None`, output will be shown instead.
    show : bool
785
        Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
786
        to be shown immediately.  Defaults to `True`.
787
    dpi : float or `None`
788
        Number of pixels per inch.  If not set the ``matplotlib`` default is
789
        used.
790
    fig_size : tuple or `None`
791
        Figure size `(width, height)` in inches.  If not set, the default
792
793
        ``matplotlib`` value is used.
    ax : ``matplotlib.axes.Axes`` instance or `None`
794
795
        If `ax` is not `None`, no new figure is created, but the plot is done
        within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
796
        and `fig_size` are ignored.
797

798
799
    Returns
    -------
800
801
    fig : matplotlib figure
        A figure with the output if `ax` is not set, else None.
802

803
804
    Notes
    -----
805
806
807
808
    - If `None` is passed for a plot property, a default value depending on
      the dimension is chosen. Typically, the default values result in
      acceptable plots.

809
810
811
812
    - The meaning of "site" depends on whether the system to be plotted is a
      builder or a low level system.  For builders, a site is a
      kwant.builder.Site object.  For low level systems, a site is an integer
      -- the site number.
813

814
815
816
    - color and symbol definitions may be tuples, but not lists or arrays.
      Arrays of values (linewidths, colors, sizes) may not be tuples.

817
818
819
820
821
822
823
    - The dimensionality of the plot (2D vs 3D) is inferred from the coordinate
      array.  If there are more than three coordinates, only the first three
      are used.  If there is just one coordinate, the second one is padded with
      zeros.

    - The system is scaled to fit the smaller dimension of the figure, given
      its aspect ratio.
824

825
    """
826
    if not _p.mpl_available:
827
828
        raise RuntimeError("matplotlib was not found, but is required "
                           "for plot()")
829

830
    syst = sys  # for naming consistency inside function bodies
831
    # Generate data.
832
833
834
835
836
837
    sites, lead_sites_slcs = sys_leads_sites(syst, num_lead_cells)
    n_syst_sites = sum(i[1] is None for i in sites)
    sites_pos = sys_leads_pos(syst, sites)
    hops, lead_hops_slcs = sys_leads_hoppings(syst, num_lead_cells)
    n_syst_hops = sum(i[1] is None for i in hops)
    end_pos, start_pos = sys_leads_hopping_pos(syst, hops)
838
839
840
841
842

    # Choose plot type.
    def resize_to_dim(array):
        if array.shape[1] != dim:
            ar = np.zeros((len(array), dim), dtype=float)
843
844
            ar[:, : min(dim, array.shape[1])] = array[
                :, : min(dim, array.shape[1])]
845
846
847
848
            return ar
        else:
            return array

849
850
851
852
853
854
855
    loc = locals()

    def check_length(name):
        value = loc[name]
        if name in ('site_size', 'site_lw') and isinstance(value, tuple):
            raise TypeError('{0} may not be a tuple, use list or '
                            'array instead.'.format(name))
Joseph Weston's avatar
Joseph Weston committed
856
        if isinstance(value, (str, tuple)):
857
858
            return
        try:
859
            if len(value) != n_syst_sites:
860
861
862
863
864
865
866
867
868
                raise ValueError('Length of {0} is not equal to number of '
                                 'system sites.'.format(name))
        except TypeError:
            pass

    for name in ['site_symbol', 'site_size', 'site_color', 'site_edgecolor',
                 'site_lw']:
        check_length(name)

869
870
871
872
873
874
    # Apply transformations to the data
    if pos_transform is not None:
        sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos)
        end_pos = np.apply_along_axis(pos_transform, 1, end_pos)
        start_pos = np.apply_along_axis(pos_transform, 1, start_pos)

875
    dim = 3 if (sites_pos.shape[1] == 3) else 2
876
    if dim == 3 and not _p.has3d:
877
        raise RuntimeError("Installed matplotlib does not support 3d plotting")
878
879
880
881
    sites_pos = resize_to_dim(sites_pos)
    end_pos = resize_to_dim(end_pos)
    start_pos = resize_to_dim(start_pos)

882
    # Determine the reference length.
Christoph Groth's avatar
Christoph Groth committed
883
    if unit == 'pt':
884
        reflen = None
Christoph Groth's avatar
Christoph Groth committed
885
    elif unit == 'nn':
886
        if n_syst_hops:
887
888
            # If hoppings are present use their lengths to determine the
            # minimal one.
889
            distances = end_pos - start_pos
890
        else:
891
892
893
            # If no hoppings are present, use for the same purpose distances
            # from ten randomly selected points to the remaining points in the
            # system.
894
            points = _sample_array(sites_pos, 10).T
895
            distances = (sites_pos.reshape(1, -1, dim) -
896
                         points.reshape(-1, 1, dim)).reshape(-1, dim)
897
        distances = np.sort(np.sum(distances**2, axis=1))
898
899
900
901
902
        # Then check if distances are present that are way shorter than the
        # longest one. Then take first distance longer than these short
        # ones. This heuristic will fail for too large systems, or systems with
        # hoppings that vary by orders and orders of magnitude, but for sane
        # cases it will work.
903
904
        long_dist_coord = np.searchsorted(distances, 1e-16 * distances[-1])
        reflen = sqrt(distances[long_dist_coord])
905

906
    else:
907
        # The last allowed value is float-compatible.
908
        try:
909
            reflen = float(unit)
910
        except:
911
            raise ValueError('Invalid value of unit argument.')
912
913
914

    # make all specs proper: either constant or lists/np.arrays:
    def make_proper_site_spec(spec, fancy_indexing=False):
915
        if callable(spec):
916
            spec = [spec(i[0]) for i in sites if i[1] is None]
917
        if (fancy_indexing and _p.isarray(spec)
918
919
920
921
922
923
924
925
            and not isinstance(spec, np.ndarray)):
            try:
                spec = np.asarray(spec)
            except:
                spec = np.asarray(spec, dtype='object')
        return spec

    def make_proper_hop_spec(spec, fancy_indexing=False):
926
        if callable(spec):
927
            spec = [spec(*i[0]) for i in hops if i[1] is None]
928
        if (fancy_indexing and _p.isarray(spec)
929
930
931
932
933
934
935
936
            and not isinstance(spec, np.ndarray)):
            try:
                spec = np.asarray(spec)
            except:
                spec = np.asarray(spec, dtype='object')
        return spec

    site_symbol = make_proper_site_spec(site_symbol)
937
    if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
938
939
    # separate different symbols (not done in 3D, the separation
    # would mess up sorting)
940
    if (_p.isarray(site_symbol) and dim != 3 and
941
942
943
944
945
        (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))):
        symbol_dict = defaultdict(list)
        for i, symbol in enumerate(site_symbol):
            symbol_dict[symbol].append(i)
        symbol_slcs = []
Joseph Weston's avatar
Joseph Weston committed
946
        for symbol, indx in symbol_dict.items():
947
948
949
            symbol_slcs.append((symbol, np.array(indx)))
        fancy_indexing = True