diff --git a/kwant/plotter.py b/kwant/plotter.py index 12d340e24c7d2baa98108ccbedc8cbebe78ed8f1..301d6edec3cbe5c2b7a1930fd774a795d35a50a1 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -1873,9 +1873,12 @@ def interpolate_density(syst, density, relwidth=None, abswidth=None, n=9, (bbox_min, bbox_max), width, padding, field) if mask: + # Field is zero when we are > 0.5*width from any site (as bump has + # finite support), so we mask positions a little further than this. field = _mask(field, - boundaries, - np.array([s.pos for s in syst.sites])) + box=boundaries, + coords=np.array([s.pos for s in syst.sites]), + cutoff=0.6*width) return field, boundaries @@ -2183,21 +2186,15 @@ def current(syst, current, relwidth=0.05, **kwargs): **kwargs) -def _mask(field, box, coords): +def _mask(field, box, coords, cutoff): tree = spatial.cKDTree(coords) - # Select 10 sites to compare -- comparing them all is too costly. - points = _sample_array(coords, 10) - min_dist = np.min(tree.query(points, 2)[0][:, 1]) - # Build the mask initially as a 2D array dims = tuple(slice(boxmin, boxmax, 1j * shape) for (boxmin, boxmax), shape in zip(box, field.shape)) mask = np.mgrid[dims].reshape(len(box), -1).T - # '0.4' (which is just below sqrt(2) - 1) makes tree.query() exact - # in the common case of a square lattice. - mask = tree.query(mask, eps=0.4)[0] > min_dist + mask = tree.query(mask, distance_upper_bound=cutoff)[0] == np.inf return np.ma.masked_array(field, mask)