diff --git a/kwant/plotter.py b/kwant/plotter.py
index 42ba2bfb99fc4d3dbc8b40ae9434c1526bb4f5c9..913f24cd7e6bd3ed5fb554b05edaa2cfdda867cc 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -1156,6 +1156,12 @@ def plot(sys, num_lead_cells=2, unit='nn',
     end_pos = resize_to_dim(end_pos)
     start_pos = resize_to_dim(start_pos)
 
+    # Apply transformations to the data
+    if pos_transform is not None:
+        sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos)
+        end_pos = np.apply_along_axis(pos_transform, 1, end_pos)
+        start_pos = np.apply_along_axis(pos_transform, 1, start_pos)
+
     # Determine the reference length.
     if unit == 'pt':
         reflen = None
@@ -1187,12 +1193,6 @@ def plot(sys, num_lead_cells=2, unit='nn',
         except:
             raise ValueError('Invalid value of unit argument.')
 
-    # Apply transformations to the data
-    if pos_transform is not None:
-        sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos)
-        end_pos = np.apply_along_axis(pos_transform, 1, end_pos)
-        start_pos = np.apply_along_axis(pos_transform, 1, start_pos)
-
     # make all specs proper: either constant or lists/np.arrays:
     def make_proper_site_spec(spec, fancy_indexing=False):
         if callable(spec):