plotter.py 65.2 KB
Newer Older
1
# Copyright 2011-2013 Kwant authors.
2
#
Christoph Groth's avatar
Christoph Groth committed
3
4
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
5
# http://kwant-project.org/license.  A list of Kwant authors can be found in
Christoph Groth's avatar
Christoph Groth committed
6
# the file AUTHORS.rst at the top-level directory of this distribution and at
7
8
# http://kwant-project.org/authors.

9
"""Plotter module for Kwant.
10

11
12
13
14
15
16
This module provides iterators useful for any plotter routine, such as a list
of system sites, their coordinates, lead sites at any lead unit cell, etc.  If
`matplotlib` is available, it also provides simple functions for plotting the
system in two or three dimensions.
"""

17
from collections import defaultdict
18
import warnings
19
import numpy as np
20
import tinyarray as ta
21
from scipy import spatial, interpolate
22
23
from math import cos, sin, pi, sqrt

24
25
26
27
# All matplotlib imports must be isolated in a try, because even without
# matplotlib iterators remain useful.  Further, mpl_toolkits used for 3D
# plotting are also imported separately, to ensure that 2D plotting works even
# if 3D does not.
28
try:
29
30
31
32
    import matplotlib
    from matplotlib.figure import Figure
    from matplotlib import collections
    from matplotlib.backends.backend_agg import FigureCanvasAgg
33
    mpl_enabled = True
34
35
    try:
        from mpl_toolkits import mplot3d
36
        has3d = True
37
    except ImportError:
Christoph Groth's avatar
Christoph Groth committed
38
        warnings.warn("3D plotting not available.", RuntimeWarning)
39
        has3d = False
40
except ImportError:
Christoph Groth's avatar
Christoph Groth committed
41
    warnings.warn("matplotlib is not available, only iterator-providing "
Christoph Groth's avatar
Christoph Groth committed
42
                  "functions will work.", RuntimeWarning)
43
    mpl_enabled = False
44

45
from . import system, builder, physics
46

47
__all__ = ['plot', 'map', 'bands', 'sys_leads_sites', 'sys_leads_hoppings',
48
49
           'sys_leads_pos', 'sys_leads_hopping_pos', 'mask_interpolate']

Anton Akhmerov's avatar
Anton Akhmerov committed
50

51
52
# Collections that allow for symbols and linewiths to be given in data space
# (not for general use, only implement what's needed for plotter)
53
def isarray(var):
54
55
56
57
58
    if hasattr(var, '__getitem__') and not isinstance(var, basestring):
        return True
    else:
        return False

Anton Akhmerov's avatar
Anton Akhmerov committed
59

60
61
def nparray_if_array(var):
    return np.asarray(var) if isarray(var) else var
62

63

64
65
66
67
if mpl_enabled:
    class LineCollection(collections.LineCollection):
        def __init__(self, segments, reflen=None, **kwargs):
            super(LineCollection, self).__init__(segments, **kwargs)
68
69
70
            self.reflen = reflen

        def set_linewidths(self, linewidths):
71
            self.linewidths_orig = nparray_if_array(linewidths)
72
73

        def draw(self, renderer):
74
75
76
77
78
            if self.reflen is not None:
                # Note: only works for aspect ratio 1!
                #       72.0 - there is 72 points in an inch
                factor = (self.axes.transData.frozen().to_values()[0] * 72.0 *
                          self.reflen / self.figure.dpi)
79
80
            else:
                factor = 1
81

82
            super(LineCollection, self).set_linewidths(self.linewidths_orig *
83
                                                       factor)
84
            return super(LineCollection, self).draw(renderer)
Anton Akhmerov's avatar
Anton Akhmerov committed
85

86

87
88
89
    class PathCollection(collections.PathCollection):
        def __init__(self, paths, sizes=None, reflen=None, **kwargs):
            super(PathCollection, self).__init__(paths, sizes=sizes, **kwargs)
90
91

            self.reflen = reflen
92
            self.linewidths_orig = nparray_if_array(self.get_linewidths())
93

Anton Akhmerov's avatar
Anton Akhmerov committed
94
95
96
            self.transforms = np.array(
                [matplotlib.transforms.Affine2D().scale(x).get_matrix()
                 for x in sizes])
97
98

        def get_transforms(self):
99
            return self.transforms
100
101

        def get_transform(self):
Michael Wimmer's avatar
Michael Wimmer committed
102
            Affine2D = matplotlib.transforms.Affine2D
103
            if self.reflen is not None:
Anton Akhmerov's avatar
Anton Akhmerov committed
104
                # For the paths, use the data transformation but strip the
105
                # offset (will be added later with offsets)
Anton Akhmerov's avatar
Anton Akhmerov committed
106
                args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
107
                return Affine2D().from_values(*args).scale(self.reflen)
108
            else:
Michael Wimmer's avatar
Michael Wimmer committed
109
                return Affine2D().scale(self.figure.dpi / 72.0)
110

111
112
113
114
115
        def draw(self, renderer):
            if self.reflen:
                # Note: only works for aspect ratio 1!
                factor = (self.axes.transData.frozen().to_values()[0] /
                          self.figure.dpi * 72.0 * self.reflen)
116
                self.set_linewidths(self.linewidths_orig * factor)
117

118
            return collections.Collection.draw(self, renderer)
119
120


121
122
123
    if has3d:
        # Sorting is optional.
        sort3d = True
124

125
126
127
        # Compute the projection of a 3D length into 2D data coordinates
        # for this we use 2 3D half-circles that are projected into 2D.
        # (This gives the same length as projecting the full unit sphere.)
128

129
130
131
132
133
        phi = np.linspace(0, pi, 21)
        xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21)
        unit_sphere = np.bmat([[xyz[0], xyz[2]], [xyz[1], xyz[0]],
                                [xyz[2], xyz[1]]])
        unit_sphere = np.asarray(unit_sphere)
134

135
136
137
        def projected_length(ax, length):
            rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()])
            rc = np.apply_along_axis(np.sum, 1, rc) / 2.
138

139
140
141
142
143
144
145
146
            rs = unit_sphere * length + rc.reshape(-1, 1)

            transform = mplot3d.proj3d.proj_transform
            rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2])
            rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2]

            coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1)
            return sqrt(np.sum(coords**2, axis=0).max())
147
148


149
150
151
152
        # Auxiliary array for calculating corners of a cube.
        corners = np.zeros((3, 8, 6), np.float_)
        corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \
        corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \
153
        corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0
154
155
156
157
158
159


        class Line3DCollection(mplot3d.art3d.Line3DCollection):
            def __init__(self, segments, reflen=None, zorder=0, **kwargs):
                super(Line3DCollection, self).__init__(segments, **kwargs)
                self.reflen = reflen
160
                self.zorder3d = zorder
161
162

            def set_linewidths(self, linewidths):
163
                self.linewidths_orig = nparray_if_array(linewidths)
164
165
166
167
168
169
170
171

            def do_3d_projection(self, renderer):
                super(Line3DCollection, self).do_3d_projection(renderer)
                # The whole 3D ordering is flawed in mplot3d when several
                # collections are added. We just use normal zorder. Note the
                # "-" due to the different logic in the 3d plotting, we still
                # want larger zorder values to be plotted on top of smaller
                # ones.
172
                return -self.zorder3d
173
174
175
176
177
178
179
180
181
182
183
184

            def draw(self, renderer):
                if self.reflen:
                    proj_len = projected_length(self.axes, self.reflen)
                    args = self.axes.transData.frozen().to_values()
                    # Note: unlike in the 2D case, where we can enforce equal
                    #       aspect ratio, this (currently) does not work with
                    #       3D plots in matplotlib. As an approximation, we
                    #       thus scale with the average of the x- and y-axis
                    #       transformation.
                    factor = proj_len * (args[0] +
                                         args[3]) * 0.5 * 72.0 / self.figure.dpi
185
186
                else:
                    factor = 1
187

188
                super(Line3DCollection, self).set_linewidths(
189
                                                self.linewidths_orig * factor)
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
                super(Line3DCollection, self).draw(renderer)


        class Path3DCollection(mplot3d.art3d.Patch3DCollection):
            def __init__(self, paths, sizes, reflen=None, zorder=0,
                         offsets=None, **kwargs):
                paths = [matplotlib.patches.PathPatch(path) for path in paths]

                if offsets is not None:
                    kwargs['offsets'] = offsets[:, :2]

                super(Path3DCollection, self).__init__(paths, **kwargs)

                if offsets is not None:
                    self.set_3d_properties(zs=offsets[:, 2], zdir="z")

                self.reflen = reflen
207
                self.zorder3d = zorder
208

209
210
211
212
213
214
                self.paths_orig = np.array(paths, dtype='object')
                self.linewidths_orig = nparray_if_array(self.get_linewidths())
                self.linewidths_orig2 = self.linewidths_orig
                self.array_orig = nparray_if_array(self.get_array())
                self.facecolors_orig = nparray_if_array(self.get_facecolors())
                self.edgecolors_orig = nparray_if_array(self.get_edgecolors())
215
216

                Affine2D = matplotlib.transforms.Affine2D
Anton Akhmerov's avatar
Anton Akhmerov committed
217
218
                self.orig_transforms = np.array(
                    [Affine2D().scale(x).get_matrix() for x in sizes])
219
                self.transforms = self.orig_transforms
220
221

            def set_array(self, array):
222
                self.array_orig = nparray_if_array(array)
223
224
225
                super(Path3DCollection, self).set_array(array)

            def set_color(self, colors):
226
227
                self.facecolors_orig = nparray_if_array(colors)
                self.edgecolors_orig = self.facecolors_orig
228
229
230
231
                super(Path3DCollection, self).set_color(colors)

            def set_edgecolors(self, colors):
                colors = matplotlib.colors.colorConverter.to_rgba_array(colors)
232
                self.edgecolors_orig = nparray_if_array(colors)
233
234
235
236
237
                super(Path3DCollection, self).set_edgecolors(colors)

            def get_transforms(self):
                # this is exact only for an isometric projection, for the
                # perspective projection used in mplot3d it's an approximation
238
                return self.transforms
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

            def get_transform(self):
                Affine2D = matplotlib.transforms.Affine2D
                if self.reflen:
                    proj_len = projected_length(self.axes, self.reflen)

                    # For the paths, use the data transformation but strip the
                    # offset (will be added later with the offsets).
                    args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
                    return Affine2D().from_values(*args).scale(proj_len)
                else:
                    return Affine2D().scale(self.figure.dpi / 72.0)

            def do_3d_projection(self, renderer):
                xs, ys, zs = self._offsets3d

                # numpy complains about zero-length index arrays
                if len(xs) == 0:
