diff --git a/kwant/plotter.py b/kwant/plotter.py index 913f24cd7e6bd3ed5fb554b05edaa2cfdda867cc..c0e6ac2438ea5fb89cb96d9ad10e78585c27a03e 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -16,6 +16,7 @@ system in two or three dimensions. from collections import defaultdict import warnings +import random import numpy as np import tinyarray as ta from scipy import spatial, interpolate @@ -79,6 +80,11 @@ def nparray_if_array(var): return np.asarray(var) if isarray(var) else var +def _sample_array(array, n_samples): + la = len(array) + return array[random.sample(xrange(la), min(n_samples, la))] + + if mpl_enabled: class LineCollection(collections.LineCollection): def __init__(self, segments, reflen=None, **kwargs): @@ -1174,7 +1180,7 @@ def plot(sys, num_lead_cells=2, unit='nn', # If no hoppings are present, use for the same purpose distances # from ten randomly selected points to the remaining points in the # system. - points = sites_pos[np.random.randint(len(sites_pos), size=10)].T + points = _sample_array(sites_pos, 10).T distances = (sites_pos.T.reshape(1, -1, dim) - points.reshape(-1, 1, dim)).reshape(-1, dim) distances = np.sort(np.sum(distances**2, axis=1)) @@ -1424,7 +1430,8 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): tree = spatial.cKDTree(coords) - points = coords[np.random.randint(len(coords), size=10)] + # Select 10 sites to compare -- comparing them all is too costly. + points = _sample_array(coords, 10) min_dist = np.min(tree.query(points, 2)[0][:, 1]) if min_dist < 1e-6 * np.linalg.norm(cmax - cmin): warnings.warn("Some sites have nearly coinciding positions, "