Skip to content
Snippets Groups Projects
Commit d935312e authored by Joseph Weston's avatar Joseph Weston
Browse files

add 2D band structure plotting to 'wraparound'

This function is adapted to plotting band structure in the case
where the translation symmetry vectors are not orthogonal
(e.g. graphene). It can only handle a very limited subset of
systems (2D lattices in 2D realspace), and most of the time
use of 'plotter.spectrum' should be preferred.
parent 5042d96a
No related branches found
No related tags found
No related merge requests found
Pipeline #
......@@ -6,14 +6,21 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import tempfile
import itertools
import numpy as np
import tinyarray as ta
import pytest
import kwant
from kwant.wraparound import wraparound
from kwant import plotter
from kwant.wraparound import wraparound, plot_2d_bands
from kwant._common import get_parameters
if plotter.mpl_enabled:
from mpl_toolkits import mplot3d # pragma: no flakes
from matplotlib import pyplot # pragma: no flakes
def _simple_syst(lat, E=0, t=1+1j, sym=None):
"""Create a builder for a simple infinite system."""
......@@ -192,3 +199,54 @@ def test_symmetry():
assert np.all(orig(None) == new(None, None, None))
else:
assert np.all(orig == new)
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_plot_2d_bands():
chain = kwant.lattice.chain()
square = kwant.lattice.square()
cube = kwant.lattice.general([(1, 0, 0), (0, 1, 0), (0, 0, 1)])
hc = kwant.lattice.honeycomb()
syst_1d = kwant.Builder(kwant.TranslationalSymmetry(*chain._prim_vecs))
syst_1d[chain(0)] = 2
syst_1d[chain.neighbors()] = -1
syst_2d = _simple_syst(square, t=-1)
syst_graphene = _simple_syst(hc, t=-1)
syst_3d = kwant.Builder(kwant.TranslationalSymmetry(*cube._prim_vecs))
syst_3d[cube(0, 0, 0)] = 6
syst_3d[cube.neighbors()] = -1
with tempfile.TemporaryFile('w+b') as out:
# test 2D
plot_2d_bands(wraparound(syst_2d).finalized(), k_x=11, k_y=11, file=out)
plot_2d_bands(wraparound(syst_graphene).finalized(), k_x=11, k_y=11,
file=out)
# test non-wrapped around system
with pytest.raises(TypeError):
plot_2d_bands(syst_1d.finalized())
# test incompletely wrapped around system
with pytest.raises(TypeError):
plot_2d_bands(wraparound(syst_2d, keep=0).finalized())
# test incorrect lattice dimention (1, 3)
with pytest.raises(ValueError):
plot_2d_bands(wraparound(syst_1d).finalized())
with pytest.raises(ValueError):
plot_2d_bands(wraparound(syst_3d).finalized())
# test k_x and k_y differ
with tempfile.TemporaryFile('w+b') as out:
syst = wraparound(syst_2d).finalized()
plot_2d_bands(syst, k_x=11, k_y=15, file=out)
plot_2d_bands(syst, k_x=np.linspace(-np.pi, np.pi, 11), file=out)
plot_2d_bands(syst, k_y=np.linspace(-np.pi, np.pi, 11), file=out)
syst = wraparound(syst_graphene).finalized()
# test extend_bbox2d
plot_2d_bands(syst, extend_bbox=1.2, k_x=11, k_y=11, file=out)
# test mask Brillouin zone
plot_2d_bands(syst, mask_brillouin_zone=True, k_x=11, k_y=11, file=out)
......@@ -9,15 +9,20 @@
import collections
import inspect
import cmath
import tinyarray as ta
import numpy as np
import scipy.linalg
import scipy.spatial
from . import builder
from . import builder, system, plotter
from .linalg import lll
from .builder import herm_conj, HermConjOfFunc
from .lattice import TranslationalSymmetry
from ._common import get_parameters
__all__ = ['wraparound']
__all__ = ['wraparound', 'plot_2d_bands']
def _hashable(obj):
......@@ -306,3 +311,155 @@ def wraparound(builder, keep=None, *, coordinate_names=('x', 'y', 'z')):
ret[hop] = vals[0] if len(vals) == 1 else bind_sum(2, *vals)
return ret
def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
mask_brillouin_zone=False, extend_bbox=0, file=None,
show=True, dpi=None, fig_size=None, ax=None):
"""Plot 2D band structure of a wrapped around system.
This function is primarily useful for systems that have translational
symmetry vectors that are non-orthogonal (e.g. graphene). This function
will properly plot the band structure in an orthonormal basis in k-space,
as opposed to in the basis of reciprocal lattice vectors (which would
produce a "skewed" Brillouin zone).
If your system has orthogonal lattice vectors, you are probably better
off using `kwant.plotter.spectrum`.
Parameters
----------
syst : `kwant.system.FiniteSystem`
A 2D system that was finalized from a Builder produced by
`kwant.wraparound.wraparound`. Note that this *must* be a finite
system; so `kwant.wraparound.wraparound` should have been called with
``keep=None``.
k_x, k_y : int or sequence of float, default: 31
Either a number of sampling points, or a sequence of points at which
the band structure is to be evaluated, in units of inverse length.
params : dict, optional
Dictionary of parameter names and their values, not including the
momentum parameters.
mask_brillouin_zone : bool, default: False
If True, then the band structure will only be plotted over the first
Brillouin zone. By default the band structure is plotted over a
rectangular bounding box that contains the Brillouin zone.
extend_bbox : float, default: 0
Amount by which to extend the region over which the band structure is
plotted, expressed as a proportion of the Brillouin zone bounding box
length. i.e. ``extend_bbox=0.1`` will extend the region by 10% (in all
directions).
file : string or file object, optional
The output file. If None, output will be shown instead.
show : bool, default: False
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float, optional
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
fig_size : tuple, optional
Figure size `(width, height)` in inches. If not set, the default
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance, optional
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
Returns
-------
fig : matplotlib figure
A figure with the output if `ax` is not set, else None.
Notes
-----
This function produces plots where the units of momentum are inverse
length. This is contrary to `kwant.plotter.bands`, where the units
of momentum are inverse lattice constant.
If the lattice vectors for the symmetry of ``syst`` are not orthogonal,
then part of the plotted band structure will be outside the first Brillouin
zone (inside the bounding box of the brillouin zone). Setting
``mask_brillouin_zone=True`` will cause the plot to be truncated outside of
the first Brillouin zone.
See Also
--------
`kwant.plotter.spectrum`
"""
if not hasattr(syst, '_wrapped_symmetry'):
raise TypeError("Expecting a system that was produced by "
"'kwant.wraparound.wraparound'.")
if not isinstance(syst, system.FiniteSystem):
msg = ("All symmetry directions must be wrapped around: specify "
"'keep=None' when calling 'kwant.wraparound.wraparound'.")
raise TypeError(msg)
params = params or {}
lat_ndim, space_ndim = syst._wrapped_symmetry.periods.shape
if lat_ndim != 2:
raise ValueError("Expected a system with a 2D translational symmetry.")
if space_ndim != lat_ndim:
raise ValueError("Lattice dimension must equal realspace dimension.")
# columns of B are lattice vectors
B = np.array(syst._wrapped_symmetry.periods).T
# columns of A are reciprocal lattice vectors
A = B.dot(np.linalg.inv(B.T.dot(B)))
## calculate the bounding box for the 1st Brillouin zone
# Get lattice points that neighbor the origin, in basis of lattice vectors
reduced_vecs, transf = lll.lll(A.T)
neighbors = ta.dot(lll.voronoi(reduced_vecs), transf)
# Add the origin to these points.
klat_points = np.concatenate(([[0] * lat_ndim], neighbors))
# Transform to cartesian coordinates and rescale.
# Will be used in 'outside_bz' function, later on.
klat_points = 2 * np.pi * np.dot(klat_points, A.T)
# Calculate the Voronoi cell vertices
vor = scipy.spatial.Voronoi(klat_points)
around_origin = vor.point_region[0]
bz_vertices = vor.vertices[vor.regions[around_origin]]
# extract bounding box
k_max = np.max(np.abs(bz_vertices), axis=0)
## build grid along each axis, if needed
ks = []
for k, km in zip((k_x, k_y), k_max):
k = np.array(k)
if not k.shape:
if extend_bbox:
km += km * extend_bbox
k = np.linspace(-km, km, k)
ks.append(k)
# TODO: It is very inefficient to call 'momentum_to_lattice' once for
# each point (for trivial Hamiltonians 60% of the time is spent
# doing this). We should instead transform the whole grid in one call.
def momentum_to_lattice(k):
k, residuals = scipy.linalg.lstsq(A, k)[:2]
if np.any(abs(residuals) > 1e-7):
raise RuntimeError("Requested momentum doesn't correspond"
" to any lattice momentum.")
return k
def ham(k_x, k_y=None, **params):
# transform into the basis of reciprocal lattice vectors
k = momentum_to_lattice([k_x] if k_y is None else [k_x, k_y])
p = dict(zip(syst._momentum_names, k), **params)
return syst.hamiltonian_submatrix(params=p, sparse=False)
def outside_bz(k_x, k_y, **_):
dm = scipy.spatial.distance_matrix(klat_points, [[k_x, k_y]])
return np.argmin(dm) != 0 # is origin no closest 'klat_point' to 'k'?
fig = plotter.spectrum(ham,
x=('k_x', ks[0]),
y=('k_y', ks[1]) if lat_ndim == 2 else None,
params=params,
mask=(outside_bz if mask_brillouin_zone else None),
file=file, show=show, dpi=dpi,
fig_size=fig_size, ax=ax)
return fig
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment