test_plotter.py 14.7 KB
Newer Older
1
# Copyright 2011-2013 Kwant authors.
2
#
Christoph Groth's avatar
Christoph Groth committed
3
4
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
5
# http://kwant-project.org/license.  A list of Kwant authors can be found in
Christoph Groth's avatar
Christoph Groth committed
6
# the file AUTHORS.rst at the top-level directory of this distribution and at
7
8
# http://kwant-project.org/authors.

9
import tempfile
10
import warnings
11
import itertools
12
import numpy as np
13
import tinyarray as ta
14
from math import cos, sin
15
import scipy.integrate
16
import scipy.stats
17
18
import pytest

19
20
import kwant
from kwant import plotter
21
from kwant._common import ensure_rng
Anton Akhmerov's avatar
Anton Akhmerov committed
22

23
if plotter.mpl_enabled:
24
    from mpl_toolkits import mplot3d
25
    from matplotlib import pyplot  # pragma: no flakes
26
27


28
29
def test_importable_without_matplotlib():
    prefix, sep, suffix = plotter.__file__.rpartition('.')
30
    if suffix in ['pyc', 'pyo']:
31
32
33
        suffix = 'py'
    assert suffix == 'py'
    fname = sep.join((prefix, suffix))
34
    with open(fname, 'rb') as f:
35
        code = f.read()
36
37
    code = code.replace(b'from . import', b'from kwant import')
    code = code.replace(b'matplotlib', b'totalblimp')
38
39
40

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
Joseph Weston's avatar
Joseph Weston committed
41
        exec(code)               # Trigger the warning.
Anton Akhmerov's avatar
Anton Akhmerov committed
42
        assert len(w) == 1
43
44
        assert issubclass(w[0].category, RuntimeWarning)
        assert "only iterator-providing functions" in str(w[0].message)
45
46


47
def syst_2d(W=3, r1=3, r2=8):
48
49
    a = 1
    t = 1.0
50
    lat = kwant.lattice.square(a, norbs=1)
51
    syst = kwant.Builder()
52
53
54
55
56
57

    def ring(pos):
        (x, y) = pos
        rsq = x ** 2 + y ** 2
        return r1 ** 2 < rsq < r2 ** 2

58
59
    syst[lat.shape(ring, (0, r1 + 1))] = 4 * t
    syst[lat.neighbors()] = -t
60
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
61
    lead0 = kwant.Builder(sym_lead0)
62
    lead2 = kwant.Builder(sym_lead0)
63

64
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2)
65
66

    lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
67
    lead2[lat.shape(lead_shape, (0, 0))] = 4 * t
68
    syst.attach_lead(lead2)
69
    lead0[lat.neighbors()] = - t
70
    lead1 = lead0.reversed()
71
72
73
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
74
75


76
def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
77
    lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)))
78
    syst = kwant.Builder()
79
80
81
82
83

    def ring(pos):
        (x, y, z) = pos
        rsq = x ** 2 + y ** 2
        return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2
84
85
    syst[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t
    syst[lat.neighbors()] = - t
86
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
87
88
    lead0 = kwant.Builder(sym_lead0)

89
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2
90
91

    lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
92
    lead0[lat.neighbors()] = - t
93
    lead1 = lead0.reversed()
94
95
96
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
97
98


Anton Akhmerov's avatar
Anton Akhmerov committed
99
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
100
def test_plot():
101
    plot = plotter.plot
102
103
    syst2d = syst_2d()
    syst3d = syst_3d()
104
105
106
    color_opts = ['k', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
107
108
    with tempfile.TemporaryFile('w+b') as out:
        for color in color_opts:
109
110
            for syst in (syst2d, syst3d):
                fig = plot(syst, site_color=color, cmap='binary', file=out)
111
                if (color != 'k' and
112
                    isinstance(color(next(iter(syst2d.sites()))), float)):
113
                    assert fig.axes[0].collections[0].get_array() is not None
114
                assert len(fig.axes[0].collections) == (8 if syst is syst2d else
115
116
117
118
119
                                                        6)
        color_opts = ['k', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]
        for color in color_opts:
120
121
            for syst in (syst2d, syst3d):
                fig = plot(syst2d, hop_color=color, cmap='binary', file=out,
122
                           fig_size=(2, 10), dpi=30)
123
                if color != 'k' and isinstance(color(next(iter(syst2d.sites())),
124
125
126
                                                          None), float):
                    assert fig.axes[0].collections[1].get_array() is not None

127
        assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D)
128

129
130
131
132
        syst2d.leads = []
        plot(syst2d, file=out)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out)
133

134
        plot(syst3d, file=out)
135
136
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
137
            plot(syst2d.finalized(), file=out)
138

139
140
141
142
143
144
145
def good_transform(pos):
    x, y = pos
    return y, x

def bad_transform(pos):
    x, y = pos
    return x, y, 0
146

Anton Akhmerov's avatar
Anton Akhmerov committed
147
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
148
def test_map():
149
    syst = syst_2d()
150
    with tempfile.TemporaryFile('w+b') as out:
151
        plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
152
                    file=out, method='linear', a=4, oversampling=4, cmap='flag')
Anton Akhmerov's avatar
Anton Akhmerov committed
153
154
155
        pytest.raises(ValueError, plotter.map, syst,
                      lambda site: site.tag[0],
                      pos_transform=bad_transform, file=out)
156
157
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
158
            plotter.map(syst.finalized(), range(len(syst.sites())),
159
                        file=out)
Anton Akhmerov's avatar
Anton Akhmerov committed
160
161
        pytest.raises(ValueError, plotter.map, syst,
                      range(len(syst.sites())), file=out)
162
163
164
165


def test_mask_interpolate():
    # A coordinate array with coordinates of two points almost coinciding.
166
    coords = np.array([[0, 0], [1e-7, 1e-7], [1, 1], [1, 0]])
167

168
169
170
171
172
173
174
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        plotter.mask_interpolate(coords, np.ones(len(coords)), a=1)
        assert len(w) == 1
        assert issubclass(w[-1].category, RuntimeWarning)
        assert "coinciding" in str(w[-1].message)

175
176
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
177
        pytest.raises(ValueError, plotter.mask_interpolate,
178
                      coords, np.ones(len(coords)))
179
        pytest.raises(ValueError, plotter.mask_interpolate,
180
                      coords, np.ones(2 * len(coords)))
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246


@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_bands():

    syst = syst_2d().finalized().leads[0]

    with tempfile.TemporaryFile('w+b') as out:
        plotter.bands(syst, file=out)
        plotter.bands(syst, fig_size=(10, 10), file=out)
        plotter.bands(syst, momenta=np.linspace(0, 2 * np.pi), file=out)

        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1)
        plotter.bands(syst, ax=ax, file=out)


@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_spectrum():

    def ham_1d(a, b, c):
        return a**2 + b**2 + c**2

    def ham_2d(a, b, c):
        return np.eye(2) * (a**2 + b**2 + c**2)

    lat = kwant.lattice.chain()
    syst = kwant.Builder()
    syst[(lat(i) for i in range(3))] = lambda site, a, b: a + b
    syst[lat.neighbors()] = lambda site1, site2, c: c
    fsyst = syst.finalized()

    vals = np.linspace(0, 1, 3)

    with tempfile.TemporaryFile('w+b') as out:

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out)
            # test with explicit figsize
            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
                             fig_size=(10, 10), file=out)

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), file=out)
            # test with explicit figsize
            plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), fig_size=(10, 10), file=out)

        # test 2D plot and explicitly passing axis
        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1, projection='3d')
        plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                         params=dict(c=1), ax=ax, file=out)
        # explicitly pass axis without 3D support
        ax = fig.add_subplot(1, 1, 1)
        with pytest.raises(TypeError):
            plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), ax=ax, file=out)

    def mask(a, b):
        return a > 0.5

    with tempfile.TemporaryFile('w+b') as out:
        plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
                         mask=mask, file=out)
247
248
249
250
251
252
253
254
255


def syst_rect(lat, salt, W=3, L=50):
    syst = kwant.Builder()

    ll = L//2
    ww = W//2

    def onsite(site):
Christoph Groth's avatar
Christoph Groth committed
256
        return 4 + 0.1 * kwant.digest.gauss(repr(site.tag), salt=salt)
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272

    syst[(lat(i, j) for i in range(-ll, ll+1)
         for j in range(-ww, ww+1))] = onsite
    syst[lat.neighbors()] = -1

    sym = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
    lead = kwant.Builder(sym)
    lead[(lat(0, j) for j in range(-ww, ww + 1))] = 4
    lead[lat.neighbors()] = -1

    syst.attach_lead(lead)
    syst.attach_lead(lead.reversed())

    return syst


273
274
def div(F, h):
    """Calculate the divergence of a vector field F over a grid of spacing h."""
275
    assert len(F.shape[:-1]) == F.shape[-1]
276
277
    assert len(h) == F.shape[-1]
    return sum(np.gradient(F[..., i], h[i])[i] for i in range(F.shape[-1]))
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336


