Commit 495a3e29 authored by Christoph Groth's avatar Christoph Groth
Browse files

Merge branch 'stable' into 'master'

parents e16d9738 57037c8d
......@@ -16,6 +16,7 @@ system in two or three dimensions.
from collections import defaultdict
import warnings
import random
import numpy as np
import tinyarray as ta
from scipy import spatial, interpolate
......@@ -79,6 +80,11 @@ def nparray_if_array(var):
return np.asarray(var) if isarray(var) else var
def _sample_array(array, n_samples):
la = len(array)
return array[random.sample(range(la), min(n_samples, la))]
if mpl_enabled:
class LineCollection(collections.LineCollection):
def __init__(self, segments, reflen=None, **kwargs):
......@@ -1181,7 +1187,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
# If no hoppings are present, use for the same purpose distances
# from ten randomly selected points to the remaining points in the
# system.
points = sites_pos[np.random.randint(len(sites_pos), size=10)].T
points = _sample_array(sites_pos, 10).T
distances = (sites_pos.T.reshape(1, -1, dim) -
points.reshape(-1, 1, dim)).reshape(-1, dim)
distances = np.sort(np.sum(distances**2, axis=1))
......@@ -1431,7 +1437,8 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
tree = spatial.cKDTree(coords)
points = coords[np.random.randint(len(coords), size=10)]
# Select 10 sites to compare -- comparing them all is too costly.
points = _sample_array(coords, 10)
min_dist = np.min(tree.query(points, 2)[0][:, 1])
if min_dist < 1e-6 * np.linalg.norm(cmax - cmin):
warnings.warn("Some sites have nearly coinciding positions, "
......
......@@ -149,18 +149,15 @@ def test_map():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plotter.map(syst.finalized(), range(len(syst.sites())),
file=out)
file=out)
pytest.raises(ValueError, plotter.map, syst,
range(len(syst.sites())), file=out)
def test_mask_interpolate():
# A coordinate array with coordinates of two points almost coinciding.
coords = np.random.rand(10, 2)
coords[5] *= 1e-8
coords[5] += coords[0]
coords = np.array([[0, 0], [1e-7, 1e-7], [1, 1], [1, 0]])
warnings.simplefilter("ignore")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
plotter.mask_interpolate(coords, np.ones(len(coords)), a=1)
......@@ -168,7 +165,9 @@ def test_mask_interpolate():
assert issubclass(w[-1].category, RuntimeWarning)
assert "coinciding" in str(w[-1].message)
pytest.raises(ValueError, plotter.mask_interpolate, coords,
np.ones(len(coords)))
pytest.raises(ValueError, plotter.mask_interpolate, coords, np.ones(2 *
len(coords)))
with warnings.catch_warnings():
warnings.simplefilter("ignore")
pytest.raises(ValueError, plotter.mask_interpolate,
coords, np.ones(len(coords)))
pytest.raises(ValueError, plotter.mask_interpolate,
coords, np.ones(2 * len(coords)))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment