test_plotter.py 5.03 KB
Newer Older
1
# Copyright 2011-2013 Kwant authors.
2
#
3
# This file is part of Kwant.  It is subject to the license terms in the
4
# LICENSE file found in the top-level directory of this distribution and at
5
# http://kwant-project.org/license.  A list of Kwant authors can be found in
6
7
8
# the AUTHORS file at the top-level directory of this distribution and at
# http://kwant-project.org/authors.

9
import tempfile
10
import warnings
11
import nose
12
13
import kwant
from kwant import plotter
14
if plotter.mpl_enabled:
15
16
17
18
    from mpl_toolkits import mplot3d
    from matplotlib import pyplot


19
20
21
22
23
24
25
26
27
28
def test_importable_without_matplotlib():
    prefix, sep, suffix = plotter.__file__.rpartition('.')
    if suffix == 'pyc':
        suffix = 'py'
    assert suffix == 'py'
    fname = sep.join((prefix, suffix))
    with open(fname) as f:
        code = f.read()
    code = code.replace('from . import', 'from kwant import')
    code = code.replace('matplotlib', 'totalblimp')
29
30
31
32
33
34
35

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        exec code               # Trigger the warning.
        nose.tools.assert_equal(len(w), 1)
        assert issubclass(w[0].category, RuntimeWarning)
        assert "only iterator-providing functions" in str(w[0].message)
36
37


38
39
40
def sys_2d(W=3, r1=3, r2=8):
    a = 1
    t = 1.0
41
    lat = kwant.lattice.square(a)
42
43
44
45
46
47
48
49
    sys = kwant.Builder()

    def ring(pos):
        (x, y) = pos
        rsq = x ** 2 + y ** 2
        return r1 ** 2 < rsq < r2 ** 2

    sys[lat.shape(ring, (0, r1 + 1))] = 4 * t
50
    sys[lat.neighbors()] = -t
51
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
52
    lead0 = kwant.Builder(sym_lead0)
53
    lead2 = kwant.Builder(sym_lead0)
54

55
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2)
56
57

    lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
58
59
    lead2[lat.shape(lead_shape, (0, 0))] = 4 * t
    sys.attach_lead(lead2)
60
    lead0[lat.neighbors()] = - t
61
62
63
64
65
66
67
    lead1 = lead0.reversed()
    sys.attach_lead(lead0)
    sys.attach_lead(lead1)
    return sys


def sys_3d(W=3, r1=2, r2=4, a=1, t=1.0):
68
    lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)))
69
70
71
72
73
74
75
    sys = kwant.Builder()

    def ring(pos):
        (x, y, z) = pos
        rsq = x ** 2 + y ** 2
        return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2
    sys[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t
76
    sys[lat.neighbors()] = - t
77
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
78
79
    lead0 = kwant.Builder(sym_lead0)

80
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2
81
82

    lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
83
    lead0[lat.neighbors()] = - t
84
85
86
87
    lead1 = lead0.reversed()
    sys.attach_lead(lead0)
    sys.attach_lead(lead1)
    return sys
88
89
90


def test_plot():
91
    plot = plotter.plot
92
    if not plotter.mpl_enabled:
93
94
95
96
97
98
        raise nose.SkipTest
    sys2d = sys_2d()
    sys3d = sys_3d()
    color_opts = ['k', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
99
100
101
102
    with tempfile.TemporaryFile('w+b') as out:
        for color in color_opts:
            for sys in (sys2d, sys3d):
                fig = plot(sys, site_color=color, cmap='binary', file=out)
103
104
                if (color != 'k' and
                    isinstance(color(iter(sys2d.sites()).next()), float)):
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
                    assert fig.axes[0].collections[0].get_array() is not None
                assert len(fig.axes[0].collections) == (8 if sys is sys2d else
                                                        6)
        color_opts = ['k', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]
        for color in color_opts:
            for sys in (sys2d, sys3d):
                fig = plot(sys2d, hop_color=color, cmap='binary', file=out,
                           fig_size=(2, 10), dpi=30)
                if color != 'k' and isinstance(color(iter(sys2d.sites()).next(),
                                                          None), float):
                    assert fig.axes[0].collections[1].get_array() is not None

        assert isinstance(plot(sys3d, file=out).axes[0], mplot3d.axes3d.Axes3D)

        sys2d.leads = []
        plot(sys2d, file=out)
        del sys2d[list(sys2d.hoppings())]
        plot(sys2d, file=out)

        plot(sys3d, file=out)
127
128
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
129
            plot(sys2d.finalized(), file=out)
130
131
132


def test_map():
133
134
    if not plotter.mpl_enabled:
        raise nose.SkipTest
135
    sys = sys_2d()
136
137
    with tempfile.TemporaryFile('w+b') as out:
        plotter.map(sys, lambda site: site.tag[0], file=out,
138
                          method='linear', a=4, oversampling=4, cmap='flag')
139
140
141
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            plotter.map(sys.finalized(), xrange(len(sys.sites())),
142
143
144
                              file=out)
        nose.tools.assert_raises(ValueError, plotter.map, sys,
                                 xrange(len(sys.sites())), file=out)