diff --git a/kwant/plotter.py b/kwant/plotter.py index 2fc6b6acf2e21e7d65f8b0b3336394c836389858..a570214101e281df3f10db28416636077a8d4b94 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -1618,26 +1618,27 @@ def interpolate_current(syst, current, relwidth=None, abswidth=None, n=9): # Define length scale in terms of the bump width. scale = 2 / width + padding = width / 2 lens *= scale # Create field array. field_shape = np.zeros(dim + 1, int) field_shape[dim] = dim for d in range(dim): - field_shape[d] = int(bbox_size[d] * n / width + 1.5*n) + field_shape[d] = int(bbox_size[d] * n / width + n) if field_shape[d] % 2: field_shape[d] += 1 field = np.zeros(field_shape) - region = [np.linspace(bbox_min[d] - 0.75*width, - bbox_max[d] + 0.75*width, + region = [np.linspace(bbox_min[d] - padding, + bbox_max[d] + padding, field_shape[d]) for d in range(dim)] - grid_density = (field_shape[:dim] - 1) / (bbox_max + 1.5*width - bbox_min) + grid_density = (field_shape[:dim] - 1) / (bbox_max + 2*padding - bbox_min) slices = np.empty((len(hops), dim, 2), int) slices[:, :, 0] = np.floor((min_hops - bbox_min) * grid_density) - slices[:, :, 1] = np.ceil((max_hops + 1.5*width - bbox_min) * grid_density) + slices[:, :, 1] = np.ceil((max_hops + 2*padding - bbox_min) * grid_density) # Interpolate the field for each hopping. for i in range(len(current)): diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index a5d06a36bf13e1a8a549620fda9d98ff762d699d..c937e3214dd08b2cc47ce16721e8bfa542c0d9ef 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -337,6 +337,12 @@ def rotational_currents(g): return null_space_basis +def _border_is_0(field): + borders = [(0, slice(None)), (-1, slice(None)), + (slice(None), 0), (slice(None), -1)] + return all(np.allclose(field[a, b], 0) for a, b in borders) + + def test_current_interpolation(): ## Passing a Builder will raise an error @@ -423,6 +429,28 @@ def test_current_interpolation(): # 3rd value returned from 'linregress' is 'rvalue' assert scipy.stats.linregress(np.log(data))[2] < -0.8 + ## Test that the current is always identically zero at the boundaries of the box + syst = kwant.Builder() + lat = kwant.lattice.square() + syst[[lat(0, 0), lat(1, 0)]] = None + syst[(lat(0, 0), lat(1, 0))] = None + syst = syst.finalized() + current = [1, -1] + + ns = [3, 4, 5, 10, 100] + abswidths = [0.01, 0.1, 1, 10, 100] + relwidths = [0.01, 0.1, 1, 10, 100] + for n, abswidth in itertools.product(ns, abswidths): + field, _ = kwant.plotter.interpolate_current(syst, current, + abswidth=abswidth, n=n) + assert _border_is_0(field) + for n, relwidth in itertools.product(ns, relwidths): + field, _ = kwant.plotter.interpolate_current(syst, current, + relwidth=relwidth, n=n) + assert _border_is_0(field) + + + @pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") def test_current():