From d935312ea14b3cdb83d12f4d62865c2c2dd718ca Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Wed, 5 Apr 2017 15:28:04 +0200
Subject: [PATCH] add 2D band structure plotting to 'wraparound'

This function is adapted to plotting band structure in the case
where the translation symmetry vectors are not orthogonal
(e.g. graphene). It can only handle a very limited subset of
systems (2D lattices in 2D realspace), and most of the time
use of 'plotter.spectrum' should be preferred.
---
 kwant/tests/test_wraparound.py |  60 +++++++++++-
 kwant/wraparound.py            | 161 ++++++++++++++++++++++++++++++++-
 2 files changed, 218 insertions(+), 3 deletions(-)

diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py
index 39f994ee..a52712e3 100644
--- a/kwant/tests/test_wraparound.py
+++ b/kwant/tests/test_wraparound.py
@@ -6,14 +6,21 @@
 # the file AUTHORS.rst at the top-level directory of this distribution and at
 # http://kwant-project.org/authors.
 
+import tempfile
 import itertools
 import numpy as np
 import tinyarray as ta
+import pytest
 
 import kwant
-from kwant.wraparound import wraparound
+from kwant import plotter
+from kwant.wraparound import wraparound, plot_2d_bands
 from kwant._common import get_parameters
 
+if plotter.mpl_enabled:
+    from mpl_toolkits import mplot3d  # pragma: no flakes
+    from matplotlib import pyplot  # pragma: no flakes
+
 
 def _simple_syst(lat, E=0, t=1+1j, sym=None):
     """Create a builder for a simple infinite system."""
@@ -192,3 +199,54 @@ def test_symmetry():
                 assert np.all(orig(None) == new(None, None, None))
             else:
                 assert np.all(orig == new)
+
+
+@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
+def test_plot_2d_bands():
+    chain = kwant.lattice.chain()
+    square = kwant.lattice.square()
+    cube = kwant.lattice.general([(1, 0, 0), (0, 1, 0), (0, 0, 1)])
+    hc = kwant.lattice.honeycomb()
+
+    syst_1d = kwant.Builder(kwant.TranslationalSymmetry(*chain._prim_vecs))
+    syst_1d[chain(0)] = 2
+    syst_1d[chain.neighbors()] = -1
+
+    syst_2d = _simple_syst(square, t=-1)
+    syst_graphene = _simple_syst(hc, t=-1)
+
+    syst_3d = kwant.Builder(kwant.TranslationalSymmetry(*cube._prim_vecs))
+    syst_3d[cube(0, 0, 0)] = 6
+    syst_3d[cube.neighbors()] = -1
+
+
+    with tempfile.TemporaryFile('w+b') as out:
+        # test 2D
+        plot_2d_bands(wraparound(syst_2d).finalized(), k_x=11, k_y=11, file=out)
+        plot_2d_bands(wraparound(syst_graphene).finalized(), k_x=11, k_y=11,
+                   file=out)
+
+    # test non-wrapped around system
+    with pytest.raises(TypeError):
+        plot_2d_bands(syst_1d.finalized())
+    # test incompletely wrapped around system
+    with pytest.raises(TypeError):
+        plot_2d_bands(wraparound(syst_2d, keep=0).finalized())
+    # test incorrect lattice dimention (1, 3)
+    with pytest.raises(ValueError):
+        plot_2d_bands(wraparound(syst_1d).finalized())
+    with pytest.raises(ValueError):
+        plot_2d_bands(wraparound(syst_3d).finalized())
+
+    # test k_x and k_y differ
+    with tempfile.TemporaryFile('w+b') as out:
+        syst = wraparound(syst_2d).finalized()
+        plot_2d_bands(syst, k_x=11, k_y=15, file=out)
+        plot_2d_bands(syst, k_x=np.linspace(-np.pi, np.pi, 11), file=out)
+        plot_2d_bands(syst, k_y=np.linspace(-np.pi, np.pi, 11), file=out)
+
+        syst = wraparound(syst_graphene).finalized()
+        # test extend_bbox2d
+        plot_2d_bands(syst, extend_bbox=1.2, k_x=11, k_y=11, file=out)
+        # test mask Brillouin zone
+        plot_2d_bands(syst, mask_brillouin_zone=True, k_x=11, k_y=11, file=out)
diff --git a/kwant/wraparound.py b/kwant/wraparound.py
index 2146430d..1936486b 100644
--- a/kwant/wraparound.py
+++ b/kwant/wraparound.py
@@ -9,15 +9,20 @@
 import collections
 import inspect
 import cmath
+
 import tinyarray as ta
+import numpy as np
+import scipy.linalg
+import scipy.spatial
 
-from . import builder
+from . import builder, system, plotter
+from .linalg import lll
 from .builder import herm_conj, HermConjOfFunc
 from .lattice import TranslationalSymmetry
 from ._common import get_parameters
 
 
-__all__ = ['wraparound']
+__all__ = ['wraparound', 'plot_2d_bands']
 
 
 def _hashable(obj):
@@ -306,3 +311,155 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')):
         ret[hop] = vals[0] if len(vals) == 1 else bind_sum(2, *vals)
 
     return ret
+
+
+def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
+                  mask_brillouin_zone=False, extend_bbox=0, file=None,
+                  show=True, dpi=None, fig_size=None, ax=None):
+    """Plot 2D band structure of a wrapped around system.
+
+    This function is primarily useful for systems that have translational
+    symmetry vectors that are non-orthogonal (e.g. graphene). This function
+    will properly plot the band structure in an orthonormal basis in k-space,
+    as opposed to in the basis of reciprocal lattice vectors (which would
+    produce a "skewed" Brillouin zone).
+
+    If your system has orthogonal lattice vectors, you are probably better
+    off using `kwant.plotter.spectrum`.
+
+    Parameters
+    ----------
+    syst : `kwant.system.FiniteSystem`
+        A 2D system that was finalized from a Builder produced by
+        `kwant.wraparound.wraparound`.  Note that this *must* be a finite
+        system; so `kwant.wraparound.wraparound` should have been called with
+        ``keep=None``.
+    k_x, k_y : int or sequence of float, default: 31
+        Either a number of sampling points, or a sequence of points at which
+        the band structure is to be evaluated, in units of inverse length.
+    params : dict, optional
+        Dictionary of parameter names and their values, not including the
+        momentum parameters.
+    mask_brillouin_zone : bool, default: False
+        If True, then the band structure will only be plotted over the first
+        Brillouin zone. By default the band structure is plotted over a
+        rectangular bounding box that contains the Brillouin zone.
+    extend_bbox : float, default: 0
+        Amount by which to extend the region over which the band structure is
+        plotted, expressed as a proportion of the Brillouin zone bounding box
+        length. i.e. ``extend_bbox=0.1`` will extend the region by 10% (in all
+        directions).
+    file : string or file object, optional
+        The output file.  If None, output will be shown instead.
+    show : bool, default: False
+        Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
+        to be shown immediately.  Defaults to `True`.
+    dpi : float, optional
+        Number of pixels per inch.  If not set the ``matplotlib`` default is
+        used.
+    fig_size : tuple, optional
+        Figure size `(width, height)` in inches.  If not set, the default
+        ``matplotlib`` value is used.
+    ax : ``matplotlib.axes.Axes`` instance, optional
+        If `ax` is not `None`, no new figure is created, but the plot is done
+        within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
+        and `fig_size` are ignored.
+
+    Returns
+    -------
+    fig : matplotlib figure
+        A figure with the output if `ax` is not set, else None.
+
+    Notes
+    -----
+    This function produces plots where the units of momentum are inverse
+    length. This is contrary to `kwant.plotter.bands`, where the units
+    of momentum are inverse lattice constant.
+
+    If the lattice vectors for the symmetry of ``syst`` are not orthogonal,
+    then part of the plotted band structure will be outside the first Brillouin
+    zone (inside the bounding box of the brillouin zone). Setting
+    ``mask_brillouin_zone=True`` will cause the plot to be truncated outside of
+    the first Brillouin zone.
+
+    See Also
+    --------
+    `kwant.plotter.spectrum`
+    """
+    if not hasattr(syst, '_wrapped_symmetry'):
+        raise TypeError("Expecting a system that was produced by "
+                        "'kwant.wraparound.wraparound'.")
+    if not isinstance(syst, system.FiniteSystem):
+        msg = ("All symmetry directions must be wrapped around: specify "
+               "'keep=None' when calling 'kwant.wraparound.wraparound'.")
+        raise TypeError(msg)
+
+    params = params or {}
+    lat_ndim, space_ndim = syst._wrapped_symmetry.periods.shape
+
+    if lat_ndim != 2:
+        raise ValueError("Expected a system with a 2D translational symmetry.")
+    if space_ndim != lat_ndim:
+        raise ValueError("Lattice dimension must equal realspace dimension.")
+
+    # columns of B are lattice vectors
+    B = np.array(syst._wrapped_symmetry.periods).T
+    # columns of A are reciprocal lattice vectors
+    A = B.dot(np.linalg.inv(B.T.dot(B)))
+
+    ## calculate the bounding box for the 1st Brillouin zone
+
+    # Get lattice points that neighbor the origin, in basis of lattice vectors
+    reduced_vecs, transf = lll.lll(A.T)
+    neighbors = ta.dot(lll.voronoi(reduced_vecs), transf)
+    # Add the origin to these points.
+    klat_points = np.concatenate(([[0] * lat_ndim], neighbors))
+    # Transform to cartesian coordinates and rescale.
+    # Will be used in 'outside_bz' function, later on.
+    klat_points = 2 * np.pi * np.dot(klat_points, A.T)
+    # Calculate the Voronoi cell vertices
+    vor = scipy.spatial.Voronoi(klat_points)
+    around_origin = vor.point_region[0]
+    bz_vertices = vor.vertices[vor.regions[around_origin]]
+    # extract bounding box
+    k_max = np.max(np.abs(bz_vertices), axis=0)
+
+    ## build grid along each axis, if needed
+    ks = []
+    for k, km in zip((k_x, k_y), k_max):
+        k = np.array(k)
+        if not k.shape:
+            if extend_bbox:
+                km += km * extend_bbox
+            k = np.linspace(-km, km, k)
+        ks.append(k)
+
+    # TODO: It is very inefficient to call 'momentum_to_lattice' once for
+    #       each point (for trivial Hamiltonians 60% of the time is spent
+    #       doing this). We should instead transform the whole grid in one call.
+
+    def momentum_to_lattice(k):
+        k, residuals = scipy.linalg.lstsq(A, k)[:2]
+        if np.any(abs(residuals) > 1e-7):
+            raise RuntimeError("Requested momentum doesn't correspond"
+                               " to any lattice momentum.")
+        return k
+
+    def ham(k_x, k_y=None, **params):
+        # transform into the basis of reciprocal lattice vectors
+        k = momentum_to_lattice([k_x] if k_y is None else [k_x, k_y])
+        p = dict(zip(syst._momentum_names, k), **params)
+        return syst.hamiltonian_submatrix(params=p, sparse=False)
+
+    def outside_bz(k_x, k_y, **_):
+        dm = scipy.spatial.distance_matrix(klat_points, [[k_x, k_y]])
+        return np.argmin(dm) != 0  # is origin no closest 'klat_point' to 'k'?
+
+    fig = plotter.spectrum(ham,
+                           x=('k_x', ks[0]),
+                           y=('k_y', ks[1]) if lat_ndim == 2 else None,
+                           params=params,
+                           mask=(outside_bz if mask_brillouin_zone else None),
+                           file=file, show=show, dpi=dpi,
+                           fig_size=fig_size, ax=ax)
+    return fig
-- 
GitLab