diff --git a/kwant/plotter.py b/kwant/plotter.py index 42ba2bfb99fc4d3dbc8b40ae9434c1526bb4f5c9..913f24cd7e6bd3ed5fb554b05edaa2cfdda867cc 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -1156,6 +1156,12 @@ def plot(sys, num_lead_cells=2, unit='nn', end_pos = resize_to_dim(end_pos) start_pos = resize_to_dim(start_pos) + # Apply transformations to the data + if pos_transform is not None: + sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos) + end_pos = np.apply_along_axis(pos_transform, 1, end_pos) + start_pos = np.apply_along_axis(pos_transform, 1, start_pos) + # Determine the reference length. if unit == 'pt': reflen = None @@ -1187,12 +1193,6 @@ def plot(sys, num_lead_cells=2, unit='nn', except: raise ValueError('Invalid value of unit argument.') - # Apply transformations to the data - if pos_transform is not None: - sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos) - end_pos = np.apply_along_axis(pos_transform, 1, end_pos) - start_pos = np.apply_along_axis(pos_transform, 1, start_pos) - # make all specs proper: either constant or lists/np.arrays: def make_proper_site_spec(spec, fancy_indexing=False): if callable(spec):