test_plotter.py 14.6 KB
Newer Older
1
# Copyright 2011-2013 Kwant authors.
2
#
Christoph Groth's avatar
Christoph Groth committed
3
4
# This file is part of Kwant.  It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
5
# http://kwant-project.org/license.  A list of Kwant authors can be found in
Christoph Groth's avatar
Christoph Groth committed
6
# the file AUTHORS.rst at the top-level directory of this distribution and at
7
8
# http://kwant-project.org/authors.

9
import tempfile
10
import warnings
11
import itertools
12
import numpy as np
13
14
15
16
17
import tinyarray as ta
from math import cos, sin, pi
import scipy.integrate
import pytest

18
19
import kwant
from kwant import plotter
20
from kwant._common import ensure_rng
Anton Akhmerov's avatar
Anton Akhmerov committed
21

22
if plotter.mpl_enabled:
23
    from mpl_toolkits import mplot3d
24
    from matplotlib import pyplot  # pragma: no flakes
25
26


27
28
def test_importable_without_matplotlib():
    prefix, sep, suffix = plotter.__file__.rpartition('.')
29
    if suffix in ['pyc', 'pyo']:
30
31
32
33
34
35
36
        suffix = 'py'
    assert suffix == 'py'
    fname = sep.join((prefix, suffix))
    with open(fname) as f:
        code = f.read()
    code = code.replace('from . import', 'from kwant import')
    code = code.replace('matplotlib', 'totalblimp')
37
38
39

    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
Joseph Weston's avatar
Joseph Weston committed
40
        exec(code)               # Trigger the warning.
Anton Akhmerov's avatar
Anton Akhmerov committed
41
        assert len(w) == 1
42
43
        assert issubclass(w[0].category, RuntimeWarning)
        assert "only iterator-providing functions" in str(w[0].message)
44
45


46
def syst_2d(W=3, r1=3, r2=8):
47
48
    a = 1
    t = 1.0
49
    lat = kwant.lattice.square(a, norbs=1)
50
    syst = kwant.Builder()
51
52
53
54
55
56

    def ring(pos):
        (x, y) = pos
        rsq = x ** 2 + y ** 2
        return r1 ** 2 < rsq < r2 ** 2

57
58
    syst[lat.shape(ring, (0, r1 + 1))] = 4 * t
    syst[lat.neighbors()] = -t
59
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
60
    lead0 = kwant.Builder(sym_lead0)
61
    lead2 = kwant.Builder(sym_lead0)
62

63
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2)
64
65

    lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
66
    lead2[lat.shape(lead_shape, (0, 0))] = 4 * t
67
    syst.attach_lead(lead2)
68
    lead0[lat.neighbors()] = - t
69
    lead1 = lead0.reversed()
70
71
72
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
73
74


75
def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
76
    lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)))
77
    syst = kwant.Builder()
78
79
80
81
82

    def ring(pos):
        (x, y, z) = pos
        rsq = x ** 2 + y ** 2
        return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2
83
84
    syst[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t
    syst[lat.neighbors()] = - t
85
    sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
86
87
    lead0 = kwant.Builder(sym_lead0)

88
    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2
89
90

    lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
91
    lead0[lat.neighbors()] = - t
92
    lead1 = lead0.reversed()
93
94
95
    syst.attach_lead(lead0)
    syst.attach_lead(lead1)
    return syst
96
97


Anton Akhmerov's avatar
Anton Akhmerov committed
98
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
99
def test_plot():
100
    plot = plotter.plot
101
102
    syst2d = syst_2d()
    syst3d = syst_3d()
103
104
105
    color_opts = ['k', (lambda site: site.tag[0]),
                  lambda site: (abs(site.tag[0] / 100),
                                abs(site.tag[1] / 100), 0)]
106
107
    with tempfile.TemporaryFile('w+b') as out:
        for color in color_opts:
108
109
            for syst in (syst2d, syst3d):
                fig = plot(syst, site_color=color, cmap='binary', file=out)
110
                if (color != 'k' and
111
                    isinstance(color(next(iter(syst2d.sites()))), float)):
112
                    assert fig.axes[0].collections[0].get_array() is not None
113
                assert len(fig.axes[0].collections) == (8 if syst is syst2d else
114
115
116
117
118
                                                        6)
        color_opts = ['k', (lambda site, site2: site.tag[0]),
                      lambda site, site2: (abs(site.tag[0] / 100),
                                           abs(site.tag[1] / 100), 0)]
        for color in color_opts:
119
120
            for syst in (syst2d, syst3d):
                fig = plot(syst2d, hop_color=color, cmap='binary', file=out,
121
                           fig_size=(2, 10), dpi=30)
122
                if color != 'k' and isinstance(color(next(iter(syst2d.sites())),
123
124
125
                                                          None), float):
                    assert fig.axes[0].collections[1].get_array() is not None

126
        assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D)
127

128
129
130
131
        syst2d.leads = []
        plot(syst2d, file=out)
        del syst2d[list(syst2d.hoppings())]
        plot(syst2d, file=out)
132

133
        plot(syst3d, file=out)
134
135
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
136
            plot(syst2d.finalized(), file=out)
