diff --git a/kwant/plotter.py b/kwant/plotter.py index cc1e9857aff5da4a0230ad0f548bbb4ffcd493d9..c9e8d3706a8c13aeccfb513985128858c4f431c2 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -1404,11 +1404,22 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): tree = spatial.cKDTree(coords) + points = coords[np.random.randint(len(coords), size=10)] + min_dist = np.min(tree.query(points, 2)[0][:, 1]) + if min_dist < 1e-6 * np.linalg.norm(cmax - cmin): + warnings.warn("Some sites have nearly coinciding positions, " + "interpolation may be confusing.", + RuntimeWarning) + if a is None: - points = coords[np.random.randint(len(coords), size=10)] - a = np.min(tree.query(points, 2)[0][:, 1]) - elif a <= 0: - raise ValueError("The distance a must be strictly positive.") + a = min_dist + + if a < 1e-6 * np.linalg.norm(cmax - cmin): + raise ValueError("The reference distance a is too small.") + + if len(coords) != len(values): + raise ValueError("The number of sites doesn't match the number of" + "provided values.") shape = (((cmax - cmin) / a + 1) * oversampling).round() delta = 0.5 * (oversampling - 1) * a / oversampling diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index 2830e545387abc06238836d9a5d643af7ff92d1f..08a6fbfea3569a7b9496117eeac1cd75b78108ba 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -9,8 +9,10 @@ import tempfile import warnings import nose +import numpy as np import kwant from kwant import plotter +from nose.tools import assert_raises if plotter.mpl_enabled: from mpl_toolkits import mplot3d from matplotlib import pyplot @@ -142,3 +144,23 @@ def test_map(): file=out) nose.tools.assert_raises(ValueError, plotter.map, sys, xrange(len(sys.sites())), file=out) + + +def test_mask_interpolate(): + # A coordinate array with coordinates of two points almost coinciding. + coords = np.random.rand(10, 2) + coords[5] *= 1e-8 + coords[5] += coords[0] + + warnings.simplefilter("ignore") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + plotter.mask_interpolate(coords, np.ones(len(coords)), a=1) + assert len(w) == 1 + assert issubclass(w[-1].category, RuntimeWarning) + assert "coinciding" in str(w[-1].message) + + assert_raises(ValueError, plotter.mask_interpolate, + coords, np.ones(len(coords))) + assert_raises(ValueError, plotter.mask_interpolate, + coords, np.ones(2 * len(coords)))