Skip to content
Snippets Groups Projects
Commit bb6328ef authored by Anton Akhmerov's avatar Anton Akhmerov Committed by Christoph Groth
Browse files

improve handling of resolution and data length in mask_interpolate

parent f07c163f
No related branches found
No related tags found
No related merge requests found
......@@ -1404,11 +1404,22 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
tree = spatial.cKDTree(coords)
points = coords[np.random.randint(len(coords), size=10)]
min_dist = np.min(tree.query(points, 2)[0][:, 1])
if min_dist < 1e-6 * np.linalg.norm(cmax - cmin):
warnings.warn("Some sites have nearly coinciding positions, "
"interpolation may be confusing.",
RuntimeWarning)
if a is None:
points = coords[np.random.randint(len(coords), size=10)]
a = np.min(tree.query(points, 2)[0][:, 1])
elif a <= 0:
raise ValueError("The distance a must be strictly positive.")
a = min_dist
if a < 1e-6 * np.linalg.norm(cmax - cmin):
raise ValueError("The reference distance a is too small.")
if len(coords) != len(values):
raise ValueError("The number of sites doesn't match the number of"
"provided values.")
shape = (((cmax - cmin) / a + 1) * oversampling).round()
delta = 0.5 * (oversampling - 1) * a / oversampling
......
......@@ -9,8 +9,10 @@
import tempfile
import warnings
import nose
import numpy as np
import kwant
from kwant import plotter
from nose.tools import assert_raises
if plotter.mpl_enabled:
from mpl_toolkits import mplot3d
from matplotlib import pyplot
......@@ -142,3 +144,23 @@ def test_map():
file=out)
nose.tools.assert_raises(ValueError, plotter.map, sys,
xrange(len(sys.sites())), file=out)
def test_mask_interpolate():
# A coordinate array with coordinates of two points almost coinciding.
coords = np.random.rand(10, 2)
coords[5] *= 1e-8
coords[5] += coords[0]
warnings.simplefilter("ignore")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
plotter.mask_interpolate(coords, np.ones(len(coords)), a=1)
assert len(w) == 1
assert issubclass(w[-1].category, RuntimeWarning)
assert "coinciding" in str(w[-1].message)
assert_raises(ValueError, plotter.mask_interpolate,
coords, np.ones(len(coords)))
assert_raises(ValueError, plotter.mask_interpolate,
coords, np.ones(2 * len(coords)))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment