Commit e3509c3e authored by Joseph Weston's avatar Joseph Weston
Browse files

factor out 'plotter.bands' into 'plotter.spectrum'.

'plotter.spectrum' can be used to plot the spectrum of
a Hamiltonian as a function of arbitrary parameters. This is
generically useful functionality, so we factor it out.
parent 920b09b4
# Copyright 2011-2013 Kwant authors.
# Copyright 2011-2017 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -15,7 +15,10 @@ system in two or three dimensions.
"""
from collections import defaultdict
import itertools
import functools
import warnings
import cmath
import numpy as np
import tinyarray as ta
from scipy import spatial, interpolate
......@@ -42,10 +45,11 @@ except ImportError:
"functions will work.", RuntimeWarning)
mpl_enabled = False
from . import system, builder, physics, _common
from . import system, builder, _common
__all__ = ['plot', 'map', 'bands', 'sys_leads_sites', 'sys_leads_hoppings',
'sys_leads_pos', 'sys_leads_hopping_pos', 'mask_interpolate']
__all__ = ['plot', 'map', 'bands', 'spectrum', 'sys_leads_sites',
'sys_leads_hoppings', 'sys_leads_pos', 'sys_leads_hopping_pos',
'mask_interpolate']
# TODO: Remove the following once we depend on matplotlib >= 1.4.1.
......@@ -1639,13 +1643,114 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
"for bands()")
syst = sys # for naming consistency inside function bodies
_common.ensure_isinstance(syst, system.InfiniteSystem)
momenta = np.array(momenta)
if momenta.ndim != 1:
momenta = np.linspace(-np.pi, np.pi, momenta)
bands = physics.Bands(syst, args, params=params)
energies = [bands(k) for k in momenta]
# expand out the contents of 'physics.Bands' to get the H(k),
# because 'spectrum' already does the diagonalisation.
ham = syst.cell_hamiltonian(args, params=params)
if not np.allclose(ham, ham.conjugate().transpose()):
raise ValueError('The cell Hamiltonian is not Hermitian.')
_hop = syst.inter_cell_hopping(args, params=params)
hop = np.empty(ham.shape, dtype=complex)
hop[:, :_hop.shape[1]] = _hop
hop[:, _hop.shape[1]:] = 0
def h_k(k):
# H_k = H_0 + V e^-ik + V^\dagger e^ik
mat = hop * cmath.exp(-1j * k)
mat += mat.conjugate().transpose() + ham
return mat
return spectrum(h_k, ('k', momenta), file=file, show=show, dpi=dpi,
fig_size=fig_size, ax=ax)
def spectrum(syst, x, y=None, params=None, mask=None, file=None,
show=True, dpi=None, fig_size=None, ax=None):
"""Plot the spectrum of a Hamiltonian as a function of 1 or 2 parameters
Parameters
----------
syst : `kwant.system.FiniteSystem` or callable
If a function, then it must take named parameters and return the
Hamiltonian as a dense matrix.
x : pair ``(name, values)``
Parameter to ``ham`` that will be varied. Consists of the
parameter name, and a sequence of parameter values.
y : pair ``(name, values)``, optional
Used for 3D plots (same as ``x``). If provided, then the cartesian
product of the ``x`` values and these values will be used as a grid
over which to evaluate the spectrum.
params : dict, optional
The rest of the parameters to ``ham``, which will be kept constant.
mask : callable, optional
Takes the parameters specified by ``x`` and ``y`` and returns True
if the spectrum should not be calculated for the given parameter
values.
file : string or file object or `None`
The output file. If `None`, output will be shown instead.
show : bool
Whether ``matplotlib.pyplot.show()`` is to be called, and the output is
to be shown immediately. Defaults to `True`.
dpi : float
Number of pixels per inch. If not set the ``matplotlib`` default is
used.
fig_size : tuple
Figure size `(width, height)` in inches. If not set, the default
``matplotlib`` value is used.
ax : ``matplotlib.axes.Axes`` instance or `None`
If `ax` is not `None`, no new figure is created, but the plot is done
within the existing Axes `ax`. in this case, `file`, `show`, `dpi`
and `fig_size` are ignored.
Returns
-------
fig : matplotlib figure
A figure with the output if `ax` is not set, else None.
"""
if not mpl_enabled:
raise RuntimeError("matplotlib was not found, but is required "
"for plot_spectrum()")
if y is not None and not has3d:
raise RuntimeError("Installed matplotlib does not support 3d plotting")
if isinstance(syst, system.FiniteSystem):
def ham(**kwargs):
return syst.hamiltonian_submatrix(params=kwargs, sparse=False)
elif callable(syst):
ham = syst
else:
raise TypeError("Expected 'syst' to be a finite Kwant system "
"or a function.")
params = params or dict()
keys = (x[0],) if y is None else (x[0], y[0])
array_values = (x[1],) if y is None else (x[1], y[1])
# calculate spectrum on the grid of points
spectrum = []
bound_ham = functools.partial(ham, **params)
for point in itertools.product(*array_values):
p = dict(zip(keys, point))
if mask and mask(**p):
spectrum.append(None)
else:
h_p = np.atleast_2d(bound_ham(**p))
spectrum.append(np.linalg.eigvalsh(h_p))
# massage masked grid points into a list of NaNs of the appropriate length
n_eigvals = len(next(filter(lambda s: s is not None, spectrum)))
nan_list = [np.nan] * n_eigvals
spectrum = [nan_list if s is None else s for s in spectrum]
# make into a numpy array and reshape
new_shape = [len(v) for v in array_values] + [-1]
spectrum = np.array(spectrum).reshape(new_shape)
# set up axes
if ax is None:
fig = Figure()
if dpi is not None:
......@@ -1653,10 +1758,32 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
if fig_size is not None:
fig.set_figwidth(fig_size[0])
fig.set_figheight(fig_size[1])
ax = fig.add_subplot(1, 1, 1)
projection = '3d' if y is not None else None
ax = fig.add_subplot(1, 1, 1, projection=projection)
ax.set_xlabel(keys[0])
if y is None:
ax.set_ylabel('Energy')
else:
ax.set_ylabel(keys[1])
ax.set_zlabel('Energy')
ax.set_title(', '.join('{} = {}'.format(*kv) for kv in params.items()))
else:
fig = None
ax.plot(momenta, energies)
# actually do the plot
if y is None:
ax.plot(array_values[0], spectrum)
else:
if not hasattr(ax, 'plot_surface'):
msg = ("When providing an axis for plotting over a 2D domain the "
"axis should be created with 'projection=\"3d\"")
raise TypeError(msg)
# plot_surface cannot directly handle rank-3 values, so we
# explicitly loop over the last axis
grid = np.meshgrid(*array_values)
for i in range(spectrum.shape[-1]):
spec = spectrum[:, :, i].transpose() # row-major to x-y ordering
ax.plot_surface(*(grid + [spec]), cstride=1, rstride=1)
if fig is not None:
return output_fig(fig, file=file, show=show)
......
......@@ -171,3 +171,69 @@ def test_mask_interpolate():
coords, np.ones(len(coords)))
pytest.raises(ValueError, plotter.mask_interpolate,
coords, np.ones(2 * len(coords)))
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_bands():
syst = syst_2d().finalized().leads[0]
with tempfile.TemporaryFile('w+b') as out:
plotter.bands(syst, file=out)
plotter.bands(syst, fig_size=(10, 10), file=out)
plotter.bands(syst, momenta=np.linspace(0, 2 * np.pi), file=out)
fig = pyplot.Figure()
ax = fig.add_subplot(1, 1, 1)
plotter.bands(syst, ax=ax, file=out)
@pytest.mark.skipif(not plotter.mpl_enabled, reason="No matplotlib available.")
def test_spectrum():
def ham_1d(a, b, c):
return a**2 + b**2 + c**2
def ham_2d(a, b, c):
return np.eye(2) * (a**2 + b**2 + c**2)
lat = kwant.lattice.chain()
syst = kwant.Builder()
syst[(lat(i) for i in range(3))] = lambda site, a, b: a + b
syst[lat.neighbors()] = lambda site1, site2, c: c
fsyst = syst.finalized()
vals = np.linspace(0, 1, 3)
with tempfile.TemporaryFile('w+b') as out:
for ham in (ham_1d, ham_2d, fsyst):
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out)
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
fig_size=(10, 10), file=out)
for ham in (ham_1d, ham_2d, fsyst):
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
params=dict(c=1), file=out)
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
params=dict(c=1), fig_size=(10, 10), file=out)
# test 2D plot and explicitly passing axis
fig = pyplot.Figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
params=dict(c=1), ax=ax, file=out)
# explicitly pass axis without 3D support
ax = fig.add_subplot(1, 1, 1)
with pytest.raises(TypeError):
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
params=dict(c=1), ax=ax, file=out)
def mask(a, b):
return a > 0.5
with tempfile.TemporaryFile('w+b') as out:
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
mask=mask, file=out)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment