From 7fb1d60936ba8a744d274d9ae01a2f16bfe3ccc9 Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Thu, 5 Apr 2012 15:12:37 +0200
Subject: [PATCH] make plotter work with general sequences of numbers as
 positions

---
 kwant/plotter.py | 31 ++++++++++++++++---------------
 1 file changed, 16 insertions(+), 15 deletions(-)

diff --git a/kwant/plotter.py b/kwant/plotter.py
index 8b95f489..b9bbb6fa 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -1,6 +1,7 @@
 """kwant.plotter docstring"""
 
 from math import sqrt, pi, sin, cos, tan
+from numpy import dot, add, subtract
 import warnings
 import cairo
 try:
@@ -346,7 +347,7 @@ def plot(system, filename=defaultname, fmt=None, a=None,
         transparant symbol], but then again there is no reason for
         having a white box behind everything)
     pos : callable or None, optional
-        When passed a site should return its (2D) position as a numpy array of
+        When passed a site should return its (2D) position as a sequence of
         length 2. If None, the method pos() of the site is used.
         Defaults to None.
     symbols : {symbol_like, callable, dict, None}, optional
@@ -410,7 +411,7 @@ def plot(system, filename=defaultname, fmt=None, a=None,
          _draw_cairo(ctx, pos, reflen[, fading])
 
       which draws the symbol onto the cairo context `ctx`
-      at the position `pos` (passed as a numpy array of length 2).
+      at the position `pos` (passed as a sequence of length 2).
       `reflen` is the reference length, allowing the symbol to use
       relative sizes. (Note though that `pos` is in **absolute** cairo
       coordinates).
@@ -433,7 +434,7 @@ def plot(system, filename=defaultname, fmt=None, a=None,
 
       which draws the something (typically a line of some sort) onto
       the cairo context `ctx` connecting the position `pos1` and
-      `pos2` (passed as a numpy arrays of length 2).  `reflen` is the
+      `pos2` (passed as sequences of length 2).  `reflen` is the
       reference length, allowing the line to use relative sizes. (Note
       though that `pos1` and `pos2` are in **absolute** cairo
       coordinates).
@@ -691,9 +692,8 @@ def plot(system, filename=defaultname, fmt=None, a=None,
         first = True
         for site1, site2 in iterate_all_hoppings(system, lead_copies=1):
 
-            # TODO: can I assume always numpy?
-            sitedist = sqrt((pos(site1)[0] - pos(site2)[0])**2 +
-                            (pos(site1)[1] - pos(site2)[1])**2)
+            tmp = subtract(pos(site1), pos(site2))
+            sitedist = sqrt(dot(tmp, tmp))
 
             if sitedist > 0:
                 if first:
@@ -717,9 +717,8 @@ def plot(system, filename=defaultname, fmt=None, a=None,
             for site1 in iterate_all_sites(system, lead_copies=1):
                 for site2 in iterate_all_sites(system, lead_copies=1):
 
-                    # TODO: can I assume always numpy?
-                    sitedist = sqrt((pos(site1)[0] - pos(site2)[0])**2 +
-                                    (pos(site1)[1] - pos(site2)[1])**2)
+                    tmp = subtract(pos(site1), pos(site2))
+                    sitedist = sqrt(dot(tmp, tmp))
 
                     if sitedist > 0:
                         if first:
@@ -823,26 +822,28 @@ def plot(system, filename=defaultname, fmt=None, a=None,
             if ucindx1 > -1:
                 line = fllines(site1, site2)
                 if line is not None:
-                    line._draw_cairo(ctx, pos(site1), (pos(site1)+pos(site2))/2,
+                    line._draw_cairo(ctx, pos(site1),
+                                     0.5 * add(pos(site1), pos(site2)),
                                      dist, fading=(bcol, lead_fading[ucindx1]))
             else:
                 #one end of the line is in the system
                 line = flines(site1, site2)
                 if line is not None:
-                    line._draw_cairo(ctx, pos(site1), (pos(site1)+pos(site2))/2,
-                                     dist)
+                    line._draw_cairo(ctx, pos(site1),
+                                     0.5 * add(pos(site1), pos(site2)), dist)
 
             if ucindx2 > -1:
                 line = fllines(site2, site1)
                 if line is not None:
-                    line._draw_cairo(ctx, pos(site2), (pos(site1)+pos(site2))/2,
+                    line._draw_cairo(ctx, pos(site2),
+                                     0.5 * add(pos(site1), pos(site2)),
                                      dist, fading=(bcol, lead_fading[ucindx2]))
             else:
                 #one end of the line is in the system
                 line = flines(site2, site1)
                 if line is not None:
-                    line._draw_cairo(ctx, pos(site2), (pos(site1)+pos(site2))/2,
-                                     dist)
+                    line._draw_cairo(ctx, pos(site2),
+                                     0.5 * add(pos(site1), pos(site2)), dist)
     # the symbols for the sites
 
     for site in iterate_system_sites(system):
-- 
GitLab