diff --git a/doc/source/code/figure/discretize.py.diff b/doc/source/code/figure/discretize.py.diff index 9ca0df555e3b8553f7772d6f2253e485e6d4df73..aebaf3e03c10d291d2b25789c1145b90d6210a15 100644 --- a/doc/source/code/figure/discretize.py.diff +++ b/doc/source/code/figure/discretize.py.diff @@ -1,4 +1,4 @@ -@@ -1,222 +1,236 @@ +@@ -1,225 +1,239 @@ # Tutorial 2.9. Processing continuum Hamiltonians with discretize # =============================================================== # @@ -13,6 +13,9 @@ +import _defs import kwant + #HIDDEN_BEGIN_import + import kwant.continuum + #HIDDEN_END_import import scipy.sparse.linalg import scipy.linalg import numpy as np diff --git a/doc/source/code/figure/plot_qahe.py.diff b/doc/source/code/figure/plot_qahe.py.diff index d726f98e908feac22a767b4387c8e6b561523bb4..1788f677df857d1b75b8706aa0907fad4b9a467d 100644 --- a/doc/source/code/figure/plot_qahe.py.diff +++ b/doc/source/code/figure/plot_qahe.py.diff @@ -1,4 +1,4 @@ -@@ -1,71 +1,75 @@ +@@ -1,72 +1,76 @@ # Comprehensive example: quantum anomalous Hall effect # ==================================================== # @@ -16,6 +16,7 @@ import math import matplotlib.pyplot import kwant + import kwant.continuum # 2 band model exhibiting quantum anomalous Hall effect diff --git a/doc/source/conf.py b/doc/source/conf.py index f0272556a1cee688b73c60fdea6574d8792ba5ee..9790aaa58434a0eb868d66bf2b2243677611ee4c 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -16,7 +16,9 @@ import sys, os from distutils.util import get_platform sys.path.insert(0, "../../build/lib.{0}-{1}.{2}".format( get_platform(), *sys.version_info[:2])) + import kwant +import kwant.continuum # sphinx gets confused with lazy loading # -- General configuration ----------------------------------------------------- diff --git a/doc/source/tutorial/discretize.rst b/doc/source/tutorial/discretize.rst index d3c58219ea2baf593bc96fc6e9ded1073aabc634..1ed5015cd2203e100d1d4e4cf9f21ae283716f04 100644 --- a/doc/source/tutorial/discretize.rst +++ b/doc/source/tutorial/discretize.rst @@ -53,6 +53,11 @@ with :math:`A(x) = \frac{\hbar^2}{2 m(x)}`. Using `~kwant.continuum.discretize` to obtain a template ........................................................ +First we must explicitly import the `kwant.continuum` package: + +.. literalinclude:: /code/include/discretize.py + :start-after: #HIDDEN_BEGIN_import + :end-before: #HIDDEN_END_import The function `kwant.continuum.discretize` takes a symbolic Hamiltonian and turns it into a `~kwant.builder.Builder` instance with appropriate spatial diff --git a/kwant/__init__.py b/kwant/__init__.py index 1fddab25426261753a81e358a30ddab009ac3171..054657c308607608b1c9c8a893680231b21c59f6 100644 --- a/kwant/__init__.py +++ b/kwant/__init__.py @@ -33,7 +33,7 @@ from ._common import KwantDeprecationWarning, UserCodeError __all__.extend(['KwantDeprecationWarning', 'UserCodeError']) for module in ['system', 'builder', 'lattice', 'solvers', 'digest', 'rmt', - 'operator', 'kpm', 'wraparound', 'continuum']: + 'operator', 'kpm', 'wraparound']: exec('from . import {0}'.format(module)) __all__.append(module) @@ -56,6 +56,13 @@ except: else: __all__.extend(['plotter', 'plot']) +# Lazy import continuum package for backwards compatibility +from ._common import lazy_import + +continuum = lazy_import('continuum', deprecation_warning=True) +__all__.append('continuum') +del lazy_import + def test(verbose=True): from pytest import main diff --git a/kwant/_common.py b/kwant/_common.py index 76e952cf6c0d2ffc437d42ca5c093a5825aac0e9..4cd0de1a3e6d4b084f3434bbfffef08871d51ded 100644 --- a/kwant/_common.py +++ b/kwant/_common.py @@ -6,10 +6,12 @@ # the file AUTHORS.rst at the top-level directory of this distribution and at # http://kwant-project.org/authors. +import sys import numpy as np import numbers import inspect import warnings +import importlib from contextlib import contextmanager __all__ = ['KwantDeprecationWarning', 'UserCodeError'] @@ -39,30 +41,6 @@ class UserCodeError(Exception): pass -class ExtensionUnavailable: - """Class that replaces unavailable extension modules in the Kwant namespace. - - Some extensions for Kwant (e.g. 'kwant.continuum') require additional - dependencies that are not required for core functionality. When the - additional dependencies are not installed an instance of this class will - be inserted into Kwant's root namespace to simulate the presence of the - extension and inform users that they need to install additional - dependencies. - - See https://mail.python.org/pipermail/python-ideas/2012-May/014969.html - for more details. - """ - - def __init__(self, name, dependencies): - self.name = name - self.dependencies = ', '.join(dependencies) - - def __getattr__(self, _): - msg = ("'{}' is not available because one or more of the following " - "dependencies are not installed: {}") - raise RuntimeError(msg.format(self.name, self.dependencies)) - - def ensure_isinstance(obj, typ, msg=None): if isinstance(obj, typ): return @@ -123,3 +101,25 @@ def get_parameters(func): takes_kwargs = any(i.kind is inspect.Parameter.VAR_KEYWORD for i in pars.values()) return required_params, default_params, takes_kwargs + + +class lazy_import: + def __init__(self, module, package='kwant', deprecation_warning=False): + if module.startswith('.') and not package: + raise ValueError('Cannot import a relative module without a package.') + self.__module = module + self.__package = package + self.__deprecation_warning = deprecation_warning + + def __getattr__(self, name): + if self.__deprecation_warning: + msg = ("Accessing {0} without an explicit import is deprecated. " + "Instead, explicitly 'import {0}'." + ).format('.'.join((self.__package, self.__module))) + warnings.warn(msg, KwantDeprecationWarning, stacklevel=2) + relative_module = '.' + self.__module + mod = importlib.import_module(relative_module, self.__package) + # Replace this _LazyModuleProxy with an actual module + package = sys.modules[self.__package] + setattr(package, self.__module, mod) + return getattr(mod, name) diff --git a/kwant/_plotter.py b/kwant/_plotter.py new file mode 100644 index 0000000000000000000000000000000000000000..771e9ba281999efc07f2aece0a7fc0169d19be63 --- /dev/null +++ b/kwant/_plotter.py @@ -0,0 +1,320 @@ +# -*- coding: utf-8 -*- +# Copyright 2011-2018 Kwant authors. +# +# This file is part of Kwant. It is subject to the license terms in the file +# LICENSE.rst found in the top-level directory of this distribution and at +# http://kwant-project.org/license. A list of Kwant authors can be found in +# the file AUTHORS.rst at the top-level directory of this distribution and at +# http://kwant-project.org/authors. + +# This module is imported by plotter.py. It contains all the expensive imports +# that we want to remove from plotter.py + +# All matplotlib imports must be isolated in a try, because even without +# matplotlib iterators remain useful. Further, mpl_toolkits used for 3D +# plotting are also imported separately, to ensure that 2D plotting works even +# if 3D does not. + +import warnings +from math import sqrt, pi +import numpy as np + +try: + import matplotlib + import matplotlib.colors + import matplotlib.cm + from matplotlib.figure import Figure + from matplotlib import collections + from . import _colormaps + mpl_available = True + try: + from mpl_toolkits import mplot3d + has3d = True + except ImportError: + warnings.warn("3D plotting not available.", RuntimeWarning) + has3d = False +except ImportError: + warnings.warn("matplotlib is not available, only iterator-providing " + "functions will work.", RuntimeWarning) + mpl_available = False + + +# Collections that allow for symbols and linewiths to be given in data space +# (not for general use, only implement what's needed for plotter) +def isarray(var): + if hasattr(var, '__getitem__') and not isinstance(var, str): + return True + else: + return False + + +def nparray_if_array(var): + return np.asarray(var) if isarray(var) else var + + +if mpl_available: + class LineCollection(collections.LineCollection): + def __init__(self, segments, reflen=None, **kwargs): + super().__init__(segments, **kwargs) + self.reflen = reflen + + def set_linewidths(self, linewidths): + self.linewidths_orig = nparray_if_array(linewidths) + + def draw(self, renderer): + if self.reflen is not None: + # Note: only works for aspect ratio 1! + # 72.0 - there is 72 points in an inch + factor = (self.axes.transData.frozen().to_values()[0] * 72.0 * + self.reflen / self.figure.dpi) + else: + factor = 1 + + super().set_linewidths(self.linewidths_orig * + factor) + return super().draw(renderer) + + + class PathCollection(collections.PathCollection): + def __init__(self, paths, sizes=None, reflen=None, **kwargs): + super().__init__(paths, sizes=sizes, **kwargs) + + self.reflen = reflen + self.linewidths_orig = nparray_if_array(self.get_linewidths()) + + self.transforms = np.array( + [matplotlib.transforms.Affine2D().scale(x).get_matrix() + for x in sizes]) + + def get_transforms(self): + return self.transforms + + def get_transform(self): + Affine2D = matplotlib.transforms.Affine2D + if self.reflen is not None: + # For the paths, use the data transformation but strip the + # offset (will be added later with offsets) + args = self.axes.transData.frozen().to_values()[:4] + (0, 0) + return Affine2D().from_values(*args).scale(self.reflen) + else: + return Affine2D().scale(self.figure.dpi / 72.0) + + def draw(self, renderer): + if self.reflen: + # Note: only works for aspect ratio 1! + factor = (self.axes.transData.frozen().to_values()[0] / + self.figure.dpi * 72.0 * self.reflen) + self.set_linewidths(self.linewidths_orig * factor) + + return collections.Collection.draw(self, renderer) + + + if has3d: + # Sorting is optional. + sort3d = True + + # Compute the projection of a 3D length into 2D data coordinates + # for this we use 2 3D half-circles that are projected into 2D. + # (This gives the same length as projecting the full unit sphere.) + + phi = np.linspace(0, pi, 21) + xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21) + unit_sphere = np.bmat([[xyz[0], xyz[2]], [xyz[1], xyz[0]], + [xyz[2], xyz[1]]]) + unit_sphere = np.asarray(unit_sphere) + + def projected_length(ax, length): + rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]) + rc = np.apply_along_axis(np.sum, 1, rc) / 2. + + rs = unit_sphere * length + rc.reshape(-1, 1) + + transform = mplot3d.proj3d.proj_transform + rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2]) + rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2] + + coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1) + return sqrt(np.sum(coords**2, axis=0).max()) + + + # Auxiliary array for calculating corners of a cube. + corners = np.zeros((3, 8, 6), np.float_) + corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \ + corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \ + corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0 + + + class Line3DCollection(mplot3d.art3d.Line3DCollection): + def __init__(self, segments, reflen=None, zorder=0, **kwargs): + super().__init__(segments, **kwargs) + self.reflen = reflen + self.zorder3d = zorder + + def set_linewidths(self, linewidths): + self.linewidths_orig = nparray_if_array(linewidths) + + def do_3d_projection(self, renderer): + super().do_3d_projection(renderer) + # The whole 3D ordering is flawed in mplot3d when several + # collections are added. We just use normal zorder. Note the + # "-" due to the different logic in the 3d plotting, we still + # want larger zorder values to be plotted on top of smaller + # ones. + return -self.zorder3d + + def draw(self, renderer): + if self.reflen: + proj_len = projected_length(self.axes, self.reflen) + args = self.axes.transData.frozen().to_values() + # Note: unlike in the 2D case, where we can enforce equal + # aspect ratio, this (currently) does not work with + # 3D plots in matplotlib. As an approximation, we + # thus scale with the average of the x- and y-axis + # transformation. + factor = proj_len * (args[0] + + args[3]) * 0.5 * 72.0 / self.figure.dpi + else: + factor = 1 + + super().set_linewidths( + self.linewidths_orig * factor) + super().draw(renderer) + + + class Path3DCollection(mplot3d.art3d.Patch3DCollection): + def __init__(self, paths, sizes, reflen=None, zorder=0, + offsets=None, **kwargs): + paths = [matplotlib.patches.PathPatch(path) for path in paths] + + if offsets is not None: + kwargs['offsets'] = offsets[:, :2] + + super().__init__(paths, **kwargs) + + if offsets is not None: + self.set_3d_properties(zs=offsets[:, 2], zdir="z") + + self.reflen = reflen + self.zorder3d = zorder + + self.paths_orig = np.array(paths, dtype='object') + self.linewidths_orig = nparray_if_array(self.get_linewidths()) + self.linewidths_orig2 = self.linewidths_orig + self.array_orig = nparray_if_array(self.get_array()) + self.facecolors_orig = nparray_if_array(self.get_facecolors()) + self.edgecolors_orig = nparray_if_array(self.get_edgecolors()) + + Affine2D = matplotlib.transforms.Affine2D + self.orig_transforms = np.array( + [Affine2D().scale(x).get_matrix() for x in sizes]) + self.transforms = self.orig_transforms + + def set_array(self, array): + self.array_orig = nparray_if_array(array) + super().set_array(array) + + def set_color(self, colors): + self.facecolors_orig = nparray_if_array(colors) + self.edgecolors_orig = self.facecolors_orig + super().set_color(colors) + + def set_edgecolors(self, colors): + colors = matplotlib.colors.colorConverter.to_rgba_array(colors) + self.edgecolors_orig = nparray_if_array(colors) + super().set_edgecolors(colors) + + def get_transforms(self): + # this is exact only for an isometric projection, for the + # perspective projection used in mplot3d it's an approximation + return self.transforms + + def get_transform(self): + Affine2D = matplotlib.transforms.Affine2D + if self.reflen: + proj_len = projected_length(self.axes, self.reflen) + + # For the paths, use the data transformation but strip the + # offset (will be added later with the offsets). + args = self.axes.transData.frozen().to_values()[:4] + (0, 0) + return Affine2D().from_values(*args).scale(proj_len) + else: + return Affine2D().scale(self.figure.dpi / 72.0) + + def do_3d_projection(self, renderer): + xs, ys, zs = self._offsets3d + + # numpy complains about zero-length index arrays + if len(xs) == 0: + return -self.zorder3d + + proj = mplot3d.proj3d.proj_transform_clip + vs = np.array(proj(xs, ys, zs, renderer.M)[:3]) + + if sort3d: + indx = vs[2].argsort()[::-1] + + self.set_offsets(vs[:2, indx].T) + + if len(self.paths_orig) > 1: + paths = np.resize(self.paths_orig, (vs.shape[1],)) + self.set_paths(paths[indx]) + + if len(self.orig_transforms) > 1: + self.transforms = self.transforms[indx] + + lw_orig = self.linewidths_orig + if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1): + self.linewidths_orig2 = np.resize(lw_orig, + (vs.shape[1],))[indx] + + # Note: here array, facecolors and edgecolors are + # guaranteed to be 2d numpy arrays or None. (And + # array is the same length as the coordinates) + + if self.array_orig is not None: + super(Path3DCollection, + self).set_array(self.array_orig[indx]) + + if (self.facecolors_orig is not None and + self.facecolors_orig.shape[0] > 1): + shape = list(self.facecolors_orig.shape) + shape[0] = vs.shape[1] + super().set_facecolors( + np.resize(self.facecolors_orig, shape)[indx]) + + if (self.edgecolors_orig is not None and + self.edgecolors_orig.shape[0] > 1): + shape = list(self.edgecolors_orig.shape) + shape[0] = vs.shape[1] + super().set_edgecolors( + np.resize(self.edgecolors_orig, + shape)[indx]) + else: + self.set_offsets(vs[:2].T) + + # the whole 3D ordering is flawed in mplot3d when several + # collections are added. We just use normal zorder, but correct + # by the projected z-coord of the "center of gravity", + # normalized by the projected z-coord of the world coordinates. + # In doing so, several Path3DCollections are plotted probably + # in the right order (it's not exact) if they have the same + # zorder. Still, smaller and larger integer zorders are plotted + # below or on top. + + bbox = np.asarray(self.axes.get_w_lims()) + + proj = mplot3d.proj3d.proj_transform_clip + cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2] + + return -self.zorder3d + vs[2].mean() / cz.ptp() + + def draw(self, renderer): + if self.reflen: + proj_len = projected_length(self.axes, self.reflen) + args = self.axes.transData.frozen().to_values() + factor = proj_len * (args[0] + + args[3]) * 0.5 * 72.0 / self.figure.dpi + + self.set_linewidths(self.linewidths_orig2 * factor) + + super().draw(renderer) diff --git a/kwant/continuum/__init__.py b/kwant/continuum/__init__.py index 718f50151795830b73df1f9327aa591997d0cbdc..bc2420e9c7998af5cb7f3b33fe72555e778156d4 100644 --- a/kwant/continuum/__init__.py +++ b/kwant/continuum/__init__.py @@ -6,18 +6,15 @@ # the file AUTHORS.rst at the top-level directory of this distribution and at # http://kwant-project.org/authors. -import sys - -from .._common import ExtensionUnavailable - try: from .discretizer import discretize, discretize_symbolic, build_discretized from ._common import sympify, lambdify from ._common import momentum_operators, position_operators -except ImportError: - sys.modules[__name__] = ExtensionUnavailable(__name__, ('sympy',)) +except ImportError as error: + msg = ("'kwant.continuum' is not available because one or more of its " + "dependencies is not installed.") + raise ImportError(msg) from error -del sys, ExtensionUnavailable __all__ = ['discretize', 'discretize_symbolic', 'build_discretized', 'sympify', 'lambdify'] diff --git a/kwant/plotter.py b/kwant/plotter.py index 2112e5b98334cd15e22ef1dcebc09aaa59ed61c9..2fc6b6acf2e21e7d65f8b0b3336394c836389858 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -26,29 +26,6 @@ import tinyarray as ta from scipy import spatial, interpolate from math import cos, sin, pi, sqrt -# All matplotlib imports must be isolated in a try, because even without -# matplotlib iterators remain useful. Further, mpl_toolkits used for 3D -# plotting are also imported separately, to ensure that 2D plotting works even -# if 3D does not. -try: - import matplotlib - import matplotlib.colors - import matplotlib.cm - from matplotlib.figure import Figure - from matplotlib import collections - from . import _colormaps - mpl_available = True - try: - from mpl_toolkits import mplot3d - has3d = True - except ImportError: - warnings.warn("3D plotting not available.", RuntimeWarning) - has3d = False -except ImportError: - warnings.warn("matplotlib is not available, only iterator-providing " - "functions will work.", RuntimeWarning) - mpl_available = False - from . import system, builder, _common @@ -57,18 +34,9 @@ __all__ = ['plot', 'map', 'bands', 'spectrum', 'current', 'sys_leads_sites', 'sys_leads_hoppings', 'sys_leads_pos', 'sys_leads_hopping_pos', 'mask_interpolate'] - -# Collections that allow for symbols and linewiths to be given in data space -# (not for general use, only implement what's needed for plotter) -def isarray(var): - if hasattr(var, '__getitem__') and not isinstance(var, str): - return True - else: - return False - - -def nparray_if_array(var): - return np.asarray(var) if isarray(var) else var +# All the expensive imports are done in _plotter.py. We lazy load the module +# to avoid slowing down the initial import of Kwant. +_p = _common.lazy_import('_plotter') def _sample_array(array, n_samples, rng=None): @@ -77,278 +45,10 @@ def _sample_array(array, n_samples, rng=None): return array[rng.choice(range(la), min(n_samples, la))] -if mpl_available: - class LineCollection(collections.LineCollection): - def __init__(self, segments, reflen=None, **kwargs): - super().__init__(segments, **kwargs) - self.reflen = reflen - - def set_linewidths(self, linewidths): - self.linewidths_orig = nparray_if_array(linewidths) - - def draw(self, renderer): - if self.reflen is not None: - # Note: only works for aspect ratio 1! - # 72.0 - there is 72 points in an inch - factor = (self.axes.transData.frozen().to_values()[0] * 72.0 * - self.reflen / self.figure.dpi) - else: - factor = 1 - - super().set_linewidths(self.linewidths_orig * - factor) - return super().draw(renderer) - - - class PathCollection(collections.PathCollection): - def __init__(self, paths, sizes=None, reflen=None, **kwargs): - super().__init__(paths, sizes=sizes, **kwargs) - - self.reflen = reflen - self.linewidths_orig = nparray_if_array(self.get_linewidths()) - - self.transforms = np.array( - [matplotlib.transforms.Affine2D().scale(x).get_matrix() - for x in sizes]) - - def get_transforms(self): - return self.transforms - - def get_transform(self): - Affine2D = matplotlib.transforms.Affine2D - if self.reflen is not None: - # For the paths, use the data transformation but strip the - # offset (will be added later with offsets) - args = self.axes.transData.frozen().to_values()[:4] + (0, 0) - return Affine2D().from_values(*args).scale(self.reflen) - else: - return Affine2D().scale(self.figure.dpi / 72.0) - - def draw(self, renderer): - if self.reflen: - # Note: only works for aspect ratio 1! - factor = (self.axes.transData.frozen().to_values()[0] / - self.figure.dpi * 72.0 * self.reflen) - self.set_linewidths(self.linewidths_orig * factor) - - return collections.Collection.draw(self, renderer) - - - if has3d: - # Sorting is optional. - sort3d = True - - # Compute the projection of a 3D length into 2D data coordinates - # for this we use 2 3D half-circles that are projected into 2D. - # (This gives the same length as projecting the full unit sphere.) - - phi = np.linspace(0, pi, 21) - xyz = np.c_[np.cos(phi), np.sin(phi), 0 * phi].T.reshape(-1, 1, 21) - unit_sphere = np.bmat([[xyz[0], xyz[2]], [xyz[1], xyz[0]], - [xyz[2], xyz[1]]]) - unit_sphere = np.asarray(unit_sphere) - - def projected_length(ax, length): - rc = np.array([ax.get_xlim3d(), ax.get_ylim3d(), ax.get_zlim3d()]) - rc = np.apply_along_axis(np.sum, 1, rc) / 2. - - rs = unit_sphere * length + rc.reshape(-1, 1) - - transform = mplot3d.proj3d.proj_transform - rp = np.asarray(transform(*(list(rs) + [ax.get_proj()]))[:2]) - rc[:2] = transform(*(list(rc) + [ax.get_proj()]))[:2] - - coords = rp - np.repeat(rc[:2].reshape(-1, 1), len(rs[0]), axis=1) - return sqrt(np.sum(coords**2, axis=0).max()) - - - # Auxiliary array for calculating corners of a cube. - corners = np.zeros((3, 8, 6), np.float_) - corners[0, [0, 1, 2, 3], 0] = corners[0, [4, 5, 6, 7], 1] = \ - corners[0, [0, 1, 4, 5], 2] = corners[0, [2, 3, 6, 7], 3] = \ - corners[0, [0, 2, 4, 6], 4] = corners[0, [1, 3, 5, 7], 5] = 1.0 - - - class Line3DCollection(mplot3d.art3d.Line3DCollection): - def __init__(self, segments, reflen=None, zorder=0, **kwargs): - super().__init__(segments, **kwargs) - self.reflen = reflen - self.zorder3d = zorder - - def set_linewidths(self, linewidths): - self.linewidths_orig = nparray_if_array(linewidths) - - def do_3d_projection(self, renderer): - super().do_3d_projection(renderer) - # The whole 3D ordering is flawed in mplot3d when several - # collections are added. We just use normal zorder. Note the - # "-" due to the different logic in the 3d plotting, we still - # want larger zorder values to be plotted on top of smaller - # ones. - return -self.zorder3d - - def draw(self, renderer): - if self.reflen: - proj_len = projected_length(self.axes, self.reflen) - args = self.axes.transData.frozen().to_values() - # Note: unlike in the 2D case, where we can enforce equal - # aspect ratio, this (currently) does not work with - # 3D plots in matplotlib. As an approximation, we - # thus scale with the average of the x- and y-axis - # transformation. - factor = proj_len * (args[0] + - args[3]) * 0.5 * 72.0 / self.figure.dpi - else: - factor = 1 - - super().set_linewidths( - self.linewidths_orig * factor) - super().draw(renderer) - - - class Path3DCollection(mplot3d.art3d.Patch3DCollection): - def __init__(self, paths, sizes, reflen=None, zorder=0, - offsets=None, **kwargs): - paths = [matplotlib.patches.PathPatch(path) for path in paths] - - if offsets is not None: - kwargs['offsets'] = offsets[:, :2] - - super().__init__(paths, **kwargs) - - if offsets is not None: - self.set_3d_properties(zs=offsets[:, 2], zdir="z") - - self.reflen = reflen - self.zorder3d = zorder - - self.paths_orig = np.array(paths, dtype='object') - self.linewidths_orig = nparray_if_array(self.get_linewidths()) - self.linewidths_orig2 = self.linewidths_orig - self.array_orig = nparray_if_array(self.get_array()) - self.facecolors_orig = nparray_if_array(self.get_facecolors()) - self.edgecolors_orig = nparray_if_array(self.get_edgecolors()) - - Affine2D = matplotlib.transforms.Affine2D - self.orig_transforms = np.array( - [Affine2D().scale(x).get_matrix() for x in sizes]) - self.transforms = self.orig_transforms - - def set_array(self, array): - self.array_orig = nparray_if_array(array) - super().set_array(array) - - def set_color(self, colors): - self.facecolors_orig = nparray_if_array(colors) - self.edgecolors_orig = self.facecolors_orig - super().set_color(colors) - - def set_edgecolors(self, colors): - colors = matplotlib.colors.colorConverter.to_rgba_array(colors) - self.edgecolors_orig = nparray_if_array(colors) - super().set_edgecolors(colors) - - def get_transforms(self): - # this is exact only for an isometric projection, for the - # perspective projection used in mplot3d it's an approximation - return self.transforms - - def get_transform(self): - Affine2D = matplotlib.transforms.Affine2D - if self.reflen: - proj_len = projected_length(self.axes, self.reflen) - - # For the paths, use the data transformation but strip the - # offset (will be added later with the offsets). - args = self.axes.transData.frozen().to_values()[:4] + (0, 0) - return Affine2D().from_values(*args).scale(proj_len) - else: - return Affine2D().scale(self.figure.dpi / 72.0) - - def do_3d_projection(self, renderer): - xs, ys, zs = self._offsets3d - - # numpy complains about zero-length index arrays - if len(xs) == 0: - return -self.zorder3d - - proj = mplot3d.proj3d.proj_transform_clip - vs = np.array(proj(xs, ys, zs, renderer.M)[:3]) - - if sort3d: - indx = vs[2].argsort()[::-1] - - self.set_offsets(vs[:2, indx].T) - - if len(self.paths_orig) > 1: - paths = np.resize(self.paths_orig, (vs.shape[1],)) - self.set_paths(paths[indx]) - - if len(self.orig_transforms) > 1: - self.transforms = self.transforms[indx] - - lw_orig = self.linewidths_orig - if (isinstance(lw_orig, np.ndarray) and len(lw_orig) > 1): - self.linewidths_orig2 = np.resize(lw_orig, - (vs.shape[1],))[indx] - - # Note: here array, facecolors and edgecolors are - # guaranteed to be 2d numpy arrays or None. (And - # array is the same length as the coordinates) - - if self.array_orig is not None: - super(Path3DCollection, - self).set_array(self.array_orig[indx]) - - if (self.facecolors_orig is not None and - self.facecolors_orig.shape[0] > 1): - shape = list(self.facecolors_orig.shape) - shape[0] = vs.shape[1] - super().set_facecolors( - np.resize(self.facecolors_orig, shape)[indx]) - - if (self.edgecolors_orig is not None and - self.edgecolors_orig.shape[0] > 1): - shape = list(self.edgecolors_orig.shape) - shape[0] = vs.shape[1] - super().set_edgecolors( - np.resize(self.edgecolors_orig, - shape)[indx]) - else: - self.set_offsets(vs[:2].T) - - # the whole 3D ordering is flawed in mplot3d when several - # collections are added. We just use normal zorder, but correct - # by the projected z-coord of the "center of gravity", - # normalized by the projected z-coord of the world coordinates. - # In doing so, several Path3DCollections are plotted probably - # in the right order (it's not exact) if they have the same - # zorder. Still, smaller and larger integer zorders are plotted - # below or on top. - - bbox = np.asarray(self.axes.get_w_lims()) - - proj = mplot3d.proj3d.proj_transform_clip - cz = proj(*(list(np.dot(corners, bbox)) + [renderer.M]))[2] - - return -self.zorder3d + vs[2].mean() / cz.ptp() - - def draw(self, renderer): - if self.reflen: - proj_len = projected_length(self.axes, self.reflen) - args = self.axes.transData.frozen().to_values() - factor = proj_len * (args[0] + - args[3]) * 0.5 * 72.0 / self.figure.dpi - - self.set_linewidths(self.linewidths_orig2 * factor) - - super().draw(renderer) - - # matplotlib helper functions. def _make_figure(dpi, fig_size): - fig = Figure() + fig = _p.Figure() if dpi is not None: fig.set_dpi(dpi) if fig_size is not None: @@ -377,10 +77,10 @@ def set_colors(color, collection, cmap, norm=None): if (isinstance(color, np.ndarray) and color.dtype == np.dtype('object')): color = tuple(color) - if isinstance(collection, mplot3d.art3d.Line3DCollection): + if isinstance(collection, _p.mplot3d.art3d.Line3DCollection): length = len(collection._segments3d) # Once again, matplotlib fault! - if isarray(color) and len(color) == length: + if _p.isarray(color) and len(color) == length: try: # check if it is an array of floats for color mapping color = np.asarray(color, dtype=float) @@ -393,7 +93,7 @@ def set_colors(color, collection, cmap, norm=None): except (TypeError, ValueError): pass - colors = matplotlib.colors.colorConverter.to_rgba_array(color) + colors = _p.matplotlib.colors.colorConverter.to_rgba_array(color) collection.set_color(colors) @@ -414,7 +114,7 @@ def get_symbol(symbols): paths = [] for symbol in symbols: - if isinstance(symbol, matplotlib.path.Path): + if isinstance(symbol, _p.matplotlib.path.Path): return symbol elif hasattr(symbol, '__getitem__') and len(symbol) == 3: kind, n, angle = symbol @@ -428,13 +128,13 @@ def get_symbol(symbols): radius = sqrt(2 * pi / (n * sin(2 * pi / n))) angle = pi * angle / 180 - patch = matplotlib.patches.RegularPolygon((0, 0), n, - radius=radius, - orientation=angle) + patch = _p.matplotlib.patches.RegularPolygon((0, 0), n, + radius=radius, + orientation=angle) else: raise ValueError("Unknown symbol definition " + str(symbol)) elif symbol == 'o': - patch = matplotlib.patches.Circle((0, 0), 1) + patch = _p.matplotlib.patches.Circle((0, 0), 1) paths.append(patch.get_path().transformed(patch.get_transform())) @@ -495,9 +195,9 @@ def symbols(axes, pos, symbol='o', size=1, reflen=None, facecolor='k', size = (size, ) if dim == 2: - Collection = PathCollection + Collection = _p.PathCollection else: - Collection = Path3DCollection + Collection = _p.Path3DCollection if len(pos) == 0 or np.all(symbol == 'no symbol') or np.all(size == 0): paths = [] @@ -569,9 +269,9 @@ def lines(axes, pos0, pos1, reflen=None, colors='k', linestyles='solid', dim = pos0.shape[1] assert dim == 2 or dim == 3 if dim == 2: - Collection = LineCollection + Collection = _p.LineCollection else: - Collection = Line3DCollection + Collection = _p.Line3DCollection if (len(pos0) == 0 or ('linewidths' in kwargs and kwargs['linewidths'] == 0)): @@ -631,7 +331,7 @@ def output_fig(fig, output_mode='auto', file=None, savefile_opts=None, matplotlib in that the `dpi` attribute of the figure is used by defaul instead of the matplotlib config setting. """ - if not mpl_available: + if not _p.mpl_available: raise RuntimeError('matplotlib is not installed.') # We import backends and pyplot only at the last possible moment (=now) @@ -1111,7 +811,7 @@ def plot(sys, num_lead_cells=2, unit='nn', its aspect ratio. """ - if not mpl_available: + if not _p.mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for plot()") @@ -1161,7 +861,7 @@ def plot(sys, num_lead_cells=2, unit='nn', start_pos = np.apply_along_axis(pos_transform, 1, start_pos) dim = 3 if (sites_pos.shape[1] == 3) else 2 - if dim == 3 and not has3d: + if dim == 3 and not _p.has3d: raise RuntimeError("Installed matplotlib does not support 3d plotting") sites_pos = resize_to_dim(sites_pos) end_pos = resize_to_dim(end_pos) @@ -1202,7 +902,7 @@ def plot(sys, num_lead_cells=2, unit='nn', def make_proper_site_spec(spec, fancy_indexing=False): if callable(spec): spec = [spec(i[0]) for i in sites if i[1] is None] - if (fancy_indexing and isarray(spec) + if (fancy_indexing and _p.isarray(spec) and not isinstance(spec, np.ndarray)): try: spec = np.asarray(spec) @@ -1213,7 +913,7 @@ def plot(sys, num_lead_cells=2, unit='nn', def make_proper_hop_spec(spec, fancy_indexing=False): if callable(spec): spec = [spec(*i[0]) for i in hops if i[1] is None] - if (fancy_indexing and isarray(spec) + if (fancy_indexing and _p.isarray(spec) and not isinstance(spec, np.ndarray)): try: spec = np.asarray(spec) @@ -1225,7 +925,7 @@ def plot(sys, num_lead_cells=2, unit='nn', if site_symbol is None: site_symbol = defaults['site_symbol'][dim] # separate different symbols (not done in 3D, the separation # would mess up sorting) - if (isarray(site_symbol) and dim != 3 and + if (_p.isarray(site_symbol) and dim != 3 and (len(site_symbol) != 3 or site_symbol[0] not in ('p', 'P'))): symbol_dict = defaultdict(list) for i, symbol in enumerate(site_symbol): @@ -1239,7 +939,7 @@ def plot(sys, num_lead_cells=2, unit='nn', fancy_indexing = False if site_color is None: - cycle = (x['color'] for x in matplotlib.rcParams['axes.prop_cycle']) + cycle = (x['color'] for x in _p.matplotlib.rcParams['axes.prop_cycle']) if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)): # Skipping the leads for brevity. families = sorted({site.family for site in syst.sites}) @@ -1279,30 +979,30 @@ def plot(sys, num_lead_cells=2, unit='nn', try: if site_color.ndim == 1 and len(site_color) == n_syst_sites: site_color = np.asarray(site_color, dtype=float) - norm = matplotlib.colors.Normalize(site_color.min(), - site_color.max()) + norm = _p.matplotlib.colors.Normalize(site_color.min(), + site_color.max()) except: pass # take spec also for lead, if it's not a list/array, default, otherwise if lead_site_symbol is None: - lead_site_symbol = (site_symbol if not isarray(site_symbol) + lead_site_symbol = (site_symbol if not _p.isarray(site_symbol) else defaults['site_symbol'][dim]) if lead_site_size is None: - lead_site_size = (site_size if not isarray(site_size) + lead_site_size = (site_size if not _p.isarray(site_size) else defaults['site_size'][dim]) if lead_color is None: lead_color = defaults['lead_color'][dim] - lead_color = matplotlib.colors.colorConverter.to_rgba(lead_color) + lead_color = _p.matplotlib.colors.colorConverter.to_rgba(lead_color) if lead_site_edgecolor is None: - lead_site_edgecolor = (site_edgecolor if not isarray(site_edgecolor) + lead_site_edgecolor = (site_edgecolor if not _p.isarray(site_edgecolor) else defaults['site_edgecolor'][dim]) if lead_site_lw is None: - lead_site_lw = (site_lw if not isarray(site_lw) + lead_site_lw = (site_lw if not _p.isarray(site_lw) else defaults['site_lw'][dim]) if lead_hop_lw is None: - lead_hop_lw = (hop_lw if not isarray(hop_lw) + lead_hop_lw = (hop_lw if not _p.isarray(hop_lw) else defaults['hop_lw'][dim]) hop_cmap = None @@ -1328,11 +1028,11 @@ def plot(sys, num_lead_cells=2, unit='nn', # plot system sites and hoppings for symbol, slc in symbol_slcs: - size = site_size[slc] if isarray(site_size) else site_size - col = site_color[slc] if isarray(site_color) else site_color - edgecol = (site_edgecolor[slc] if isarray(site_edgecolor) else + size = site_size[slc] if _p.isarray(site_size) else site_size + col = site_color[slc] if _p.isarray(site_color) else site_color + edgecol = (site_edgecolor[slc] if _p.isarray(site_edgecolor) else site_edgecolor) - lw = site_lw[slc] if isarray(site_lw) else site_lw + lw = site_lw[slc] if _p.isarray(site_lw) else site_lw symbol_coll = symbols(ax, sites_pos[slc], size=size, reflen=reflen, symbol=symbol, @@ -1344,8 +1044,8 @@ def plot(sys, num_lead_cells=2, unit='nn', zorder=1, cmap=hop_cmap) # plot lead sites and hoppings - norm = matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5) - cmap_from_list = matplotlib.colors.LinearSegmentedColormap.from_list + norm = _p.matplotlib.colors.Normalize(-0.5, num_lead_cells - 0.5) + cmap_from_list = _p.matplotlib.colors.LinearSegmentedColormap.from_list lead_cmap = cmap_from_list(None, [lead_color, (1, 1, 1, lead_color[3])]) for sites_slc, hops_slc in zip(lead_sites_slcs, lead_hops_slcs): @@ -1543,7 +1243,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, correspond to exactly one pixel. """ - if not mpl_available: + if not _p.mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for map()") @@ -1576,7 +1276,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, fig = None if cmap is None: - cmap = _colormaps.kwant_red + cmap = _p._colormaps.kwant_red # Note that we tell imshow to show the array created by mask_interpolate # faithfully and not to interpolate by itself another time. @@ -1640,7 +1340,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, See `~kwant.physics.Bands` for the calculation of dispersion without plotting. """ - if not mpl_available: + if not _p.mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for bands()") @@ -1715,10 +1415,10 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, A figure with the output if `ax` is not set, else None. """ - if not mpl_available: + if not _p.mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for plot_spectrum()") - if y is not None and not has3d: + if y is not None and not _p.has3d: raise RuntimeError("Installed matplotlib does not support 3d plotting") if isinstance(syst, system.FiniteSystem): @@ -1754,7 +1454,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, # set up axes if ax is None: - fig = Figure() + fig = _p.Figure() if dpi is not None: fig.set_dpi(dpi) if fig_size is not None: @@ -1996,8 +1696,8 @@ _gamma_expand = np.vectorize(_gamma_expand, otypes=[float]) def _linear_cmap(a, b): """Make a colormap that linearly interpolates between the colors a and b.""" - a = matplotlib.colors.colorConverter.to_rgb(a) - b = matplotlib.colors.colorConverter.to_rgb(b) + a = _p.matplotlib.colors.colorConverter.to_rgb(a) + b = _p.matplotlib.colors.colorConverter.to_rgb(b) a_linear = _gamma_expand(a) b_linear = _gamma_expand(b) color_diff = a_linear - b_linear @@ -2005,7 +1705,7 @@ def _linear_cmap(a, b): * color_diff.reshape((1, -1))) palette += b_linear palette = _gamma_compress(palette) - return matplotlib.colors.ListedColormap(palette) + return _p.matplotlib.colors.ListedColormap(palette) def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', @@ -2074,7 +1774,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', fig : matplotlib figure A figure with the output if `ax` is not set, else None. """ - if not mpl_available: + if not _p.mpl_available: raise RuntimeError("matplotlib was not found, but is required " "for current()") @@ -2090,8 +1790,8 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', if bgcolor is None: if cmap is None: - cmap = _colormaps.kwant_red - cmap = matplotlib.cm.get_cmap(cmap) + cmap = _p._colormaps.kwant_red + cmap = _p.matplotlib.cm.get_cmap(cmap) bgcolor = cmap(0)[:3] elif cmap is not None: raise ValueError("The parameters 'cmap' and 'bgcolor' are " @@ -2129,7 +1829,7 @@ def streamplot(field, box, cmap=None, bgcolor=None, linecolor='k', ax.streamplot(X, Y, field[:,:,0], field[:,:,1], density=density, linewidth=linewidth, color=color, cmap=line_cmap, arrowstyle='->', - norm=matplotlib.colors.Normalize(0, 1)) + norm=_p.matplotlib.colors.Normalize(0, 1)) ax.set_xlim(*box[0]) ax.set_ylim(*box[1]) diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index ff2e55caa18b693b23eca9f219d4d22afcff54e8..a5d06a36bf13e1a8a549620fda9d98ff762d699d 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -36,6 +36,7 @@ except ImportError: matplotlib_backend_chosen = False from kwant import plotter +from kwant import _plotter # for mpl_available def test_matplotlib_backend_unset(): @@ -44,7 +45,7 @@ def test_matplotlib_backend_unset(): def test_importable_without_matplotlib(): - prefix, sep, suffix = plotter.__file__.rpartition('.') + prefix, sep, suffix = _plotter.__file__.rpartition('.') if suffix in ['pyc', 'pyo']: suffix = 'py' assert suffix == 'py' @@ -111,7 +112,7 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0): return syst -@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") +@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") def test_plot(): plot = plotter.plot syst2d = syst_2d() @@ -161,7 +162,7 @@ def bad_transform(pos): x, y = pos return x, y, 0 -@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") +@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") def test_map(): syst = syst_2d() with tempfile.TemporaryFile('w+b') as out: @@ -197,7 +198,7 @@ def test_mask_interpolate(): coords, np.ones(2 * len(coords))) -@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") +@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") def test_bands(): syst = syst_2d().finalized().leads[0] @@ -212,7 +213,7 @@ def test_bands(): plotter.bands(syst, ax=ax, file=out) -@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") +@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") def test_spectrum(): def ham_1d(a, b, c): @@ -423,7 +424,7 @@ def test_current_interpolation(): assert scipy.stats.linregress(np.log(data))[2] < -0.8 -@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") +@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") def test_current(): syst = syst_2d().finalized() J = kwant.operator.Current(syst) diff --git a/kwant/tests/test_wraparound.py b/kwant/tests/test_wraparound.py index eb981d2798fa1b8fd42e36d26a0cd88039274b11..f59ec8cb15953446b7ce44b745adf6b0c3b98ed7 100644 --- a/kwant/tests/test_wraparound.py +++ b/kwant/tests/test_wraparound.py @@ -13,11 +13,11 @@ import tinyarray as ta import pytest import kwant -from kwant import plotter +from kwant import _plotter from kwant.wraparound import wraparound, plot_2d_bands from kwant._common import get_parameters -if plotter.mpl_available: +if _plotter.mpl_available: from mpl_toolkits import mplot3d # pragma: no flakes from matplotlib import pyplot # pragma: no flakes @@ -201,7 +201,7 @@ def test_symmetry(): assert np.all(orig == new) -@pytest.mark.skipif(not plotter.mpl_available, reason="Matplotlib unavailable.") +@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") def test_plot_2d_bands(): chain = kwant.lattice.chain() square = kwant.lattice.square() diff --git a/pytest.ini b/pytest.ini index 6c8258daf2ed83d5fb6da872d379c2ad36d42c25..c7f56a4768a2f6f410c50545cc5740b85baef3d7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,5 +2,6 @@ testpaths = kwant flakes-ignore = __init__.py UnusedImport + kwant/_plotter.py UnusedImport graph/tests/test_scotch.py UndefinedName graph/tests/test_dissection.py UndefinedName