Commit 9293e703 authored by Joseph Weston's avatar Joseph Weston
Browse files

Merge branch 'import-time'

speed up kwant import from >1s to 0.3s

Closes #181

See merge request kwant/kwant!207
parents 35620fb4 1b18eb5c
@@ -1,222 +1,236 @@
@@ -1,225 +1,239 @@
# Tutorial 2.9. Processing continuum Hamiltonians with discretize
# ===============================================================
#
......@@ -13,6 +13,9 @@
+import _defs
import kwant
#HIDDEN_BEGIN_import
import kwant.continuum
#HIDDEN_END_import
import scipy.sparse.linalg
import scipy.linalg
import numpy as np
......
@@ -1,71 +1,75 @@
@@ -1,72 +1,76 @@
# Comprehensive example: quantum anomalous Hall effect
# ====================================================
#
......@@ -16,6 +16,7 @@
import math
import matplotlib.pyplot
import kwant
import kwant.continuum
# 2 band model exhibiting quantum anomalous Hall effect
......
......@@ -16,7 +16,9 @@ import sys, os
from distutils.util import get_platform
sys.path.insert(0, "../../build/lib.{0}-{1}.{2}".format(
get_platform(), *sys.version_info[:2]))
import kwant
import kwant.continuum # sphinx gets confused with lazy loading
# -- General configuration -----------------------------------------------------
......
......@@ -53,6 +53,11 @@ with :math:`A(x) = \frac{\hbar^2}{2 m(x)}`.
Using `~kwant.continuum.discretize` to obtain a template
........................................................
First we must explicitly import the `kwant.continuum` package:
.. literalinclude:: /code/include/discretize.py
:start-after: #HIDDEN_BEGIN_import
:end-before: #HIDDEN_END_import
The function `kwant.continuum.discretize` takes a symbolic Hamiltonian and
turns it into a `~kwant.builder.Builder` instance with appropriate spatial
......
......@@ -33,7 +33,7 @@ from ._common import KwantDeprecationWarning, UserCodeError
__all__.extend(['KwantDeprecationWarning', 'UserCodeError'])
for module in ['system', 'builder', 'lattice', 'solvers', 'digest', 'rmt',
'operator', 'kpm', 'wraparound', 'continuum']:
'operator', 'kpm', 'wraparound']:
exec('from . import {0}'.format(module))
__all__.append(module)
......@@ -56,6 +56,13 @@ except:
else:
__all__.extend(['plotter', 'plot'])
# Lazy import continuum package for backwards compatibility
from ._common import lazy_import
continuum = lazy_import('continuum', deprecation_warning=True)
__all__.append('continuum')
del lazy_import
def test(verbose=True):
from pytest import main
......
......@@ -6,10 +6,12 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import sys
import numpy as np
import numbers
import inspect
import warnings
import importlib
from contextlib import contextmanager
__all__ = ['KwantDeprecationWarning', 'UserCodeError']
......@@ -39,30 +41,6 @@ class UserCodeError(Exception):
pass
class ExtensionUnavailable:
"""Class that replaces unavailable extension modules in the Kwant namespace.
Some extensions for Kwant (e.g. 'kwant.continuum') require additional
dependencies that are not required for core functionality. When the
additional dependencies are not installed an instance of this class will
be inserted into Kwant's root namespace to simulate the presence of the
extension and inform users that they need to install additional
dependencies.
See https://mail.python.org/pipermail/python-ideas/2012-May/014969.html
for more details.
"""
def __init__(self, name, dependencies):
self.name = name
self.dependencies = ', '.join(dependencies)
def __getattr__(self, _):
msg = ("'{}' is not available because one or more of the following "
"dependencies are not installed: {}")
raise RuntimeError(msg.format(self.name, self.dependencies))
def ensure_isinstance(obj, typ, msg=None):
if isinstance(obj, typ):
return
......@@ -123,3 +101,25 @@ def get_parameters(func):
takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD
for i in pars.values())
return required_params, default_params, takes_kwargs
class lazy_import:
def __init__(self, module, package='kwant', deprecation_warning=False):
if module.startswith('.') and not package:
raise ValueError('Cannot import a relative module without a package.')
self.__module = module
self.__package = package
self.__deprecation_warning = deprecation_warning
def __getattr__(self, name):
if self.__deprecation_warning:
msg = ("Accessing {0} without an explicit import is deprecated. "
"Instead, explicitly 'import {0}'."
).format('.'.join((self.__package, self.__module)))
warnings.warn(msg, KwantDeprecationWarning, stacklevel=2)
relative_module = '.' + self.__module
mod = importlib.import_module(relative_module, self.__package)
# Replace this _LazyModuleProxy with an actual module
package = sys.modules[self.__package]
setattr(package, self.__module, mod)
return getattr(mod, name)
# -*- coding: utf-8 -*-
# Copyright 2011-2018 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# http://kwant-project.org/license. A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
# This module is imported by plotter.py. It contains all the expensive imports
# that we want to remove from plotter.py
# All matplotlib imports must be isolated in a try, because even without
# matplotlib iterators remain useful. Further, mpl_toolkits used for 3D
# plotting are also imported separately, to ensure that 2D plotting works even
# if 3D does not.
import warnings
from math import sqrt, pi
import numpy as np
try:
import matplotlib
import matplotlib.colors
import matplotlib.cm
from matplotlib.figure import Figure
from matplotlib import collections
from . import _colormaps
mpl_available = True
try:
from mpl_toolkits import mplot3d
has3d = True
except ImportError:
warnings.warn("3D plotting not available.", RuntimeWarning)
has3d = False
except ImportError:
warnings.warn("matplotlib is not available, only iterator-providing "
"functions will work.", RuntimeWarning)
mpl_available = False
# Collections that allow for symbols and linewiths to be given in data space
# (not for general use, only implement what's needed for plotter)
def isarray(var):
if hasattr(var, '__getitem__') and not isinstance(var, str):
return True
else:
return False
def nparray_if_array(var):
return np.asarray(var) if isarray(var) else var
if mpl_available:
class LineCollection(collections.LineCollection):
def __init__(self, segments, reflen=None, **kwargs):
super().__init__(segments, **kwargs)
self.reflen = reflen
def set_linewidths(self, linewidths):
self.linewidths_orig = nparray_if_array(linewidths)
def draw(self, renderer):
if self.reflen is not None:
# Note: only works for aspect ratio 1!
# 72.0 - there is 72 points in an inch
factor = (self.axes.transData.frozen().to_values()[0] * 72.0 *
self.reflen / self.figure.dpi)
else:
factor = 1
super().set_linewidths(self.linewidths_orig *
factor)
return super().draw(renderer)
class PathCollection(collections.PathCollection):
def __init__(self, paths, sizes=None, reflen=None, **kwargs):
super().__init__(paths, sizes=sizes, **kwargs)
self.reflen = reflen
self.linewidths_orig = nparray_if_array(self.get_linewidths())
self.transforms = np.array(
[matplotlib.transforms.Affine2D().scale(x).get_matrix()
for x in sizes])
def get_transforms(self):
return self.transforms
def get_transform(self):
Affine2D = matplotlib.transforms.Affine2D
if self.reflen is not None:
# For the paths, use the data transformation but strip the
# offset (will be added later with offsets)
args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
return Affine2D().from_values(*args).scale(self.reflen)
else:
return Affine2D().scale(self.figure.dpi / 72.0)
def draw(self, renderer):
if self.reflen:
# Note: only works for aspect ratio 1!
factor = (self.axes.transData.frozen().to_values()[0] /
self.figure.dpi * 72.0 * self.reflen)
self.set_linewidths(self.linewidths_orig * factor)
return collections.Collection.draw(self, renderer)
if has3d:
# Sorting is optional.
sort3d = True
# Compute the projection of a 3D length into 2D data coordinates
# for this we use 2 3D half-circles that are projected into 2D.
# (This gives the same length as projecting the full unit sphere.)
phi = np.linspace(0, pi, 21)
xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21)
unit_sphere = np.bmat([[xyz[0], xyz[2]], [xyz[1], xyz[0]],
[xyz[2], xyz[1]]])
unit_sphere = np.asarray(unit_sphere)
def projected_length(ax, length):
rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()])
rc = np.apply_along_axis(np.sum, 1, rc) / 2.
rs = unit_sphere * length + rc.reshape(-1, 1)
transform = mplot3d.proj3d.proj_transform
rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2])
rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2]
coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1)
return sqrt(np.sum(coords**2, axis=0).max())
# Auxiliary array for calculating corners of a cube.
corners = np.zeros((3, 8, 6), np.float_)
corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \
corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \
corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0
class Line3DCollection(mplot3d.art3d.Line3DCollection):
def __init__(self, segments, reflen=None, zorder=0, **kwargs):
super().__init__(segments, **kwargs)
self.reflen = reflen
self.zorder3d = zorder
def set_linewidths(self, linewidths):
self.linewidths_orig = nparray_if_array(linewidths)
def do_3d_projection(self, renderer):
super().do_3d_projection(renderer)
# The whole 3D ordering is flawed in mplot3d when several
# collections are added. We just use normal zorder. Note the
# "-" due to the different logic in the 3d plotting, we still
# want larger zorder values to be plotted on top of smaller
# ones.
return -self.zorder3d
def draw(self, renderer):
if self.reflen:
proj_len = projected_length(self.axes, self.reflen)
args = self.axes.transData.frozen().to_values()
# Note: unlike in the 2D case, where we can enforce equal
# aspect ratio, this (currently) does not work with
# 3D plots in matplotlib. As an approximation, we
# thus scale with the average of the x- and y-axis
# transformation.
factor = proj_len * (args[0] +
args[3]) * 0.5 * 72.0 / self.figure.dpi
else:
factor = 1
super().set_linewidths(
self.linewidths_orig * factor)
super().draw(renderer)
class Path3DCollection(mplot3d.art3d.Patch3DCollection):
def __init__(self, paths, sizes, reflen=None, zorder=0,
offsets=None, **kwargs):
paths = [matplotlib.patches.PathPatch(path) for path in paths]
if offsets is not None:
kwargs['offsets'] = offsets[:, :2]
super().__init__(paths, **kwargs)
if offsets is not None:
self.set_3d_properties(zs=offsets[:, 2], zdir="z")
self.reflen = reflen
self.zorder3d = zorder
self.paths_orig = np.array(paths, dtype='object')
self.linewidths_orig = nparray_if_array(self.get_linewidths())
self.linewidths_orig2 = self.linewidths_orig
self.array_orig = nparray_if_array(self.get_array())
self.facecolors_orig = nparray_if_array(self.get_facecolors())
self.edgecolors_orig = nparray_if_array(self.get_edgecolors())
Affine2D = matplotlib.transforms.Affine2D
self.orig_transforms = np.array(
[Affine2D().scale(x).get_matrix() for x in sizes])
self.transforms = self.orig_transforms
def set_array(self, array):
self.array_orig = nparray_if_array(array)
super().set_array(array)
def set_color(self, colors):
self.facecolors_orig = nparray_if_array(colors)
self.edgecolors_orig = self.facecolors_orig
super().set_color(colors)
def set_edgecolors(self, colors):
colors = matplotlib.colors.colorConverter.to_rgba_array(colors)
self.edgecolors_orig = nparray_if_array(colors)
super().set_edgecolors(colors)
def get_transforms(self):
# this is exact only for an isometric projection, for the
# perspective projection used in mplot3d it's an approximation
return self.transforms
def get_transform(self):
Affine2D = matplotlib.transforms.Affine2D
if self.reflen:
proj_len = projected_length(self.axes, self.reflen)
# For the paths, use the data transformation but strip the
# offset (will be added later with the offsets).
args = self.axes.transData.frozen().to_values()[:4] + (0, 0)
return Affine2D().from_values(*args).scale(proj_len)
else:
return Affine2D().scale(self.figure.dpi / 72.0)
def do_3d_projection(self, renderer):
xs, ys, zs = self._offsets3d
# numpy complains about zero-length index arrays
if len(xs) == 0:
return -self.zorder3d
proj = mplot3d.proj3d.proj_transform_clip
vs = np.array(proj(xs, ys, zs, renderer.M)[:3])
if sort3d:
indx = vs[2].argsort()[::-1]
self.set_offsets(vs[:2, indx].T)
if len(self.paths_orig) > 1:
paths = np.resize(self.paths_orig, (vs.shape[1],))
self.set_paths(paths[indx])
if len(self.orig_transforms) > 1:
self.transforms = self.transforms[indx]
lw_orig = self.linewidths_orig
if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1):
self.linewidths_orig2 = np.resize(lw_orig,
(vs.shape[1],))[indx]
# Note: here array, facecolors and edgecolors are
# guaranteed to be 2d numpy arrays or None. (And
# array is the same length as the coordinates)
if self.array_orig is not None:
super(Path3DCollection,
self).set_array(self.array_orig[indx])
if (self.facecolors_orig is not None and
self.facecolors_orig.shape[0] > 1):
shape = list(self.facecolors_orig.shape)
shape[0] = vs.shape[1]
super().set_facecolors(
np.resize(self.facecolors_orig, shape)[indx])
if (self.edgecolors_orig is not None and
self.edgecolors_orig.shape[0] > 1):
shape = list(self.edgecolors_orig.shape)
shape[0] = vs.shape[1]
super().set_edgecolors(
np.resize(self.edgecolors_orig,
shape)[indx])
else:
self.set_offsets(vs[:2].T)
# the whole 3D ordering is flawed in mplot3d when several
# collections are added. We just use normal zorder, but correct
# by the projected z-coord of the "center of gravity",
# normalized by the projected z-coord of the world coordinates.
# In doing so, several Path3DCollections are plotted probably
# in the right order (it's not exact) if they have the same
# zorder. Still, smaller and larger integer zorders are plotted
# below or on top.
bbox = np.asarray(self.axes.get_w_lims())
proj = mplot3d.proj3d.proj_transform_clip
cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2]
return -self.zorder3d + vs[2].mean() / cz.ptp()
def draw(self, renderer):
if self.reflen:
proj_len = projected_length(self.axes, self.reflen)
args = self.axes.transData.frozen().to_values()
factor = proj_len * (args[0] +
args[3]) * 0.5 * 72.0 / self.figure.dpi
self.set_linewidths(self.linewidths_orig2 * factor)
super().draw(renderer)
......@@ -6,18 +6,15 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import sys
from .._common import ExtensionUnavailable
try:
from .discretizer import discretize, discretize_symbolic, build_discretized
from ._common import sympify, lambdify
from ._common import momentum_operators, position_operators
except ImportError:
sys.modules[__name__] = ExtensionUnavailable(__name__, ('sympy',))
except ImportError as error:
msg = ("'kwant.continuum' is not available because one or more of its "
"dependencies is not installed.")
raise ImportError(msg) from error
del sys, ExtensionUnavailable
__all__ = ['discretize', 'discretize_symbolic', 'build_discretized',
'sympify', 'lambdify']
This diff is collapsed.
......@@ -36,6 +36,7 @@ except ImportError:
matplotlib_backend_chosen = False
from kwant import plotter
from kwant import _plotter # for mpl_available
def test_matplotlib_backend_unset():
......@@ -44,7 +45,7 @@ def test_matplotlib_backend_unset():
def test_importable_without_matplotlib():
prefix, sep, suffix = plotter.__file__.rpartition('.')
prefix, sep, suffix = _plotter.__file__.rpartition('.')
if suffix in ['pyc', 'pyo']:
suffix = 'py'
assert suffix == 'py'
......@@ -111,7 +112,7 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
return syst
@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot():
plot = plotter.plot
syst2d = syst_2d()
......@@ -161,7 +162,7 @@ def bad_transform(pos):
x, y = pos
return x, y, 0
@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_map():
syst = syst_2d()
with tempfile.TemporaryFile('w+b') as out:
......@@ -197,7 +198,7 @@ def test_mask_interpolate():
coords, np.ones(2 * len(coords)))
@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_bands():
syst = syst_2d().finalized().leads[0]
......@@ -212,7 +213,7 @@ def test_bands():
plotter.bands(syst, ax=ax, file=out)
@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_spectrum():
def ham_1d(a, b, c):
......@@ -423,7 +424,7 @@ def test_current_interpolation():
assert scipy.stats.linregress(np.log(data))[2] < -0.8
@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_current():
syst = syst_2d().finalized()
J = kwant.operator.Current(syst)
......
......@@ -13,11 +13,11 @@ import tinyarray as ta
import pytest
import kwant
from kwant import plotter
from kwant import _plotter
from kwant.wraparound import wraparound, plot_2d_bands
from kwant._common import get_parameters
if plotter.mpl_available:
if _plotter.mpl_available:
from mpl_toolkits import mplot3d # pragma: no flakes
from matplotlib import pyplot # pragma: no flakes
......@@ -201,7 +201,7 @@ def test_symmetry():
assert np.all(orig == new)
@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.")
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot_2d_bands():
chain = kwant.lattice.chain()
square = kwant.lattice.square()
......
......@@ -2,5 +2,6 @@
testpaths = kwant
flakes-ignore =
__init__.py UnusedImport
kwant/_plotter.py UnusedImport
graph/tests/test_scotch.py UndefinedName
graph/tests/test_dissection.py UndefinedName
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment