test_plotter.py 6.19 KB
Newer Older
1
# Copyright 2011-2013 Kwant authors.
2
#
Christoph Groth's avatar
Christoph Groth committed
3
4
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
5
# http://kwant-project.org/license.  A list of Kwant authors can be found in
Christoph Groth's avatar
Christoph Groth committed
6
# the file AUTHORS.rst at the top-level directory of this distribution and at
7
8
# http://kwant-project.org/authors.

9
import tempfile
10
import warnings
11
import nose
12
import numpy as np
13
14
import kwant
from kwant import plotter
15
from nose.tools import assert_raises
16
if plotter.mpl_enabled:
17
18
19
20
    from mpl_toolkits import mplot3d
    from matplotlib import pyplot


21
22
def test_importable_without_matplotlib():
    prefix, sep, suffix = plotter.__file__.rpartition('.')
23
    if suffix in ['pyc', 'pyo']:
24
25
26
27
28
29
30
        suffix = 'py'
    assert suffix == 'py'
    fname = sep.join((prefix, suffix))
    with open(fname) as f:
        code = f.read()
    code = code.replace('from . import', 'from kwant import')
    code = code.replace('matplotlib', 'totalblimp')
31
32
33

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
Joseph Weston's avatar
Joseph Weston committed
34
        exec(code)               # Trigger the warning.
35
36
37
        nose.tools.assert_equal(len(w), 1)
        assert issubclass(w[0].category, RuntimeWarning)
        assert "only iterator-providing functions" in str(w[0].message)
38
39


40
def syst_2d(W=3, r1=3, r2=8):
41
42
    a = 1
    t = 1.0
43
    lat = kwant.lattice.square(a)
44
    syst = kwant.Builder()
45
46
47
48
49
50

    def ring(pos):
        (x, y) = pos
        rsq = x ** 2 + y ** 2
        return r1 ** 2 < rsq < r2 ** 2

51
52
    syst[lat.shape(ring, (0, r1 + 1))] = 4 * t
    syst[lat.neighbors()] = -t
53
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
54
    lead0 = kwant.Builder(sym_lead0)
55
    lead2 = kwant.Builder(sym_lead0)
56

57
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2)
58
59

    lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
60
    lead2[lat.shape(lead_shape, (0, 0))] = 4 * t
61
    syst.attach_lead(lead2)
62
    lead0[lat.neighbors()] = - t
63
    lead1 = lead0.reversed()
64
65
66
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
67
68


69
def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
70
    lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)))
71
    syst = kwant.Builder()
72
73
74
75
76

    def ring(pos):
        (x, y, z) = pos
        rsq = x ** 2 + y ** 2
        return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2
77
78
    syst[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t
    syst[lat.neighbors()] = - t
79
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
80
81
    lead0 = kwant.Builder(sym_lead0)

82
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2
83
84

    lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
85
    lead0[lat.neighbors()] = - t
86
    lead1 = lead0.reversed()
87
88
89
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
90
91
92


def test_plot():
93
    plot = plotter.plot
94
    if not plotter.mpl_enabled:
95
        raise nose.SkipTest
96
97
    syst2d = syst_2d()
    syst3d = syst_3d()
98
99
100
    color_opts = ['k', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
101
102
    with tempfile.TemporaryFile('w+b') as out:
        for color in color_opts:
103
104
            for syst in (syst2d, syst3d):
                fig = plot(syst, site_color=color, cmap='binary', file=out)
105
                if (color != 'k' and
106
                    isinstance(color(next(iter(syst2d.sites()))), float)):
107
                    assert fig.axes[0].collections[0].get_array() is not None
108
                assert len(fig.axes[0].collections) == (8 if syst is syst2d else
109
110
111
112
113
                                                        6)
        color_opts = ['k', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]
        for color in color_opts:
114
115
            for syst in (syst2d, syst3d):
                fig = plot(syst2d, hop_color=color, cmap='binary', file=out,
116
                           fig_size=(2, 10), dpi=30)
117
                if color != 'k' and isinstance(color(next(iter(syst2d.sites())),
118
119
120
                                                          None), float):
                    assert fig.axes[0].collections[1].get_array() is not None

121
        assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D)
122

123
124
125
126
        syst2d.leads = []
        plot(syst2d, file=out)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out)
127

128
        plot(syst3d, file=out)
129
130
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
131
            plot(syst2d.finalized(), file=out)
132

133
134
135
136
137
138
139
def good_transform(pos):
    x, y = pos
    return y, x

def bad_transform(pos):
    x, y = pos
    return x, y, 0
140
141

def test_map():
142
143
    if not plotter.mpl_enabled:
        raise nose.SkipTest
144
    syst = syst_2d()
145
    with tempfile.TemporaryFile('w+b') as out:
146
        plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
147
                    file=out, method='linear', a=4, oversampling=4, cmap='flag')
148
        nose.tools.assert_raises(ValueError, plotter.map, syst,
149
150
                                 lambda site: site.tag[0],
                                 pos_transform=bad_transform, file=out)
151
152
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
153
            plotter.map(syst.finalized(), range(len(syst.sites())),
154
                              file=out)
155
156
        nose.tools.assert_raises(ValueError, plotter.map, syst,
                                 range(len(syst.sites())), file=out)
157
158
159
160
161
162
163


def test_mask_interpolate():
    # A coordinate array with coordinates of two points almost coinciding.
    coords = np.random.rand(10, 2)
    coords[5] *= 1e-8
    coords[5] += coords[0]
164

165
166
167
168
169
170
171
172
173
174
    warnings.simplefilter("ignore")
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        plotter.mask_interpolate(coords, np.ones(len(coords)), a=1)
        assert len(w) == 1
        assert issubclass(w[-1].category, RuntimeWarning)
        assert "coinciding" in str(w[-1].message)

    assert_raises(ValueError, plotter.mask_interpolate,
                  coords, np.ones(len(coords)))
175
    assert_raises(ValueError, plotter.mask_interpolate,
176
                  coords, np.ones(2 * len(coords)))