137

138
139
140
141
142
143
144
def good_transform(pos):
    x, y = pos
    return y, x

def bad_transform(pos):
    x, y = pos
    return x, y, 0
145

Anton Akhmerov's avatar
Anton Akhmerov committed
146
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
147
def test_map():
148
    syst = syst_2d()
149
    with tempfile.TemporaryFile('w+b') as out:
150
        plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
151
                    file=out, method='linear', a=4, oversampling=4, cmap='flag')
Anton Akhmerov's avatar
Anton Akhmerov committed
152
153
154
        pytest.raises(ValueError, plotter.map, syst,
                      lambda site: site.tag[0],
                      pos_transform=bad_transform, file=out)
155
156
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
157
            plotter.map(syst.finalized(), range(len(syst.sites())),
158
                        file=out)
Anton Akhmerov's avatar
Anton Akhmerov committed
159
160
        pytest.raises(ValueError, plotter.map, syst,
                      range(len(syst.sites())), file=out)
161
162
163
164


def test_mask_interpolate():
    # A coordinate array with coordinates of two points almost coinciding.
165
    coords = np.array([[0, 0], [1e-7, 1e-7], [1, 1], [1, 0]])
166

167
168
169
170
171
172
173
    with warnings.catch_warnings(record=True) as w:
        warnings.simplefilter("always")
        plotter.mask_interpolate(coords, np.ones(len(coords)), a=1)
        assert len(w) == 1
        assert issubclass(w[-1].category, RuntimeWarning)
        assert "coinciding" in str(w[-1].message)

174
175
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
176
        pytest.raises(ValueError, plotter.mask_interpolate,
177
                      coords, np.ones(len(coords)))
178
        pytest.raises(ValueError, plotter.mask_interpolate,
179
                      coords, np.ones(2 * len(coords)))
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245


@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_bands():

    syst = syst_2d().finalized().leads[0]

    with tempfile.TemporaryFile('w+b') as out:
        plotter.bands(syst, file=out)
        plotter.bands(syst, fig_size=(10, 10), file=out)
        plotter.bands(syst, momenta=np.linspace(0, 2 * np.pi), file=out)

        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1)
        plotter.bands(syst, ax=ax, file=out)


@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_spectrum():

    def ham_1d(a, b, c):
        return a**2 + b**2 + c**2

    def ham_2d(a, b, c):
        return np.eye(2) * (a**2 + b**2 + c**2)

    lat = kwant.lattice.chain()
    syst = kwant.Builder()
    syst[(lat(i) for i in range(3))] = lambda site, a, b: a + b
    syst[lat.neighbors()] = lambda site1, site2, c: c
    fsyst = syst.finalized()

    vals = np.linspace(0, 1, 3)

    with tempfile.TemporaryFile('w+b') as out:

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out)
            # test with explicit figsize
            plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
                             fig_size=(10, 10), file=out)

        for ham in (ham_1d, ham_2d, fsyst):
            plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), file=out)
            # test with explicit figsize
            plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), fig_size=(10, 10), file=out)

        # test 2D plot and explicitly passing axis
        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1, projection='3d')
        plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                         params=dict(c=1), ax=ax, file=out)
        # explicitly pass axis without 3D support
        ax = fig.add_subplot(1, 1, 1)
        with pytest.raises(TypeError):
            plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
                             params=dict(c=1), ax=ax, file=out)

    def mask(a, b):
        return a > 0.5

    with tempfile.TemporaryFile('w+b') as out:
        plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
                         mask=mask, file=out)
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418


def syst_rect(lat, salt, W=3, L=50):
    syst = kwant.Builder()

    ll = L//2
    ww = W//2

    def onsite(site):
        return 4 + 0.1 * kwant.digest.gauss(site.tag, salt=salt)

    syst[(lat(i, j) for i in range(-ll, ll+1)
         for j in range(-ww, ww+1))] = onsite
    syst[lat.neighbors()] = -1

    sym = kwant.TranslationalSymmetry(lat.vec((-1, 0)))
    lead = kwant.Builder(sym)
    lead[(lat(0, j) for j in range(-ww, ww + 1))] = 4
    lead[lat.neighbors()] = -1

    syst.attach_lead(lead)
    syst.attach_lead(lead.reversed())

    return syst


def div(F):
    """Calculate the divergence of a vector field F."""
    assert len(F.shape[:-1]) == F.shape[-1]
    return sum(np.gradient(F[..., i])[i] for i in range(F.shape[-1]))


def rotational_currents(g):
    """Return a basis of divergence-free currents for a closed graph.

    Given the graph 'g' of a Kwant system, returns a sequence of arrays
    which are linearly independent, divergence-free currents on the graph.
    """
    #'A' represents the set of expressions that give the net current flow
    # into the system sites. 'perm' is a map from the edges of a graph
    # with only 1 edge per hopping to the proper Kwant graph (2 edges
    # per hopping).
    A = np.zeros((g.num_nodes, g.num_edges // 2))
    hoppings = dict()
    perm_data = np.zeros(g.num_edges)
    perm_ij = np.zeros((2, g.num_edges))
    i = 0
    for k, (a, b) in enumerate(g):
        hop = frozenset((a, b))
        if hop not in hoppings:
            A[a, i] = 1
            A[b, i] = -1
            hoppings[hop] = i
            perm_data[k] = 1
            perm_ij[:, k] = (k, i)
            i += 1
        else:
            perm_data[k] = -1
            perm_ij[:, k] = (k, hoppings[hop])

    perm = scipy.sparse.coo_matrix((perm_data, perm_ij))

    # Get the row vectors of V with singular value 0. These form
    # a basis for the right null space of 'A'.
    U, S, V = np.linalg.svd(A)
    tol = S.max() * max(A.shape) * np.finfo(S.dtype).eps
    rank = sum(S > tol)
    # Transform null space basis into vectors defined over the full
    # hopping space (both hopping directions).
    null_space_basis = V[-(len(V) - rank):].transpose()
    null_space_basis = perm.dot(null_space_basis).transpose()
    return null_space_basis


def test_current_interpolation():

    ## Passing a Builder will raise an error
    pytest.raises(TypeError, plotter.interpolate_current, syst_2d(), None)

    def R(theta):
        return ta.array([[cos(theta), -sin(theta)], [sin(theta), cos(theta)]])

    def make_lattice(a, theta):
        x = ta.dot(R(theta), (a, 0))
        y = ta.dot(R(theta), (0, a))
        return kwant.lattice.general([x, y], norbs=1)

    angles = (0, pi/6, pi/4)
    lat_constants = (1, 2)

    ## Check current through cross section is same for different lattice
    ## parameters and orientations of the system wrt. the discretization grid
    for a, theta in itertools.product(lat_constants, angles):
        lat = make_lattice(a, theta)
        syst = syst_rect(lat, salt='0').finalized()
        psi = kwant.wave_function(syst, energy=3)(0)

        def cut(a, b):
            return b.tag[0] < 0 and a.tag[0] >= 0

        J = kwant.operator.Current(syst).bind()
        J_cut = kwant.operator.Current(syst, where=cut, sum=True).bind()
        (x, y), j0 = plotter.interpolate_current(syst, J(psi[0]), gauss_range=5)

        # slice field perpendicular to a cut along the y axis
        y_axis = (np.argmin(np.abs(x)), slice(None), 0)
        ## Integrate and compare with summed current.
        assert np.isclose(J_cut(psi[0]),
                          scipy.integrate.simps(j0[y_axis], y))

    ## Check that taking a finer grid or changing the broadening does not
    ## affect the total integrated current.
    n_s = (3, 5)
    sigma_s = (1, 0.5)

    for n, sigma in zip(n_s, sigma_s):
        (x, y), j0 = plotter.interpolate_current(syst, J(psi[0]), n=n,
                                                 sigma=sigma, gauss_range=5)
        # slice field perpendicular to a cut along the y axis
        y_axis = (np.argmin(np.abs(x)), slice(None), 0)
        assert np.isclose(J_cut(psi[0]),
                          scipy.integrate.simps(j0[y_axis], y))

    ### Tests on a divergence-free current (closed system)

    lat = kwant.lattice.general([(1, 0), (0.5, np.sqrt(3) / 2)])
    syst = kwant.Builder()
    sites = [lat(0, 0), lat(1, 0), lat(0, 1), lat(2, 2)]
    syst[sites] = None
    syst[((s, t) for s, t in itertools.product(sites, sites) if s != t)] = None
    del syst[lat(0, 0), lat(2, 2)]
    syst = syst.finalized()

    # generate random divergence-free currents
    Js = rotational_currents(syst.graph)
    rng = ensure_rng(3)
    J0 = sum(rng.rand(len(Js))[:, None] * Js)
    J1 = sum(rng.rand(len(Js))[:, None] * Js)

    # Sanity check that diverence on the graph is 0
    divergence = np.zeros(len(syst.sites))
    for (a, _), current in zip(syst.graph, J0):
        divergence[a] += current
    assert np.allclose(divergence, 0)

    _, j0 = plotter.interpolate_current(syst, J0)
    _, j1 = plotter.interpolate_current(syst, J1)

    ## Test linearity of interpolation.
    _, j_tot = plotter.interpolate_current(syst, J0 + 2 * J1)
    assert np.allclose(j_tot, j0 + 2 * j1)

    ## Test that divergence of interpolated current is approximately zero.
    # For currents not aligned with the interpolation grid this is only
    # 1/a**2 accurate.
    _, j = plotter.interpolate_current(syst, J0, n=20, gauss_range=4)
    div_j = np.max(np.abs(div(j)))
    assert np.isclose(div_j, 0, atol=1E-4)


@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_current():
    syst = syst_2d().finalized()
    J = kwant.operator.Current(syst)
    current = J(kwant.wave_function(syst, energy=1)(1)[0])

    # Test good codepath
    with tempfile.TemporaryFile('w+b') as out:
        plotter.current(syst, current, file=out)

        fig = pyplot.Figure()
        ax = fig.add_subplot(1, 1, 1)
        plotter.current(syst, current, ax=ax, file=out)