From 7a328741d286002554291bdc70a51f8941a0e41d Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Fri, 6 Apr 2012 18:03:25 +0200
Subject: [PATCH] break up plot into smaller functions, optimize

---
 kwant/plotter.py | 309 ++++++++++++++++++++++-------------------------
 1 file changed, 145 insertions(+), 164 deletions(-)

diff --git a/kwant/plotter.py b/kwant/plotter.py
index a69c7e45..a5da086d 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -2,6 +2,7 @@
 
 from math import sqrt, pi, sin, cos, tan
 from numpy import dot, add, subtract
+import numpy as np
 import warnings
 import cairo
 try:
@@ -267,6 +268,126 @@ class Polygon(object):
             ctx.stroke()
 
 
+def iterate_lead_sites_builder(system, lead_copies):
+    for lead in system.leads:
+        if not isinstance(lead, kwant.builder.BuilderLead):
+            continue
+        sym = lead.builder.symmetry
+        shift = sym.which(lead.neighbors[0]) + 1
+
+        for i in xrange(lead_copies):
+            for site in lead.builder.sites():
+                yield sym.act(shift + i, site), i
+
+
+def iterate_lead_hoppings_builder(system, lead_copies):
+    for lead in system.leads:
+        if not isinstance(lead, kwant.builder.BuilderLead):
+            continue
+        sym = lead.builder.symmetry
+        shift = sym.which(lead.neighbors[0]) + 1
+
+        for i in xrange(lead_copies):
+            for site1, site2 in lead.builder.hoppings():
+                shift1 = sym.which(site1)[0]
+                shift2 = sym.which(site2)[0]
+                if shift1 >= shift2:
+                    yield (sym.act(shift + i, site1),
+                           sym.act(shift + i, site2),
+                           i + shift1, i  + shift2)
+                else:
+                    # Note: this makes sure that hoppings beyond the unit
+                    #       cell are always ordered such that they are into
+                    #       the previous slice
+                    yield (sym.act(shift + i - 1, site1),
+                           sym.act(shift + i - 1, site2),
+                           i - 1 + shift1, i - 1 + shift2)
+
+
+def iterate_scattreg_sites_builder(system):
+    for site in system.sites():
+        yield site
+
+
+def iterate_scattreg_hoppings_builder(system):
+    for hopping in system.hoppings():
+        yield hopping
+
+
+def empty_generator(*args, **kwds):
+    return
+    yield
+
+
+def iterate_scattreg_sites_llsys(system):
+    return xrange(system.graph.num_nodes)
+
+
+def iterate_scattreg_hoppings_llsys(system):
+    for i in xrange(system.graph.num_nodes):
+        for j in system.graph.out_neighbors(i):
+            # Only yield half of the hoppings (as builder does)
+            if i < j:
+                yield i, j
+
+
+def extent(pos, sites):
+    """Figure out the extent of the system."""
+    minx = miny = inf = float('inf')
+    maxx = maxy = float('-inf')
+    for site in sites:
+        try:
+            x, y = pos(site)
+        except TypeError:
+            raise RuntimeError("Only 2 dimensions are supported by plot")
+        minx = min(x, minx)
+        maxx = max(x, maxx)
+        miny = min(y, miny)
+        maxy = max(y, maxy)
+    if minx == inf:
+        warnings.warn("Plotting empty system");
+        return 0, 1, 0, 1
+    return minx, maxx, miny, maxy
+
+
+def typical_distance(pos, hoppings, sites):
+    min_sq_dist = inf = float('inf')
+    for site1, site2 in hoppings:
+        tmp = subtract(pos(site1), pos(site2))
+        sq_dist = dot(tmp, tmp)
+        if 0 < sq_dist < min_sq_dist:
+            min_sq_dist = sq_dist
+
+    # If there were no hoppings, then we can only find the distance by checking
+    # the distances between all pairs sites (potentially slow).  To speed this
+    # only look at the distances between 10 chosen sites and all the remaining
+    # sites.  This simple heuristics works well in practice and is fast enough.
+    if min_sq_dist == inf:
+        first = True
+        positions = list(pos(site) for site in sites)
+
+        for site1 in positions[:: max(len(positions) // 10, 1)]:
+            for site2 in positions:
+                tmp = subtract(site1, site2)
+                sq_dist = dot(tmp, tmp)
+                if 0 < sq_dist < min_sq_dist:
+                    min_sq_dist = sq_dist
+
+    # If min_sq_dist ist still 0, all sites sit at the same spot In this case I
+    # can just use any value for dist (rangex and rangey will also be 0 then)
+    return sqrt(min_sq_dist) if min_sq_dist != inf else 1
+
+
+def default_pos(system):
+    if isinstance(system, kwant.builder.Builder):
+        return lambda site: site.pos
+    elif isinstance(system, kwant.builder.FiniteSystem):
+        return lambda i: system.site(i).pos
+    else:
+        raise ValueError("`pos` argument needed when plotting"
+                         " systems which are not (finalized) builders")
+
+
 def plot(system, filename=defaultname, fmt=None, a=None,
          width=600, height=None, border=0.1, bcol=white, pos=None,
          symbols=Circle(r=0.3), lines=Line(lw=0.1),
@@ -485,63 +606,6 @@ def plot(system, filename=defaultname, fmt=None, a=None,
       and `blue`.
     """
 
-    def iterate_lead_sites_builder(system, lead_copies):
-        for lead in system.leads:
-            if not isinstance(lead, kwant.builder.BuilderLead):
-                continue
-            sym = lead.builder.symmetry
-            shift = sym.which(lead.neighbors[0]) + 1
-
-            for i in xrange(lead_copies):
-                for site in lead.builder.sites():
-                    yield sym.act(shift + i, site), i
-
-    def iterate_lead_hoppings_builder(system, lead_copies):
-        for lead in system.leads:
-            if not isinstance(lead, kwant.builder.BuilderLead):
-                continue
-            sym = lead.builder.symmetry
-            shift = sym.which(lead.neighbors[0]) + 1
-
-            for i in xrange(lead_copies):
-                for site1, site2 in lead.builder.hoppings():
-                    shift1 = sym.which(site1)[0]
-                    shift2 = sym.which(site2)[0]
-                    if shift1 >= shift2:
-                        yield (sym.act(shift + i, site1),
-                               sym.act(shift + i, site2),
-                               i + shift1, i  + shift2)
-                    else:
-                        # Note: this makes sure that hoppings beyond the unit
-                        #       cell are always ordered such that they are into
-                        #       the previous slice
-                        yield (sym.act(shift + i - 1, site1),
-                               sym.act(shift + i - 1, site2),
-                               i - 1 + shift1, i - 1 + shift2)
-
-    def iterate_scattreg_sites_builder(system):
-        for site in system.sites():
-            yield site
-
-    def iterate_scattreg_hoppings_builder(system):
-        for hopping in system.hoppings():
-            yield hopping
-
-    def empty_generator(*args, **kwds):
-        return
-        yield
-
-    def iterate_scattreg_sites_llsys(system):
-        return xrange(system.graph.num_nodes)
-
-    def iterate_scattreg_hoppings_llsys(system):
-        for i in xrange(system.graph.num_nodes):
-            for j in system.graph.out_neighbors(i):
-                # Only yield half of the hoppings (as builder does)
-                if i < j:
-                    yield i, j
-
-
     def iterate_all_sites(system, lead_copies=0):
         for site in iterate_scattreg_sites(system):
             yield site
@@ -576,22 +640,8 @@ def plot(system, filename=defaultname, fmt=None, a=None,
     if width is None and height is None:
         raise ValueError("One of width and height must be not None")
 
-    if a is None:
-        dist = 0
-    else:
-        if a > 0:
-            dist = a
-        else:
-            raise ValueError("The distance a must be >0")
-
     if pos is None:
-        if is_builder:
-            pos = lambda site: site.pos
-        elif is_lowlevel:
-            pos = lambda i: system.site(i).pos
-        else:
-            raise ValueError("`pos` argument needed when plotting"
-                             " systems which are not (finalized) builders")
+        pos = default_pos(system)
 
     if fmt is None and filename is not None:
         # Try to figure out the format from the filename
@@ -607,9 +657,7 @@ def plot(system, filename=defaultname, fmt=None, a=None,
         raise ValueError("The requested functionality requires the "
                          "Python Image Library (PIL)")
 
-    # symbols and lines may be constant or functions
-    # Here they are wrapped as a function
-
+    # Symbols and lines may be constant or functions.  Wrap them as functions.
     if hasattr(symbols, "__call__"):
         fsymbols = symbols
     elif is_builder and hasattr(symbols, "__getitem__"):
@@ -645,92 +693,29 @@ def plot(system, filename=defaultname, fmt=None, a=None,
     else:
         fllines = lambda x, y : lead_lines
 
-    # Figure out the extent of the system
-    nsites = 0
-    first = True
-    for site in iterate_all_sites(system, len(lead_fading)):
-        sitepos = pos(site)
-        nsites += 1
-
-        if len(sitepos) != 2:
-            raise RuntimeError("Only 2 dimensions are supported by plot")
-
-        if first:
-            minx = maxx = sitepos[0]
-            miny = maxy = sitepos[1]
-            first = False
-        else:
-            minx = min(sitepos[0], minx)
-            maxx = max(sitepos[0], maxx)
-            miny = min(sitepos[1], miny)
-            maxy = max(sitepos[1], maxy)
-
-    if nsites == 0:
-        warnings.warn("Empty system. No output generated");
-        return
-
-    rangex = (maxx - minx) / (1 - 2 * border)
-    rangey = (maxy - miny) / (1 - 2 * border)
+    minx, maxx, miny, maxy = \
+        extent(pos, iterate_all_sites(system, len(lead_fading)))
 
     # If the user gave no typical distance between sites, we need to figure it
     # out ourselves
     # (Note: it is enough to consider one copy of the lead unit cell for
     #        figuring out distances, because of the translational symmetry)
     if a is None:
-        first = True
-        for site1, site2 in iterate_all_hoppings(system, lead_copies=1):
-
-            tmp = subtract(pos(site1), pos(site2))
-            sitedist = sqrt(dot(tmp, tmp))
-
-            if sitedist > 0:
-                if first:
-                    dist = sitedist
-                    first = False
-                else:
-                    dist = min(dist, sitedist)
-
-        # If there were no hoppings, then we can only find the distance
-        # by checking the distance between all sites (potentially slow)
-        if dist == 0:
-            warnings.warn("Finding the typical distance automatically"
-                          "may be slow!")
-
-            first = True
-            # TODO: hm, in this way I will check distances always twice
-            # in principle, it would be enoughto go through all sites
-            # site2 *after* site1
-            # it's only a factor 2 in speed, so not clear if it makes
-            # sense to change it
-            for site1 in iterate_all_sites(system, lead_copies=1):
-                for site2 in iterate_all_sites(system, lead_copies=1):
-
-                    tmp = subtract(pos(site1), pos(site2))
-                    sitedist = sqrt(dot(tmp, tmp))
-
-                    if sitedist > 0:
-                        if first:
-                            dist = sitedist
-                            first = False
-                        else:
-                            dist = min(dist, sitedist)
-
-        # If dist ist still 0, all sites sit at the same spot
-        # In this case I can just use any value for dist
-        # (rangex and rangey will also be 0 then)
-        if dist == 0:
-            dist = 1
+        a = typical_distance(pos, iterate_all_hoppings(system, lead_copies=1),
+                             iterate_all_sites(system, lead_copies=1))
+    elif a <= 0:
+        raise ValueError("The distance a must be >0")
 
     # Use the typical distance, if one of the ranges is 0
     # (e.g. in a one-dimensional system)
-
+    rangex = (maxx - minx) / (1 - 2 * border)
     if rangex == 0:
-        rangex = dist / (1 - 2 * border)
+        rangex = a / (1 - 2 * border)
+    rangey = (maxy - miny) / (1 - 2 * border)
     if rangey == 0:
-        rangey = dist / (1 - 2 * border)
+        rangey = a / (1 - 2 * border)
 
     # Compare with the desired dimensions of the plot
-
     if height is None:
         height = width * rangey / rangex
     elif width is None:
@@ -756,7 +741,6 @@ def plot(system, filename=defaultname, fmt=None, a=None,
     elif fmt == "png" or fmt == "jpg" or fmt is None:
         surface = cairo.ImageSurface(cairo.FORMAT_ARGB32,
                                      int(round(width)), int(round(height)))
-
     ctx = cairo.Context(surface)
 
     # The default background in the image surface is black
@@ -788,15 +772,12 @@ def plot(system, filename=defaultname, fmt=None, a=None,
     ctx.scale(width/rangex, -height/rangey)
     ctx.translate(-minx, -miny)
 
-    # Now draw the system!
-
-    # The lines for the hoppings
-
+    #### Draw the lines for the hoppings.
     for site1, site2 in iterate_scattreg_hoppings(system):
         line = flines(site1, site2)
 
         if line is not None:
-            line._draw_cairo(ctx, pos(site1), pos(site2), dist)
+            line._draw_cairo(ctx, pos(site1), pos(site2), a)
 
     for site1, site2, ucindx1, ucindx2 in \
             iterate_lead_hoppings(system, len(lead_fading)):
@@ -804,7 +785,7 @@ def plot(system, filename=defaultname, fmt=None, a=None,
             line = fllines(site1, site2)
 
             if line is not None:
-                line._draw_cairo(ctx, pos(site1), pos(site2), dist,
+                line._draw_cairo(ctx, pos(site1), pos(site2), a,
                                  fading=(bcol, lead_fading[ucindx1]))
         else:
             if ucindx1 > -1:
@@ -812,44 +793,44 @@ def plot(system, filename=defaultname, fmt=None, a=None,
                 if line is not None:
                     line._draw_cairo(ctx, pos(site1),
                                      0.5 * add(pos(site1), pos(site2)),
-                                     dist, fading=(bcol, lead_fading[ucindx1]))
+                                     a, fading=(bcol, lead_fading[ucindx1]))
             else:
                 #one end of the line is in the system
                 line = flines(site1, site2)
                 if line is not None:
                     line._draw_cairo(ctx, pos(site1),
-                                     0.5 * add(pos(site1), pos(site2)), dist)
+                                     0.5 * add(pos(site1), pos(site2)), a)
 
             if ucindx2 > -1:
                 line = fllines(site2, site1)
                 if line is not None:
                     line._draw_cairo(ctx, pos(site2),
                                      0.5 * add(pos(site1), pos(site2)),
-                                     dist, fading=(bcol, lead_fading[ucindx2]))
+                                     a, fading=(bcol, lead_fading[ucindx2]))
             else:
-                #one end of the line is in the system
+                # One end of the line is in the system
                 line = flines(site2, site1)
                 if line is not None:
                     line._draw_cairo(ctx, pos(site2),
-                                     0.5 * add(pos(site1), pos(site2)), dist)
-    # the symbols for the sites
+                                     0.5 * add(pos(site1), pos(site2)), a)
 
+    #### Draw the symbols for the sites.
     for site in iterate_scattreg_sites(system):
         symbol = fsymbols(site)
 
         if symbol is not None:
-            symbol._draw_cairo(ctx, pos(site), dist)
+            symbol._draw_cairo(ctx, pos(site), a)
 
     for site, ucindx in iterate_lead_sites(system,
                                            lead_copies=len(lead_fading)):
         symbol = flsymbols(site)
 
         if symbol is not None:
-            symbol._draw_cairo(ctx, pos(site), dist,
+            symbol._draw_cairo(ctx, pos(site), a,
                                fading=(bcol, lead_fading[ucindx]))
 
 
-    # Show or save the picture, if necessary (depends on format)
+    # Show or save the picture, if necessary (depends on format).
     if fmt == None:
         im = Image.frombuffer("RGBA",
                               (surface.get_width(), surface.get_height()),
-- 
GitLab