From 667a4c07e8f4fdb4819b8e861a89a2ad555f7b0b Mon Sep 17 00:00:00 2001
From: Michael Wimmer <wimmer@lorentz.leidenuniv.nl>
Date: Wed, 28 Aug 2013 16:29:55 +0200
Subject: [PATCH] bug fix: make site-dependent linewidths work properly

---
 kwant/plotter.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/kwant/plotter.py b/kwant/plotter.py
index 4e19bf2..335ba85 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -71,15 +71,16 @@ if mpl_enabled:
             self._linewidths_orig = nparray_if_array(linewidths)
 
         def draw(self, renderer):
-            linewidths = self._linewidths_orig
             if self.reflen is not None:
                 # Note: only works for aspect ratio 1!
                 #       72.0 - there is 72 points in an inch
                 factor = (self.axes.transData.frozen().to_values()[0] * 72.0 *
                           self.reflen / self.figure.dpi)
-                linewidths *= factor
+            else:
+                factor = 1
 
-            super(LineCollection, self).set_linewidths(linewidths)
+            super(LineCollection, self).set_linewidths(self._linewidths_orig *
+                                                       factor)
             return super(LineCollection, self).draw(renderer)
 
 
@@ -170,7 +171,6 @@ if mpl_enabled:
                 return -self._zorder3d
 
             def draw(self, renderer):
-                linewidths = self._linewidths_orig
                 if self.reflen:
                     proj_len = projected_length(self.axes, self.reflen)
                     args = self.axes.transData.frozen().to_values()
@@ -181,9 +181,11 @@ if mpl_enabled:
                     #       transformation.
                     factor = proj_len * (args[0] +
                                          args[3]) * 0.5 * 72.0 / self.figure.dpi
-                    linewidths *= factor
+                else:
+                    factor = 1
 
-                super(Line3DCollection, self).set_linewidths(linewidths)
+                super(Line3DCollection, self).set_linewidths(
+                                                self._linewidths_orig * factor)
                 super(Line3DCollection, self).draw(renderer)
 
 
-- 
GitLab