257
                    return -self.zorder3d
258
259
260
261
262
263
264
265
266

                proj = mplot3d.proj3d.proj_transform_clip
                vs = np.array(proj(xs, ys, zs, renderer.M)[:3])

                if sort3d:
                    indx = vs[2].argsort()[::-1]

                    self.set_offsets(vs[:2, indx].T)

267
268
                    if len(self.paths_orig) > 1:
                        paths = np.resize(self.paths_orig, (vs.shape[1],))
269
270
                        self.set_paths(paths[indx])

271
272
                    if len(self.orig_transforms) > 1:
                        self.transforms = np.resize(self.orig_transforms,
273
                                                     (vs.shape[1],))
274
                        self.transforms = self.transforms[indx]
275

276
                    lw_orig = self.linewidths_orig
277
                    if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1):
278
                        self.linewidths_orig2 = np.resize(lw_orig,
279
280
281
282
283
284
                                                           (vs.shape[1],))[indx]

                    # Note: here array, facecolors and edgecolors are
                    #       guaranteed to be 2d numpy arrays or None.  (And
                    #       array is the same length as the coordinates)

285
                    if self.array_orig is not None:
286
                        super(Path3DCollection,
287
                              self).set_array(self.array_orig[indx])
288

289
290
291
                    if (self.facecolors_orig is not None and
                        self.facecolors_orig.shape[0] > 1):
                        shape = list(self.facecolors_orig.shape)
292
                        shape[0] = vs.shape[1]
293
                        super(Path3DCollection, self).set_facecolors(
294
                            np.resize(self.facecolors_orig, shape)[indx])
295

296
297
298
                    if (self.edgecolors_orig is not None and
                        self.edgecolors_orig.shape[0] > 1):
                        shape = list(self.edgecolors_orig.shape)
299
                        shape[0] = vs.shape[1]
300
                        super(Path3DCollection, self).set_edgecolors(
301
                                                np.resize(self.edgecolors_orig,
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
                                                          shape)[indx])
                else:
                    self.set_offsets(vs[:2].T)

                # the whole 3D ordering is flawed in mplot3d when several
                # collections are added. We just use normal zorder, but correct
                # by the projected z-coord of the "center of gravity",
                # normalized by the projected z-coord of the world coordinates.
                # In doing so, several Path3DCollections are plotted probably
                # in the right order (it's not exact) if they have the same
                # zorder. Still, smaller and larger integer zorders are plotted
                # below or on top.

                bbox = np.asarray(self.axes.get_w_lims())

                proj = mplot3d.proj3d.proj_transform_clip
                cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2]

320
                return -self.zorder3d + vs[2].mean() / cz.ptp()
321
322
323
324
325
326
327
328

            def draw(self, renderer):
                if self.reflen:
                    proj_len = projected_length(self.axes, self.reflen)
                    args = self.axes.transData.frozen().to_values()
                    factor = proj_len * (args[0] +
                                         args[3]) * 0.5 * 72.0 / self.figure.dpi

329
                    self.set_linewidths(self.linewidths_orig2 * factor)
330
331

                super(Path3DCollection, self).draw(renderer)
332
333


334
335
# matplotlib helper functions.

336
def set_colors(color, collection, cmap, norm=None):
337
    """Process a color specification to a format accepted by collections.
338

339
340
341
342
343
344
345
346
347
348
    Parameters
    ----------
    color : color specification
    collection : instance of a subclass of `matplotlib.collections.Collection`
        Collection to which the color is added.
    cmap : `matplotlib` color map specification or None
        Color map to be used if colors are specified as floats.
    norm : `matplotlib` color norm
        Norm to be used if colors are specified as floats.
    """
349
350
351
352

    length = max(len(collection.get_paths()), len(collection.get_offsets()))

    # matplotlib gets confused if dtype='object'
Anton Akhmerov's avatar
Anton Akhmerov committed
353
    if (isinstance(color, np.ndarray) and color.dtype == np.dtype('object')):
354
355
        color = tuple(color)

356
357
    if isinstance(collection, mplot3d.art3d.Line3DCollection):
        length = len(collection._segments3d)  # Once again, matplotlib fault!
358

359
    if isarray(color) and len(color) == length:
360
361
362
363
364
365
366
367
368
369
370
371
372
        try:
            # check if it is an array of floats for color mapping
            color = np.asarray(color, dtype=float)
            if color.ndim == 1:
                collection.set_array(color)
                collection.set_cmap(cmap)
                collection.set_norm(norm)
                collection.set_color(None)
                return
        except (TypeError, ValueError):
            pass

    colors = matplotlib.colors.colorConverter.to_rgba_array(color)
373
    collection.set_color(colors)
374
375


Anton Akhmerov's avatar
Anton Akhmerov committed
376
symbol_dict = {'O': 'o', 's': ('p', 4, 45), 'S': ('P', 4, 45)}
377
378

def get_symbol(symbols):
Anton Akhmerov's avatar
Anton Akhmerov committed
379
380
    """Return the path corresponding to the description in `symbol`"""
    # Figure out if list of symbols or single symbol.
381
382
383
    if not hasattr(symbols, '__getitem__'):
        symbols = [symbols]
    elif len(symbols) == 3 and symbols[0] in ('p', 'P'):
Anton Akhmerov's avatar
Anton Akhmerov committed
384
385
        # Most likely a polygon specification (at least not a valid other
        # symbol).
386
387
        symbols = [symbols]

Anton Akhmerov's avatar
Anton Akhmerov committed
388
389
    symbols = [symbol_dict[symbol] if symbol in symbol_dict else symbol for
               symbol in symbols]
390
391
392
393
394
395
396
397
398
399

    paths = []
    for symbol in symbols:
        if isinstance(symbol, matplotlib.path.Path):
            return symbol
        elif hasattr(symbol, '__getitem__') and len(symbol) == 3:
            kind, n, angle = symbol

            if kind in ['p', 'P']:
                if kind == 'p':
Anton Akhmerov's avatar
Anton Akhmerov committed
400
                    radius = 1. / cos(pi / n)
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
                else:
                    # make the polygon such that it has area equal
                    # to a unit circle
                    radius = sqrt(2 * pi / (n * sin(2 * pi / n)))

                angle = pi * angle / 180
                patch = matplotlib.patches.RegularPolygon((0, 0), n,
                                                          radius=radius,
                                                          orientation=angle)
            else:
                raise ValueError("Unknown symbol definition " + str(symbol))
        elif symbol == 'o':
            patch = matplotlib.patches.Circle((0, 0), 1)

        paths.append(patch.get_path().transformed(patch.get_transform()))

    return paths


Anton Akhmerov's avatar
Anton Akhmerov committed
420
421
def symbols(axes, pos, symbol='o', size=1, reflen=None, facecolor='k',
            edgecolor='k', linewidth=None, cmap=None, norm=None, zorder=0,
422
            **kwargs):
Anton Akhmerov's avatar
Anton Akhmerov committed
423
    """Add a collection of symbols (2D or 3D) to an axes instance.
424
425
426

    Parameters
    ----------
427
428
    axes : matplotlib.axes.Axes instance
        Axes to which the lines have to be added.
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    pos0 : 2d or 3d array_like
        Coordinates of each symbol.
    symbol: symbol definition.
        TODO To be written.
    size: float or 1d array
        Size(s) of the symbols. Defaults to 1.
    reflen: float or None, optional
        If `reflen` is `None`, the symbol sizes and linewidths are
        given in points (absolute size in the figure space). If
        `reflen` is a number, the symbol sizes and linewidths are
        given in units of `reflen` in data space (i.e. scales with the
        scale of the plot). Defaults to `None`.
    facecolor: color definition, optional
    edgecolor: color definition, optional
        Defines the fill and edge color of the symbol, repsectively.
        Either a single object that is a proper matplotlib color
        definition or a sequence of such objects of appropriate
        length.  Defaults to all black.
447
448
449
450
    cmap : `matplotlib` color map specification or None
        Color map to be used if colors are specified as floats.
    norm : `matplotlib` color norm
        Norm to be used if colors are specified as floats.
451
452
453
454
455
456
    zorder: int
        Order in which different collections are drawn: larger
        `zorder` means the collection is drawn over collections with
        smaller `zorder` values.
    **kwargs : dict keyword arguments to
        pass to `PathCollection` or `Path3DCollection`, respectively.
457

458
459
    Returns
    -------
460
461
    `PathCollection` or `Path3DCollection` instance containing all the
    symbols that were added.
462
    """
463

464
465
    dim = pos.shape[1]
    assert dim == 2 or dim == 3
466

467
468
469
    #internally, size must be array_like
    try:
        size[0]
Anton Akhmerov's avatar
Anton Akhmerov committed
470
    except TypeError:
471
472
473
474
475
476
477
478
        size = (size, )

    if dim == 2:
        Collection = PathCollection
    else:
        Collection = Path3DCollection

    if len(pos) == 0 or np.all(symbol == 'no symbol') or np.all(size == 0):
Anton Akhmerov's avatar
Anton Akhmerov committed
479
480
481
482
        paths = []
        pos = np.empty((0, dim))
    else:
        paths = get_symbol(symbol)
483

Anton Akhmerov's avatar
Anton Akhmerov committed
484
485
    coll = Collection(paths, sizes=size, reflen=reflen, linewidths=linewidth,
                      offsets=pos, transOffset=axes.transData, zorder=zorder)
486

487
488
    set_colors(facecolor, coll, cmap, norm)
    coll.set_edgecolors(edgecolor)
Anton Akhmerov's avatar
Anton Akhmerov committed
489

490
    coll.update(kwargs)
491

492
493
494
495
    if dim == 2:
        axes.add_collection(coll)
    else:
        axes.add_collection3d(coll)
496

497
    return coll
498
499


500
501
def lines(axes, pos0, pos1, reflen=None, colors='k', linestyles='solid',
          cmap=None, norm=None, zorder=0, **kwargs):
Anton Akhmerov's avatar
Anton Akhmerov committed
502
    """Add a collection of line segments (2D or 3D) to an axes instance.
503
504
505

    Parameters
    ----------
506
507
    axes : matplotlib.axes.Axes instance
        Axes to which the lines have to be added.
508
    pos0 : 2d or 3d array_like
509
        Starting coordinates of each line segment
510
511
512
513
514
515
516
    pos1 : 2d or 3d array_like
        Ending coordinates of each line segment
    reflen: float or None, optional
        If `reflen` is `None`, the linewidths are given in points (absolute
        size in the figure space). If `reflen` is a number, the linewidths
        are given in units of `reflen` in data space (i.e. scales with
        the scale of the plot). Defaults to `None`.
517
518
519
520
521
522
523
524
525
526
527
528
    colors : color definition, optional
        Either a single object that is a proper matplotlib color definition
        or a sequence of such objects of appropriate length.  Defaults to all
        segments black.
    linestyles :linestyle definition, optional
        Either a single object that is a proper matplotlib line style
        definition or a sequence of such objects of appropriate length.
        Defaults to all segments solid.
    cmap : `matplotlib` color map specification or None
        Color map to be used if colors are specified as floats.
    norm : `matplotlib` color norm
        Norm to be used if colors are specified as floats.
529
530
531
532
533
534
    zorder: int
        Order in which different collections are drawn: larger
        `zorder` means the collection is drawn over collections with
        smaller `zorder` values.
    **kwargs : dict keyword arguments to
        pass to `LineCollection` or `Line3DCollection`, respectively.
535
536
537

    Returns
    -------
538
    `LineCollection` or `Line3DCollection` instance containing all the
539
    segments that were added.
540
541
    """

542
    if not pos0.shape == pos1.shape:
543
        raise ValueError('Incompatible lengths of coordinate arrays.')
544

545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
    dim = pos0.shape[1]
    assert dim == 2 or dim == 3
    if dim == 2:
        Collection = LineCollection
    else:
        Collection = Line3DCollection

    if (len(pos0) == 0 or
        ('linewidths' in kwargs and kwargs['linewidths'] == 0)):
        coll = Collection([], reflen=reflen, linestyles=linestyles,
                          zorder=zorder)
        coll.update(kwargs)
        if dim == 2:
            axes.add_collection(coll)
        else:
            axes.add_collection3d(coll)
561
        return coll
562

563
    segments = np.c_[pos0, pos1].reshape(pos0.shape[0], 2, dim)
564

565
566
    coll = Collection(segments, reflen=reflen, linestyles=linestyles,
                      zorder=zorder)
567
    set_colors(colors, coll, cmap, norm)
568
    coll.update(kwargs)
569

570
571
572
573
    if dim == 2:
        axes.add_collection(coll)
    else:
        axes.add_collection3d(coll)
574

575
    return coll
576
577


578
579
580
def output_fig(fig, output_mode='auto', file=None, savefile_opts=None,
               show=True):
    """Output a matplotlib figure using a given output mode.
581

582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
    Parameters
    ----------
    fig : matplotlib.figure.Figure instance
        The figure to be output.
    output_mode : string
        The output mode to be used.  Can be one of the following:
        'pyplot' : attach the figure to pyplot, with the same behavior as if
        pyplot.plot was called to create this figure.
        'ipython' : attach a `FigureCanvasAgg` to the figure and return it.
        'return' : return the figure.
        'file' : same as 'ipython', but also save the figure into a file.
        'auto' : if fname is given, save to a file, else if pyplot
        is imported, attach to pyplot, otherwise just return.  See also the
        notes below.
    file : string or a file object
        The name of the target file or the target file itself
        (opened for writing).
    savefile_opts : (list, dict) or None
        args and kwargs passed to `print_figure` of `matplotlib`
    show : bool
        Whether to call `matplotlib.pyplot.show()`.  Only has an effect if the
        output uses pyplot.
604

605
606
607
608
609
610
    Notes
    -----
    The behavior of this function producing a file is different from that of
    matplotlib in that the `dpi` attribute of the figure is used by defaul
    instead of the matplotlib config setting.
    """
611
    if not mpl_enabled:
612
613
614
615
616
617
        raise RuntimeError('matplotlib is not installed.')
    if output_mode == 'auto':
        if file is not None:
            output_mode = 'file'
        else:
            try:
618
619
                matplotlib.pyplot.get_backend()
                output_mode = 'pyplot'
620
621
622
623
624
625
            except AttributeError:
                output_mode = 'pyplot'
    if output_mode == 'pyplot':
        try:
            fake_fig = matplotlib.pyplot.figure()
        except AttributeError:
626
627
            msg = ('matplotlib.pyplot is unavailable.  Execute `import '
                   'matplotlib.pyplot` or use a different output mode.')
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
            raise RuntimeError(msg)
        fake_fig.canvas.figure = fig
        fig.canvas = fake_fig.canvas
        for ax in fig.axes:
            try:
                ax.mouse_init()  # Make 3D interface interactive.
            except AttributeError:
                pass
        if show:
            matplotlib.pyplot.show()
        return fig
    elif output_mode == 'return':
        canvas = FigureCanvasAgg(fig)
        fig.canvas = canvas
        return fig
    elif output_mode == 'file':
        canvas = FigureCanvasAgg(fig)
        if savefile_opts is None:
            savefile_opts = ([], {})
        if 'dpi' not in savefile_opts[1]:
            savefile_opts[1]['dpi'] = fig.dpi
Anton Akhmerov's avatar
Anton Akhmerov committed
649
        canvas.print_figure(file, *savefile_opts[0], **savefile_opts[1])
650
651
652
        return fig
    else:
        assert False, 'Unknown output_mode'
653

654

655
# Extracting necessary data from the system.
656

657
def sys_leads_sites(sys, num_lead_cells=2):
658
    """Return all the sites of the system and of the leads as a list.
659

660
661
662
663
    Parameters
    ----------
    sys : kwant.builder.Builder or kwant.system.System instance
        The system, sites of which should be returned.
664
    num_lead_cells : integer
665
666
667
        The number of times lead sites from each lead should be returned.
        This is useful for showing several unit cells of the lead next to the
        system.
668

669
670
671
672
673
674
    Returns
    -------
    sites : list of (site, lead_number, copy_number) tuples
        A site is a `builder.Site` instance if the system is not finalized,
        and an integer otherwise.  For system sites `lead_number` is `None` and
        `copy_number` is `0`, for leads both are integers.
675
676
    lead_cells : list of slices
        `lead_cells[i]` gives the position of all the coordinates of lead
677
        `i` within `sites`.
678

679
680
681
682
683
684
685
    Notes
    -----
    Leads are only supported if they are of the same type as the original
    system, i.e.  sites of `builder.BuilderLead` leads are returned with an
    unfinalized system, and sites of `system.InfiniteSystem` leads are
    returned with a finalized system.
    """
686
    lead_cells = []
687
688
689
    if isinstance(sys, builder.Builder):
        sites = [(site, None, 0) for site in sys.sites()]
        for leadnr, lead in enumerate(sys.leads):
690
            start = len(sites)
691
            if hasattr(lead, 'builder') and len(lead.interface):
692
693
                sites.extend(((site, leadnr, i) for site in
                              lead.builder.sites() for i in
694
695
                              xrange(num_lead_cells)))
            lead_cells.append(slice(start, len(sites)))
696
697
698
    elif isinstance(sys, system.FiniteSystem):
        sites = [(i, None, 0) for i in xrange(sys.graph.num_nodes)]
        for leadnr, lead in enumerate(sys.leads):
699
            start = len(sites)
700
            # We will only plot leads with a graph and with a symmetry.
701
702
            if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and
                len(sys.lead_interfaces[leadnr])):
703
                sites.extend(((site, leadnr, i) for site in
704
705
706
                              xrange(lead.cell_size) for i in
                              xrange(num_lead_cells)))
            lead_cells.append(slice(start, len(sites)))
707
708
    else:
        raise TypeError('Unrecognized system type.')
709
    return sites, lead_cells
710
711


712
713
def sys_leads_pos(sys, site_lead_nr):
    """Return an array of positions of sites in a system.
714
715
716

    Parameters
    ----------
717
718
    sys : `kwant.builder.Builder` or `kwant.system.System` instance
        The system, coordinates of sites of which should be returned.
719
    site_lead_nr : list of `(site, leadnr, copynr)` tuples
720
721
722
723
724
725
        Output of `sys_leads_sites` applied to the system.

    Returns
    -------
    coords : numpy.ndarray of floats
        Array of coordinates of the sites.
726
727
728

    Notes
    -----
729
730
731
    This function uses `site.pos` property to get the position of a builder
    site and `sys.pos(sitenr)` for finalized systems.  This function requires
    that all the positions of all the sites have the same dimensionality.
732
    """
733
734

    # Note about efficiency (also applies to sys_leads_hoppings_pos)
735
    # NumPy is really slow when making a NumPy array from a tinyarray
736
737
738
    # (buffer interface seems very slow). It's much faster to first
    # convert to a tuple and then to convert to numpy array ...

739
    is_builder = isinstance(sys, builder.Builder)
740
    num_lead_cells = site_lead_nr[-1][2] + 1
741
    if is_builder:
742
        pos = np.array(ta.array([i[0].pos for i in site_lead_nr]))
743
744
745
    else:
        sys_from_lead = lambda lead: (sys if (lead is None)
                                      else sys.leads[lead])
746
747
        pos = np.array(ta.array([sys_from_lead(i[1]).pos(i[0])
                                 for i in site_lead_nr]))
748
749
750
751
752
753
754
755
756
757
    if pos.dtype == object:  # Happens if not all the pos are same length.
        raise ValueError("pos attribute of the sites does not have consistent"
                         " values.")
    dim = pos.shape[1]

    def get_vec_domain(lead_nr):
        if lead_nr is None:
            return np.zeros((dim,)), 0
        if is_builder:
            sym = sys.leads[lead_nr].builder.symmetry
758
759
760
761
            try:
                site = sys.leads[lead_nr].interface[0]
            except IndexError:
                return (0, 0)
762
        else:
763
            try:
764
                sym = sys.leads[lead_nr].symmetry
765
                site = sys.sites[sys.lead_interfaces[lead_nr][0]]
766
767
            except (AttributeError, IndexError):
                # empty leads, or leads without symmetry aren't drawn anyways
768
                return (0, 0)
769
        dom = sym.which(site)[0] + 1
770
        # Conversion to numpy array here useful for efficiency
771
772
773
774
775
        vec = np.array(sym.periods)[0]
        return vec, dom
    vecs_doms = dict((i, get_vec_domain(i)) for i in xrange(len(sys.leads)))
    vecs_doms[None] = np.zeros((dim,)), 0
    for k, v in vecs_doms.iteritems():
776
        vecs_doms[k] = [v[0] * i for i in xrange(v[1], v[1] + num_lead_cells)]
777
778
779
780
    pos += [vecs_doms[i[1]][i[2]] for i in site_lead_nr]
    return pos


781
def sys_leads_hoppings(sys, num_lead_cells=2):
782
    """Return all the hoppings of the system and of the leads as an iterator.
783

784
785
786
787
    Parameters
    ----------
    sys : kwant.builder.Builder or kwant.system.System instance
        The system, sites of which should be returned.
788
    num_lead_cells : integer
789
790
791
        The number of times lead sites from each lead should be returned.
        This is useful for showing several unit cells of the lead next to the
        system.
792

793
794
795
796
797
798
    Returns
    -------
    hoppings : list of (hopping, lead_number, copy_number) tuples
        A site is a `builder.Site` instance if the system is not finalized,
        and an integer otherwise.  For system sites `lead_number` is `None` and
        `copy_number` is `0`, for leads both are integers.
799
800
    lead_cells : list of slices
        `lead_cells[i]` gives the position of all the coordinates of lead
801
        `i` within `hoppings`.
802

803
804
805
806
807
808
809
810
    Notes
    -----
    Leads are only supported if they are of the same type as the original
    system, i.e.  hoppings of `builder.BuilderLead` leads are returned with an
    unfinalized system, and hoppings of `system.InfiniteSystem` leads are
    returned with a finalized system.
    """
    hoppings = []
811
    lead_cells = []
812
813
    if isinstance(sys, builder.Builder):
        hoppings.extend(((hop, None, 0) for hop in sys.hoppings()))
814

815
816
817
818
819
820
821
822
823
824
825
826
827
828
        def lead_hoppings(lead):
            sym = lead.symmetry
            for site2, site1 in lead.hoppings():
                shift1 = sym.which(site1)[0]
                shift2 = sym.which(site2)[0]
                # We need to make sure that the hopping is between a site in a
                # fundamental domain and a site with a negative domain.  The
                # direction of the hopping is chosen arbitrarily
                # NOTE(Anton): This may need to be revisited with the future
                # builder format changes.
                shift = max(shift1, shift2)
                yield sym.act([-shift], site2), sym.act([-shift], site1)

        for leadnr, lead in enumerate(sys.leads):
829
            start = len(hoppings)
830
            if hasattr(lead, 'builder') and len(lead.interface):
831
                hoppings.extend(((hop, leadnr, i) for hop in
Anton Akhmerov's avatar
Anton Akhmerov committed
832
833
                                 lead_hoppings(lead.builder) for i in
                                 xrange(num_lead_cells)))
834
            lead_cells.append(slice(start, len(hoppings)))
835
836
837
838
839
840
    elif isinstance(sys, system.System):
        def ll_hoppings(sys):
            for i in xrange(sys.graph.num_nodes):
                for j in sys.graph.out_neighbors(i):
                    if i < j:
                        yield i, j
Anton Akhmerov's avatar
Anton Akhmerov committed
841

842
843
        hoppings.extend(((hop, None, 0) for hop in ll_hoppings(sys)))
        for leadnr, lead in enumerate(sys.leads):
844
            start = len(hoppings)
845
            # We will only plot leads with a graph and with a symmetry.
846
847
            if (hasattr(lead, 'graph') and hasattr(lead, 'symmetry') and
                len(sys.lead_interfaces[leadnr])):
Anton Akhmerov's avatar
Anton Akhmerov committed
848
849
                hoppings.extend(((hop, leadnr, i) for hop in ll_hoppings(lead)
                                 for i in xrange(num_lead_cells)))
850
            lead_cells.append(slice(start, len(hoppings)))
851
    else:
852
        raise TypeError('Unrecognized system type.')
853
    return hoppings, lead_cells
854
855


856
857
def sys_leads_hopping_pos(sys, hop_lead_nr):
    """Return arrays of coordinates of all hoppings in a system.
858

859
860
861
862
863
864
    Parameters
    ----------
    sys : `kwant.builder.Builder` or `kwant.system.System` instance
        The system, coordinates of sites of which should be returned.
    hoppings : list of `(hopping, leadnr, copynr)` tuples
        Output of `sys_leads_hoppings` applied to the system.
865

866
867
    Returns
    -------
868
    coords : (end_site, start_site): tuple of NumPy arrays of floats
869
870
871
        Array of coordinates of the hoppings.  The first half of coordinates
        in each array entry are those of the first site in the hopping, the
        last half are those of the second site.
872

873
874
875
876
877
878
879
880
881
    Notes
    -----
    This function uses `site.pos` property to get the position of a builder
    site and `sys.pos(sitenr)` for finalized systems.  This function requires
    that all the positions of all the sites have the same dimensionality.
    """
    is_builder = isinstance(sys, builder.Builder)
    if len(hop_lead_nr) == 0:
        return np.empty((0, 3)), np.empty((0, 3))
882
    num_lead_cells = hop_lead_nr[-1][2] + 1
883
    if is_builder:
884
        pos = np.array(ta.array([ta.array(tuple(i[0][0].pos) +
Anton Akhmerov's avatar
Anton Akhmerov committed
885
886
                                          tuple(i[0][1].pos)) for i in
                                 hop_lead_nr]))
887
    else:
Anton Akhmerov's avatar
Anton Akhmerov committed
888
889
        sys_from_lead = lambda lead: (sys if (lead is None) else
                                      sys.leads[lead])
890
        pos = ta.array([ta.array(tuple(sys_from_lead(i[1]).pos(i[0][0])) +
Anton Akhmerov's avatar
Anton Akhmerov committed
891
892
                                 tuple(sys_from_lead(i[1]).pos(i[0][1]))) for i
                        in hop_lead_nr])
893
        pos = np.array(pos)
894
895
896
897
898
899
900
901
902
903
    if pos.dtype == object:  # Happens if not all the pos are same length.
        raise ValueError("pos attribute of the sites does not have consistent"
                         " values.")
    dim = pos.shape[1]

    def get_vec_domain(lead_nr):
        if lead_nr is None:
            return np.zeros((dim,)), 0
        if is_builder:
            sym = sys.leads[lead_nr].builder.symmetry
904
905
906
907
            try:
                site = sys.leads[lead_nr].interface[0]
            except IndexError:
                return (0, 0)
908
        else:
909
            try:
910
                sym = sys.leads[lead_nr].symmetry
911
                site = sys.sites[sys.lead_interfaces[lead_nr][0]]
912
913
            except (AttributeError, IndexError):
                # empyt leads or leads without symmetry are not drawn anyways
914
                return (0, 0)
915
916
917
918
919
920
921
        dom = sym.which(site)[0] + 1
        vec = np.array(sym.periods)[0]
        return np.r_[vec, vec], dom

    vecs_doms = dict((i, get_vec_domain(i)) for i in xrange(len(sys.leads)))
    vecs_doms[None] = np.zeros((dim,)), 0
    for k, v in vecs_doms.iteritems():
922
        vecs_doms[k] = [v[0] * i for i in xrange(v[1], v[1] + num_lead_cells)]
923
924
925
    pos += [vecs_doms[i[1]][i[2]] for i in hop_lead_nr]
    return np.copy(pos[:, : dim / 2]), np.copy(pos[:, dim / 2:])

926

927
928
# Useful plot functions (to be extended).

929
930
931
932
933
934
935
936
defaults = {'site_symbol': {2: 'o', 3: 'o'},
            'site_size': {2: 0.25, 3: 0.5},
            'site_color': {2: 'black', 3: 'white'},
            'site_edgecolor': {2: 'black', 3: 'black'},
            'site_lw': {2: 0, 3: 0.1},
            'hop_color': {2: 'black', 3: 'black'},
            'hop_lw': {2: 0.1, 3: 0},
            'lead_color': {2: 'red', 3: 'red'}}
937

938

Christoph Groth's avatar
Christoph Groth committed
939
def plot(sys, num_lead_cells=2, unit='nn',
940
941
942
943
944
945
946
         site_symbol=None, site_size=None,
         site_color=None, site_edgecolor=None, site_lw=None,
         hop_color=None, hop_lw=None,
         lead_site_symbol=None, lead_site_size=None, lead_color=None,
         lead_site_edgecolor=None, lead_site_lw=None,
         lead_hop_lw=None, pos_transform=None,
         cmap='gray', colorbar=True, file=None,
Anton Akhmerov's avatar
Anton Akhmerov committed
947
         show=True, dpi=None, fig_size=None, ax=None):
948
    """Plot a system in 2 or 3 dimensions.
949

950
951
952
953
    Parameters
    ----------
    sys : kwant.builder.Builder or kwant.system.FiniteSystem
        A system to be plotted.
954
    num_lead_cells : int
955
        Number of lead copies to be shown with the system.
956
957
    unit : 'nn', 'pt', or float
        The unit used to specify symbol sizes and linewidths.
958
959
        Possible choices are:

960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
        - 'nn': unit is the shortest hopping or a typical nearst neighbor
          distance in the system if there are no hoppings.  This means that
          symbol sizes/linewidths will scale as the zoom level of the figure is
          changed.  Very short distances are discarded before searching for the
          shortest.  This choice means that the symbols will scale if the
          figure is zoomed.
        - 'pt': unit is points (point = 1/72 inch) in figure space.  This means
          that symbols and linewidths will always be drawn with the same size
          independent of zoom level of the plot.
        - float: sizes are given in units of this value in real (system) space,
          and will accordingly scale as the plot is zoomed.

        The default value is 'nn', which allows to ensure that the images
        neighboring sites do not overlap.

975
    site_symbol : symbol specification, function, array, or `None`
976
977
        Symbol used for representing a site in the plot. Can be specified as

978
979
980
        - 'o': circle with radius of 1 unit.
        - 's': square with inner circle radius of 1 unit.
        - `('p', nvert, angle)`: regular polygon with `nvert` vertices,
981
982
983
          rotated by `angle`. `angle` is given in degrees, and ``angle=0``
          corresponds to one edge of the polygon pointing upward. The
          radius of the inner circle is 1 unit.
984
        - 'no symbol': no symbol is plotted.
985
986
        - 'S', `('P', nvert, angle)`: as the lower-case variants described
          above, but with an area equal to a circle of radius 1. (Makes
987
          the visual size of the symbol equal to the size of a circle with
Christoph Groth's avatar
Christoph Groth committed
988
          radius 1).
989
        - matplotlib.path.Path instance.
990
991
992
993
994

        Instead of a single symbol, different symbols can be specified
        for different sites by passing a function that returns a valid
        symbol specification for each site, or by passing an array of
        symbols specifications (only for kwant.system.FiniteSystem).
995
    site_size : number, function, array, or `None`
996
        Relative (linear) size of the site symbol.
997
    site_color : `matplotlib` color description, function, array, or `None`
998
999
1000
        A color used for plotting a site in the system. If a colormap is used,
        it should be a function returning single floats or a one-dimensional
        array of floats.
For faster browsing, not all history is shown. View entire blame