test_plotter.py 6.17 KB
Newer Older
1
# Copyright 2011-2013 Kwant authors.
2
#
Christoph Groth's avatar
Christoph Groth committed
3
4
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
5
# http://kwant-project.org/license.  A list of Kwant authors can be found in
Christoph Groth's avatar
Christoph Groth committed
6
# the file AUTHORS.rst at the top-level directory of this distribution and at
7
8
# http://kwant-project.org/authors.

9
import tempfile
10
import warnings
11
import numpy as np
12
13
import kwant
from kwant import plotter
Anton Akhmerov's avatar
Anton Akhmerov committed
14
15
import pytest

16
if plotter.mpl_enabled:
17
    from mpl_toolkits import mplot3d
18
    from matplotlib import pyplot  # pragma: no flakes
19
20


21
22
def test_importable_without_matplotlib():
    prefix, sep, suffix = plotter.__file__.rpartition('.')
23
    if suffix in ['pyc', 'pyo']:
24
25
26
27
28
29
30
        suffix = 'py'
    assert suffix == 'py'
    fname = sep.join((prefix, suffix))
    with open(fname) as f:
        code = f.read()
    code = code.replace('from . import', 'from kwant import')
    code = code.replace('matplotlib', 'totalblimp')
31
32
33

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
Joseph Weston's avatar
Joseph Weston committed
34
        exec(code)               # Trigger the warning.
Anton Akhmerov's avatar
Anton Akhmerov committed
35
        assert len(w) == 1
36
37
        assert issubclass(w[0].category, RuntimeWarning)
        assert "only iterator-providing functions" in str(w[0].message)
38
39


40
def syst_2d(W=3, r1=3, r2=8):
41
42
    a = 1
    t = 1.0
43
    lat = kwant.lattice.square(a)
44
    syst = kwant.Builder()
45
46
47
48
49
50

    def ring(pos):
        (x, y) = pos
        rsq = x ** 2 + y ** 2
        return r1 ** 2 < rsq < r2 ** 2

51
52
    syst[lat.shape(ring, (0, r1 + 1))] = 4 * t
    syst[lat.neighbors()] = -t
53
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
54
    lead0 = kwant.Builder(sym_lead0)
55
    lead2 = kwant.Builder(sym_lead0)
56

57
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2)
58
59

    lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
60
    lead2[lat.shape(lead_shape, (0, 0))] = 4 * t
61
    syst.attach_lead(lead2)
62
    lead0[lat.neighbors()] = - t
63
    lead1 = lead0.reversed()
64
65
66
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
67
68


69
def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
70
    lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)))
71
    syst = kwant.Builder()
72
73
74
75
76

    def ring(pos):
        (x, y, z) = pos
        rsq = x ** 2 + y ** 2
        return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2
77
78
    syst[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t
    syst[lat.neighbors()] = - t
79
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
80
81
    lead0 = kwant.Builder(sym_lead0)

82
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2
83
84

    lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
85
    lead0[lat.neighbors()] = - t
86
    lead1 = lead0.reversed()
87
88
89
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
90
91


Anton Akhmerov's avatar
Anton Akhmerov committed
92
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
93
def test_plot():
94
    plot = plotter.plot
95
96
    syst2d = syst_2d()
    syst3d = syst_3d()
97
98
99
    color_opts = ['k', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
100
101
    with tempfile.TemporaryFile('w+b') as out:
        for color in color_opts:
102
103
            for syst in (syst2d, syst3d):
                fig = plot(syst, site_color=color, cmap='binary', file=out)
104
                if (color != 'k' and
105
                    isinstance(color(next(iter(syst2d.sites()))), float)):
106
                    assert fig.axes[0].collections[0].get_array() is not None
107
                assert len(fig.axes[0].collections) == (8 if syst is syst2d else
108
109
110
111
112
                                                        6)
        color_opts = ['k', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]
        for color in color_opts:
113
114
            for syst in (syst2d, syst3d):
                fig = plot(syst2d, hop_color=color, cmap='binary', file=out,
115
                           fig_size=(2, 10), dpi=30)
116
                if color != 'k' and isinstance(color(next(iter(syst2d.sites())),
117
118
119
                                                          None), float):
                    assert fig.axes[0].collections[1].get_array() is not None

120
        assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D)
121

122
123
124
125
        syst2d.leads = []
        plot(syst2d, file=out)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out)
126

127
        plot(syst3d, file=out)
128
129
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
130
            plot(syst2d.finalized(), file=out)
131

132
133
134
135
136
137
138
def good_transform(pos):
    x, y = pos
    return y, x

def bad_transform(pos):
    x, y = pos
    return x, y, 0
139

Anton Akhmerov's avatar
Anton Akhmerov committed
140
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
141
def test_map():
142
    syst = syst_2d()
143
    with tempfile.TemporaryFile('w+b') as out:
144
        plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
145
                    file=out, method='linear', a=4, oversampling=4, cmap='flag')
Anton Akhmerov's avatar
Anton Akhmerov committed
146
147
148
        pytest.raises(ValueError, plotter.map, syst,
                      lambda site: site.tag[0],
                      pos_transform=bad_transform, file=out)
149
150
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
151
            plotter.map(syst.finalized(), range(len(syst.sites())),
152
                        file=out)
Anton Akhmerov's avatar
Anton Akhmerov committed
153
154
        pytest.raises(ValueError, plotter.map, syst,
                      range(len(syst.sites())), file=out)
155
156
157
158


def test_mask_interpolate():
    # A coordinate array with coordinates of two points almost coinciding.
159
    coords = np.array([[0, 0], [1e-7, 1e-7], [1, 1], [1, 0]])
160

161
162
163
164
165
166
167
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        plotter.mask_interpolate(coords, np.ones(len(coords)), a=1)
        assert len(w) == 1
        assert issubclass(w[-1].category, RuntimeWarning)
        assert "coinciding" in str(w[-1].message)

168
169
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
170
        pytest.raises(ValueError, plotter.mask_interpolate,
171
                      coords, np.ones(len(coords)))
172
        pytest.raises(ValueError, plotter.mask_interpolate,
173
                      coords, np.ones(2 * len(coords)))