def rotational_currents(g):
    """Return a basis of divergence-free currents for a closed graph.

    Given the graph 'g' of a Kwant system, returns a sequence of arrays
    which are linearly independent, divergence-free currents on the graph.
    """
    #'A' represents the set of expressions that give the net current flow
    # into the system sites. 'perm' is a map from the edges of a graph
    # with only 1 edge per hopping to the proper Kwant graph (2 edges
    # per hopping).
    A = np.zeros((g.num_nodes, g.num_edges // 2))
    hoppings = dict()
    perm_data = np.zeros(g.num_edges)
    perm_ij = np.zeros((2, g.num_edges))
    i = 0
    for k, (a, b) in enumerate(g):
        hop = frozenset((a, b))
        if hop not in hoppings:
            A[a, i] = 1
            A[b, i] = -1
            hoppings[hop] = i
            perm_data[k] = 1
            perm_ij[:, k] = (k, i)
            i += 1
        else:
            perm_data[k] = -1
            perm_ij[:, k] = (k, hoppings[hop])

    perm = scipy.sparse.coo_matrix((perm_data, perm_ij))

    # Get the row vectors of V with singular value 0. These form
    # a basis for the right null space of 'A'.
    U, S, V = np.linalg.svd(A)
    tol = S.max() * max(A.shape) * np.finfo(S.dtype).eps
    rank = sum(S > tol)
    # Transform null space basis into vectors defined over the full
    # hopping space (both hopping directions).
    null_space_basis = V[-(len(V) - rank):].transpose()
    null_space_basis = perm.dot(null_space_basis).transpose()
    return null_space_basis


def test_current_interpolation():

    ## Passing a Builder will raise an error
    pytest.raises(TypeError, plotter.interpolate_current, syst_2d(), None)

    def R(theta):
        return ta.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])

    def make_lattice(a, theta):
        x = ta.dot(R(theta), (a, 0))
        y = ta.dot(R(theta), (0, a))
        return kwant.lattice.general([x, y], norbs=1)

    ## Check current through cross section is same for different lattice
    ## parameters and orientations of the system wrt. the discretization grid
337
338
339
340
341
    for a, theta, width in [(1, 0, 1),
                            (1, 0, 0.5),
                            (2, 0, 1),
                            (1, 0.2, 1),
                            (2, 0.4, 1)]:
342
343
344
345
346
347
348
349
350
        lat = make_lattice(a, theta)
        syst = syst_rect(lat, salt='0').finalized()
        psi = kwant.wave_function(syst, energy=3)(0)

        def cut(a, b):
            return b.tag[0] < 0 and a.tag[0] >= 0

        J = kwant.operator.Current(syst).bind()
        J_cut = kwant.operator.Current(syst, where=cut, sum=True).bind()
351
352
353
354
        J_exact = J_cut(psi[0])

        data = []
        for n in [4, 6, 8, 11, 16]:
Christoph Groth's avatar
Christoph Groth committed
355
356
357
358
            j0, box = plotter.interpolate_current(syst, J(psi[0]),
                                                  n=n, abswidth=width)
            x, y = (np.linspace(mn, mx, shape)
                    for (mn, mx), shape in zip(box, j0.shape))
359
360
361
362
363
364
365
            # slice field perpendicular to a cut along the y axis
            y_axis = (np.argmin(np.abs(x)), slice(None), 0)
            J_interp = scipy.integrate.simps(j0[y_axis], y)
            data.append((n, abs(J_interp - J_exact)))
        # 3rd value returned from 'linregress' is 'rvalue'
        assert scipy.stats.linregress(np.log(data))[2] < -0.8

366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388

    ### Tests on a divergence-free current (closed system)

    lat = kwant.lattice.general([(1, 0), (0.5, np.sqrt(3) / 2)])
    syst = kwant.Builder()
    sites = [lat(0, 0), lat(1, 0), lat(0, 1), lat(2, 2)]
    syst[sites] = None
    syst[((s, t) for s, t in itertools.product(sites, sites) if s != t)] = None
    del syst[lat(0, 0), lat(2, 2)]
    syst = syst.finalized()

    # generate random divergence-free currents
    Js = rotational_currents(syst.graph)
    rng = ensure_rng(3)
    J0 = sum(rng.rand(len(Js))[:, None] * Js)
    J1 = sum(rng.rand(len(Js))[:, None] * Js)

    # Sanity check that diverence on the graph is 0
    divergence = np.zeros(len(syst.sites))
    for (a, _), current in zip(syst.graph, J0):
        divergence[a] += current
    assert np.allclose(divergence, 0)

Christoph Groth's avatar
Christoph Groth committed
389
390
    j0, _ = plotter.interpolate_current(syst, J0)
    j1, _ = plotter.interpolate_current(syst, J1)
391
392

    ## Test linearity of interpolation.
Christoph Groth's avatar
Christoph Groth committed
393
    j_tot, _ = plotter.interpolate_current(syst, J0 + 2 * J1)
394
395
    assert np.allclose(j_tot, j0 + 2 * j1)

396
397
398
399
    ## Test that divergence of interpolated current approaches zero as we make
    ## the interpolation finer.
    data = []
    for n in [4, 6, 8, 11, 16]:
Christoph Groth's avatar
Christoph Groth committed
400
401
        j, box = plotter.interpolate_current(syst, J0, n=n)
        dx = [(mx - mn) / (shape - 1) for (mn, mx), shape in zip(box, j.shape)]
402
403
404
405
406
        div_j = np.max(np.abs(div(j, dx)))
        data.append((n, div_j))

    # 3rd value returned from 'linregress' is 'rvalue'
    assert scipy.stats.linregress(np.log(data))[2] < -0.8
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421


@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_current():
    syst = syst_2d().finalized()
    J = kwant.operator.Current(syst)
    current = J(kwant.wave_function(syst, energy=1)(1)[0])

    # Test good codepath
    with tempfile.TemporaryFile('w+b') as out:
        plotter.current(syst, current, file=out)

        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1)
        plotter.current(syst, current, ax=ax, file=out)