Commit 151ff59f authored by Joseph Weston's avatar Joseph Weston
Browse files

ensure unique sites are picked when randomly sampling

parent 12929415
......@@ -16,6 +16,7 @@ system in two or three dimensions.
from collections import defaultdict
import warnings
import random
import numpy as np
import tinyarray as ta
from scipy import spatial, interpolate
......@@ -79,6 +80,11 @@ def nparray_if_array(var):
return np.asarray(var) if isarray(var) else var
def _sample_array(array, n_samples):
la = len(array)
return array[random.sample(xrange(la), min(n_samples, la))]
if mpl_enabled:
class LineCollection(collections.LineCollection):
def __init__(self, segments, reflen=None, **kwargs):
......@@ -1174,7 +1180,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
# If no hoppings are present, use for the same purpose distances
# from ten randomly selected points to the remaining points in the
# system.
points = sites_pos[np.random.randint(len(sites_pos), size=10)].T
points = _sample_array(sites_pos, 10).T
distances = (sites_pos.T.reshape(1, -1, dim) -
points.reshape(-1, 1, dim)).reshape(-1, dim)
distances = np.sort(np.sum(distances**2, axis=1))
......@@ -1424,7 +1430,8 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
tree = spatial.cKDTree(coords)
points = coords[np.random.randint(len(coords), size=10)]
# Select 10 sites to compare -- comparing them all is too costly.
points = _sample_array(coords, 10)
min_dist = np.min(tree.query(points, 2)[0][:, 1])
if min_dist < 1e-6 * np.linalg.norm(cmax - cmin):
warnings.warn("Some sites have nearly coinciding positions, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment