diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py index 39f994ee5c12aa259aa9e00ac817b14fcccf77c1..a52712e33979a23e4a0123568edf2e654f61918d 100644 --- a/kwant/tests/test_wraparound.py +++ b/kwant/tests/test_wraparound.py @@ -6,14 +6,21 @@ # the file AUTHORS.rst at the top-level directory of this distribution and at # http://kwant-project.org/authors. +import tempfile import itertools import numpy as np import tinyarray as ta +import pytest import kwant -from kwant.wraparound import wraparound +from kwant import plotter +from kwant.wraparound import wraparound, plot_2d_bands from kwant._common import get_parameters +if plotter.mpl_enabled: + from mpl_toolkits import mplot3d # pragma: no flakes + from matplotlib import pyplot # pragma: no flakes + def _simple_syst(lat, E=0, t=1+1j, sym=None): """Create a builder for a simple infinite system.""" @@ -192,3 +199,54 @@ def test_symmetry(): assert np.all(orig(None) == new(None, None, None)) else: assert np.all(orig == new) + + +@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.") +def test_plot_2d_bands(): + chain = kwant.lattice.chain() + square = kwant.lattice.square() + cube = kwant.lattice.general([(1, 0, 0), (0, 1, 0), (0, 0, 1)]) + hc = kwant.lattice.honeycomb() + + syst_1d = kwant.Builder(kwant.TranslationalSymmetry(*chain._prim_vecs)) + syst_1d[chain(0)] = 2 + syst_1d[chain.neighbors()] = -1 + + syst_2d = _simple_syst(square, t=-1) + syst_graphene = _simple_syst(hc, t=-1) + + syst_3d = kwant.Builder(kwant.TranslationalSymmetry(*cube._prim_vecs)) + syst_3d[cube(0, 0, 0)] = 6 + syst_3d[cube.neighbors()] = -1 + + + with tempfile.TemporaryFile('w+b') as out: + # test 2D + plot_2d_bands(wraparound(syst_2d).finalized(), k_x=11, k_y=11, file=out) + plot_2d_bands(wraparound(syst_graphene).finalized(), k_x=11, k_y=11, + file=out) + + # test non-wrapped around system + with pytest.raises(TypeError): + plot_2d_bands(syst_1d.finalized()) + # test incompletely wrapped around system + with pytest.raises(TypeError): + plot_2d_bands(wraparound(syst_2d, keep=0).finalized()) + # test incorrect lattice dimention (1, 3) + with pytest.raises(ValueError): + plot_2d_bands(wraparound(syst_1d).finalized()) + with pytest.raises(ValueError): + plot_2d_bands(wraparound(syst_3d).finalized()) + + # test k_x and k_y differ + with tempfile.TemporaryFile('w+b') as out: + syst = wraparound(syst_2d).finalized() + plot_2d_bands(syst, k_x=11, k_y=15, file=out) + plot_2d_bands(syst, k_x=np.linspace(-np.pi, np.pi, 11), file=out) + plot_2d_bands(syst, k_y=np.linspace(-np.pi, np.pi, 11), file=out) + + syst = wraparound(syst_graphene).finalized() + # test extend_bbox2d + plot_2d_bands(syst, extend_bbox=1.2, k_x=11, k_y=11, file=out) + # test mask Brillouin zone + plot_2d_bands(syst, mask_brillouin_zone=True, k_x=11, k_y=11, file=out) diff --git a/kwant/wraparound.py b/kwant/wraparound.py index 2146430d80de9b01b63d8617c7980cab84e62924..1936486bd85e96f42d0e96cb8e39e335c7eb08ca 100644 --- a/kwant/wraparound.py +++ b/kwant/wraparound.py @@ -9,15 +9,20 @@ import collections import inspect import cmath + import tinyarray as ta +import numpy as np +import scipy.linalg +import scipy.spatial -from . import builder +from . import builder, system, plotter +from .linalg import lll from .builder import herm_conj, HermConjOfFunc from .lattice import TranslationalSymmetry from ._common import get_parameters -__all__ = ['wraparound'] +__all__ = ['wraparound', 'plot_2d_bands'] def _hashable(obj): @@ -306,3 +311,155 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')): ret[hop] = vals[0] if len(vals) == 1 else bind_sum(2, *vals) return ret + + +def plot_2d_bands(syst, k_x=31, k_y=31, params=None, + mask_brillouin_zone=False, extend_bbox=0, file=None, + show=True, dpi=None, fig_size=None, ax=None): + """Plot 2D band structure of a wrapped around system. + + This function is primarily useful for systems that have translational + symmetry vectors that are non-orthogonal (e.g. graphene). This function + will properly plot the band structure in an orthonormal basis in k-space, + as opposed to in the basis of reciprocal lattice vectors (which would + produce a "skewed" Brillouin zone). + + If your system has orthogonal lattice vectors, you are probably better + off using `kwant.plotter.spectrum`. + + Parameters + ---------- + syst : `kwant.system.FiniteSystem` + A 2D system that was finalized from a Builder produced by + `kwant.wraparound.wraparound`. Note that this *must* be a finite + system; so `kwant.wraparound.wraparound` should have been called with + ``keep=None``. + k_x, k_y : int or sequence of float, default: 31 + Either a number of sampling points, or a sequence of points at which + the band structure is to be evaluated, in units of inverse length. + params : dict, optional + Dictionary of parameter names and their values, not including the + momentum parameters. + mask_brillouin_zone : bool, default: False + If True, then the band structure will only be plotted over the first + Brillouin zone. By default the band structure is plotted over a + rectangular bounding box that contains the Brillouin zone. + extend_bbox : float, default: 0 + Amount by which to extend the region over which the band structure is + plotted, expressed as a proportion of the Brillouin zone bounding box + length. i.e. ``extend_bbox=0.1`` will extend the region by 10% (in all + directions). + file : string or file object, optional + The output file. If None, output will be shown instead. + show : bool, default: False + Whether ``matplotlib.pyplot.show()`` is to be called, and the output is + to be shown immediately. Defaults to `True`. + dpi : float, optional + Number of pixels per inch. If not set the ``matplotlib`` default is + used. + fig_size : tuple, optional + Figure size `(width, height)` in inches. If not set, the default + ``matplotlib`` value is used. + ax : ``matplotlib.axes.Axes`` instance, optional + If `ax` is not `None`, no new figure is created, but the plot is done + within the existing Axes `ax`. in this case, `file`, `show`, `dpi` + and `fig_size` are ignored. + + Returns + ------- + fig : matplotlib figure + A figure with the output if `ax` is not set, else None. + + Notes + ----- + This function produces plots where the units of momentum are inverse + length. This is contrary to `kwant.plotter.bands`, where the units + of momentum are inverse lattice constant. + + If the lattice vectors for the symmetry of ``syst`` are not orthogonal, + then part of the plotted band structure will be outside the first Brillouin + zone (inside the bounding box of the brillouin zone). Setting + ``mask_brillouin_zone=True`` will cause the plot to be truncated outside of + the first Brillouin zone. + + See Also + -------- + `kwant.plotter.spectrum` + """ + if not hasattr(syst, '_wrapped_symmetry'): + raise TypeError("Expecting a system that was produced by " + "'kwant.wraparound.wraparound'.") + if not isinstance(syst, system.FiniteSystem): + msg = ("All symmetry directions must be wrapped around: specify " + "'keep=None' when calling 'kwant.wraparound.wraparound'.") + raise TypeError(msg) + + params = params or {} + lat_ndim, space_ndim = syst._wrapped_symmetry.periods.shape + + if lat_ndim != 2: + raise ValueError("Expected a system with a 2D translational symmetry.") + if space_ndim != lat_ndim: + raise ValueError("Lattice dimension must equal realspace dimension.") + + # columns of B are lattice vectors + B = np.array(syst._wrapped_symmetry.periods).T + # columns of A are reciprocal lattice vectors + A = B.dot(np.linalg.inv(B.T.dot(B))) + + ## calculate the bounding box for the 1st Brillouin zone + + # Get lattice points that neighbor the origin, in basis of lattice vectors + reduced_vecs, transf = lll.lll(A.T) + neighbors = ta.dot(lll.voronoi(reduced_vecs), transf) + # Add the origin to these points. + klat_points = np.concatenate(([[0] * lat_ndim], neighbors)) + # Transform to cartesian coordinates and rescale. + # Will be used in 'outside_bz' function, later on. + klat_points = 2 * np.pi * np.dot(klat_points, A.T) + # Calculate the Voronoi cell vertices + vor = scipy.spatial.Voronoi(klat_points) + around_origin = vor.point_region[0] + bz_vertices = vor.vertices[vor.regions[around_origin]] + # extract bounding box + k_max = np.max(np.abs(bz_vertices), axis=0) + + ## build grid along each axis, if needed + ks = [] + for k, km in zip((k_x, k_y), k_max): + k = np.array(k) + if not k.shape: + if extend_bbox: + km += km * extend_bbox + k = np.linspace(-km, km, k) + ks.append(k) + + # TODO: It is very inefficient to call 'momentum_to_lattice' once for + # each point (for trivial Hamiltonians 60% of the time is spent + # doing this). We should instead transform the whole grid in one call. + + def momentum_to_lattice(k): + k, residuals = scipy.linalg.lstsq(A, k)[:2] + if np.any(abs(residuals) > 1e-7): + raise RuntimeError("Requested momentum doesn't correspond" + " to any lattice momentum.") + return k + + def ham(k_x, k_y=None, **params): + # transform into the basis of reciprocal lattice vectors + k = momentum_to_lattice([k_x] if k_y is None else [k_x, k_y]) + p = dict(zip(syst._momentum_names, k), **params) + return syst.hamiltonian_submatrix(params=p, sparse=False) + + def outside_bz(k_x, k_y, **_): + dm = scipy.spatial.distance_matrix(klat_points, [[k_x, k_y]]) + return np.argmin(dm) != 0 # is origin no closest 'klat_point' to 'k'? + + fig = plotter.spectrum(ham, + x=('k_x', ks[0]), + y=('k_y', ks[1]) if lat_ndim == 2 else None, + params=params, + mask=(outside_bz if mask_brillouin_zone else None), + file=file, show=show, dpi=dpi, + fig_size=fig_size, ax=ax) + return fig