Commit 151ff59f authored by Joseph Weston's avatar Joseph Weston
Browse files

ensure unique sites are picked when randomly sampling

parent 12929415
...@@ -16,6 +16,7 @@ system in two or three dimensions. ...@@ -16,6 +16,7 @@ system in two or three dimensions.
from collections import defaultdict from collections import defaultdict
import warnings import warnings
import random
import numpy as np import numpy as np
import tinyarray as ta import tinyarray as ta
from scipy import spatial, interpolate from scipy import spatial, interpolate
...@@ -79,6 +80,11 @@ def nparray_if_array(var): ...@@ -79,6 +80,11 @@ def nparray_if_array(var):
return np.asarray(var) if isarray(var) else var return np.asarray(var) if isarray(var) else var
def _sample_array(array, n_samples):
la = len(array)
return array[random.sample(xrange(la), min(n_samples, la))]
if mpl_enabled: if mpl_enabled:
class LineCollection(collections.LineCollection): class LineCollection(collections.LineCollection):
def __init__(self, segments, reflen=None, **kwargs): def __init__(self, segments, reflen=None, **kwargs):
...@@ -1174,7 +1180,7 @@ def plot(sys, num_lead_cells=2, unit='nn', ...@@ -1174,7 +1180,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
# If no hoppings are present, use for the same purpose distances # If no hoppings are present, use for the same purpose distances
# from ten randomly selected points to the remaining points in the # from ten randomly selected points to the remaining points in the
# system. # system.
points = sites_pos[np.random.randint(len(sites_pos), size=10)].T points = _sample_array(sites_pos, 10).T
distances = (sites_pos.T.reshape(1, -1, dim) - distances = (sites_pos.T.reshape(1, -1, dim) -
points.reshape(-1, 1, dim)).reshape(-1, dim) points.reshape(-1, 1, dim)).reshape(-1, dim)
distances = np.sort(np.sum(distances**2, axis=1)) distances = np.sort(np.sum(distances**2, axis=1))
...@@ -1424,7 +1430,8 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): ...@@ -1424,7 +1430,8 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
tree = spatial.cKDTree(coords) tree = spatial.cKDTree(coords)
points = coords[np.random.randint(len(coords), size=10)] # Select 10 sites to compare -- comparing them all is too costly.
points = _sample_array(coords, 10)
min_dist = np.min(tree.query(points, 2)[0][:, 1]) min_dist = np.min(tree.query(points, 2)[0][:, 1])
if min_dist < 1e-6 * np.linalg.norm(cmax - cmin): if min_dist < 1e-6 * np.linalg.norm(cmax - cmin):
warnings.warn("Some sites have nearly coinciding positions, " warnings.warn("Some sites have nearly coinciding positions, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment