test_plotter.py 3.86 KB
Newer Older
1
2
import tempfile
import nose
3
4
import kwant
from kwant import plotter
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
if plotter._mpl_enabled:
    from mpl_toolkits import mplot3d
    from matplotlib import pyplot


def sys_2d(W=3, r1=3, r2=8):
    a = 1
    t = 1.0
    lat = kwant.lattice.Square(a)
    sys = kwant.Builder()

    def ring(pos):
        (x, y) = pos
        rsq = x ** 2 + y ** 2
        return r1 ** 2 < rsq < r2 ** 2

    sys[lat.shape(ring, (0, r1 + 1))] = 4 * t
    for hopping in lat.nearest:
        sys[sys.possible_hoppings(*hopping)] = - t
24
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    lead0 = kwant.Builder(sym_lead0)

    def lead_shape(pos):
        (x, y) = pos
        return (-1 < x < 1) and (-W / 2 < y < W / 2)

    lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
    for hopping in lat.nearest:
        lead0[lead0.possible_hoppings(*hopping)] = - t
    lead1 = lead0.reversed()
    sys.attach_lead(lead0)
    sys.attach_lead(lead1)
    return sys


def sys_3d(W=3, r1=2, r2=4, a=1, t=1.0):
    lat = kwant.make_lattice(((a, 0, 0), (0, a, 0), (0, 0, a)))
    lat.nearest = (((1, 0, 0), lat, lat), ((0, 1, 0), lat, lat),
                   ((0, 0, 1), lat, lat))
    sys = kwant.Builder()

    def ring(pos):
        (x, y, z) = pos
        rsq = x ** 2 + y ** 2
        return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2
    sys[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t
    for hopping in lat.nearest:
        sys[sys.possible_hoppings(*hopping)] = - t
53
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
54
55
56
57
58
59
60
61
62
63
64
65
66
    lead0 = kwant.Builder(sym_lead0)

    def lead_shape(pos):
        (x, y, z) = pos
        return (-1 < x < 1) and (-W / 2 < y < W / 2) and abs(z) < 2

    lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
    for hopping in lat.nearest:
        lead0[lead0.possible_hoppings(*hopping)] = - t
    lead1 = lead0.reversed()
    sys.attach_lead(lead0)
    sys.attach_lead(lead1)
    return sys
67
68
69


def test_plot():
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    plot = plotter.plot
    if not plotter._mpl_enabled:
        raise nose.SkipTest
    sys2d = sys_2d()
    sys3d = sys_3d()
    color_opts = ['k', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
    for color in color_opts:
        for sys in (sys2d, sys3d):
            fig = plot(sys, site_color=color, cmap='binary', show=False)
            if color != 'k' and isinstance(color(iter(sys2d.sites()).next()),
                                           float):
                assert fig.axes[0].collections[0].get_array() is not None
            assert len(fig.axes[0].collections) == 4
    color_opts = ['k', (lambda site, site2: site.tag[0]),
                  lambda site, site2: (abs(site.tag[0] / 100),
                                       abs(site.tag[1] / 100), 0)]
    for color in color_opts:
        for sys in (sys2d, sys3d):
            fig = plot(sys2d, hop_color=color, cmap='binary', show=False,
                       fig_size=(2, 10), dpi=30)
            if color != 'k' and isinstance(color(iter(sys2d.sites()).next(),
                                                      None), float):
                assert fig.axes[0].collections[1].get_array() is not None

    assert isinstance(plot(sys3d, show=False).axes[0], mplot3d.axes3d.Axes3D)

    sys2d.leads = []
    plot(sys2d, show=False)
    del sys2d[list(sys2d.hoppings())]
    plot(sys2d, show=False)
    with tempfile.TemporaryFile('w+b') as output:
        plot(sys3d, file=output)


def test_map():
    sys = sys_2d()
    with tempfile.TemporaryFile('w+b') as output:
        plotter.map(sys, lambda site: site.tag[0], file=output,
                          method='linear', a=4, oversampling=4, cmap='flag')
        plotter.map(sys.finalized(), xrange(len(sys.sites())),
                          file=output)
        nose.tools.assert_raises(ValueError, plotter.map,
                                 sys, xrange(len(sys.sites())), file=output)