From 7a328741d286002554291bdc70a51f8941a0e41d Mon Sep 17 00:00:00 2001 From: Christoph Groth <christoph.groth@cea.fr> Date: Fri, 6 Apr 2012 18:03:25 +0200 Subject: [PATCH] break up plot into smaller functions, optimize --- kwant/plotter.py | 309 ++++++++++++++++++++++------------------------- 1 file changed, 145 insertions(+), 164 deletions(-) diff --git a/kwant/plotter.py b/kwant/plotter.py index a69c7e45..a5da086d 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -2,6 +2,7 @@ from math import sqrt, pi, sin, cos, tan from numpy import dot, add, subtract +import numpy as np import warnings import cairo try: @@ -267,6 +268,126 @@ class Polygon(object): ctx.stroke() +def iterate_lead_sites_builder(system, lead_copies): + for lead in system.leads: + if not isinstance(lead, kwant.builder.BuilderLead): + continue + sym = lead.builder.symmetry + shift = sym.which(lead.neighbors[0]) + 1 + + for i in xrange(lead_copies): + for site in lead.builder.sites(): + yield sym.act(shift + i, site), i + + +def iterate_lead_hoppings_builder(system, lead_copies): + for lead in system.leads: + if not isinstance(lead, kwant.builder.BuilderLead): + continue + sym = lead.builder.symmetry + shift = sym.which(lead.neighbors[0]) + 1 + + for i in xrange(lead_copies): + for site1, site2 in lead.builder.hoppings(): + shift1 = sym.which(site1)[0] + shift2 = sym.which(site2)[0] + if shift1 >= shift2: + yield (sym.act(shift + i, site1), + sym.act(shift + i, site2), + i + shift1, i + shift2) + else: + # Note: this makes sure that hoppings beyond the unit + # cell are always ordered such that they are into + # the previous slice + yield (sym.act(shift + i - 1, site1), + sym.act(shift + i - 1, site2), + i - 1 + shift1, i - 1 + shift2) + + +def iterate_scattreg_sites_builder(system): + for site in system.sites(): + yield site + + +def iterate_scattreg_hoppings_builder(system): + for hopping in system.hoppings(): + yield hopping + + +def empty_generator(*args, **kwds): + return + yield + + +def iterate_scattreg_sites_llsys(system): + return xrange(system.graph.num_nodes) + + +def iterate_scattreg_hoppings_llsys(system): + for i in xrange(system.graph.num_nodes): + for j in system.graph.out_neighbors(i): + # Only yield half of the hoppings (as builder does) + if i < j: + yield i, j + + +def extent(pos, sites): + """Figure out the extent of the system.""" + minx = miny = inf = float('inf') + maxx = maxy = float('-inf') + for site in sites: + try: + x, y = pos(site) + except TypeError: + raise RuntimeError("Only 2 dimensions are supported by plot") + minx = min(x, minx) + maxx = max(x, maxx) + miny = min(y, miny) + maxy = max(y, maxy) + if minx == inf: + warnings.warn("Plotting empty system"); + return 0, 1, 0, 1 + return minx, maxx, miny, maxy + + +def typical_distance(pos, hoppings, sites): + min_sq_dist = inf = float('inf') + for site1, site2 in hoppings: + tmp = subtract(pos(site1), pos(site2)) + sq_dist = dot(tmp, tmp) + if 0 < sq_dist < min_sq_dist: + min_sq_dist = sq_dist + + # If there were no hoppings, then we can only find the distance by checking + # the distances between all pairs sites (potentially slow). To speed this + # only look at the distances between 10 chosen sites and all the remaining + # sites. This simple heuristics works well in practice and is fast enough. + if min_sq_dist == inf: + first = True + positions = list(pos(site) for site in sites) + + for site1 in positions[:: max(len(positions) // 10, 1)]: + for site2 in positions: + tmp = subtract(site1, site2) + sq_dist = dot(tmp, tmp) + if 0 < sq_dist < min_sq_dist: + min_sq_dist = sq_dist + + # If min_sq_dist ist still 0, all sites sit at the same spot In this case I + # can just use any value for dist (rangex and rangey will also be 0 then) + return sqrt(min_sq_dist) if min_sq_dist != inf else 1 + + +def default_pos(system): + if isinstance(system, kwant.builder.Builder): + return lambda site: site.pos + elif isinstance(system, kwant.builder.FiniteSystem): + return lambda i: system.site(i).pos + else: + raise ValueError("`pos` argument needed when plotting" + " systems which are not (finalized) builders") + + def plot(system, filename=defaultname, fmt=None, a=None, width=600, height=None, border=0.1, bcol=white, pos=None, symbols=Circle(r=0.3), lines=Line(lw=0.1), @@ -485,63 +606,6 @@ def plot(system, filename=defaultname, fmt=None, a=None, and `blue`. """ - def iterate_lead_sites_builder(system, lead_copies): - for lead in system.leads: - if not isinstance(lead, kwant.builder.BuilderLead): - continue - sym = lead.builder.symmetry - shift = sym.which(lead.neighbors[0]) + 1 - - for i in xrange(lead_copies): - for site in lead.builder.sites(): - yield sym.act(shift + i, site), i - - def iterate_lead_hoppings_builder(system, lead_copies): - for lead in system.leads: - if not isinstance(lead, kwant.builder.BuilderLead): - continue - sym = lead.builder.symmetry - shift = sym.which(lead.neighbors[0]) + 1 - - for i in xrange(lead_copies): - for site1, site2 in lead.builder.hoppings(): - shift1 = sym.which(site1)[0] - shift2 = sym.which(site2)[0] - if shift1 >= shift2: - yield (sym.act(shift + i, site1), - sym.act(shift + i, site2), - i + shift1, i + shift2) - else: - # Note: this makes sure that hoppings beyond the unit - # cell are always ordered such that they are into - # the previous slice - yield (sym.act(shift + i - 1, site1), - sym.act(shift + i - 1, site2), - i - 1 + shift1, i - 1 + shift2) - - def iterate_scattreg_sites_builder(system): - for site in system.sites(): - yield site - - def iterate_scattreg_hoppings_builder(system): - for hopping in system.hoppings(): - yield hopping - - def empty_generator(*args, **kwds): - return - yield - - def iterate_scattreg_sites_llsys(system): - return xrange(system.graph.num_nodes) - - def iterate_scattreg_hoppings_llsys(system): - for i in xrange(system.graph.num_nodes): - for j in system.graph.out_neighbors(i): - # Only yield half of the hoppings (as builder does) - if i < j: - yield i, j - - def iterate_all_sites(system, lead_copies=0): for site in iterate_scattreg_sites(system): yield site @@ -576,22 +640,8 @@ def plot(system, filename=defaultname, fmt=None, a=None, if width is None and height is None: raise ValueError("One of width and height must be not None") - if a is None: - dist = 0 - else: - if a > 0: - dist = a - else: - raise ValueError("The distance a must be >0") - if pos is None: - if is_builder: - pos = lambda site: site.pos - elif is_lowlevel: - pos = lambda i: system.site(i).pos - else: - raise ValueError("`pos` argument needed when plotting" - " systems which are not (finalized) builders") + pos = default_pos(system) if fmt is None and filename is not None: # Try to figure out the format from the filename @@ -607,9 +657,7 @@ def plot(system, filename=defaultname, fmt=None, a=None, raise ValueError("The requested functionality requires the " "Python Image Library (PIL)") - # symbols and lines may be constant or functions - # Here they are wrapped as a function - + # Symbols and lines may be constant or functions. Wrap them as functions. if hasattr(symbols, "__call__"): fsymbols = symbols elif is_builder and hasattr(symbols, "__getitem__"): @@ -645,92 +693,29 @@ def plot(system, filename=defaultname, fmt=None, a=None, else: fllines = lambda x, y : lead_lines - # Figure out the extent of the system - nsites = 0 - first = True - for site in iterate_all_sites(system, len(lead_fading)): - sitepos = pos(site) - nsites += 1 - - if len(sitepos) != 2: - raise RuntimeError("Only 2 dimensions are supported by plot") - - if first: - minx = maxx = sitepos[0] - miny = maxy = sitepos[1] - first = False - else: - minx = min(sitepos[0], minx) - maxx = max(sitepos[0], maxx) - miny = min(sitepos[1], miny) - maxy = max(sitepos[1], maxy) - - if nsites == 0: - warnings.warn("Empty system. No output generated"); - return - - rangex = (maxx - minx) / (1 - 2 * border) - rangey = (maxy - miny) / (1 - 2 * border) + minx, maxx, miny, maxy = \ + extent(pos, iterate_all_sites(system, len(lead_fading))) # If the user gave no typical distance between sites, we need to figure it # out ourselves # (Note: it is enough to consider one copy of the lead unit cell for # figuring out distances, because of the translational symmetry) if a is None: - first = True - for site1, site2 in iterate_all_hoppings(system, lead_copies=1): - - tmp = subtract(pos(site1), pos(site2)) - sitedist = sqrt(dot(tmp, tmp)) - - if sitedist > 0: - if first: - dist = sitedist - first = False - else: - dist = min(dist, sitedist) - - # If there were no hoppings, then we can only find the distance - # by checking the distance between all sites (potentially slow) - if dist == 0: - warnings.warn("Finding the typical distance automatically" - "may be slow!") - - first = True - # TODO: hm, in this way I will check distances always twice - # in principle, it would be enoughto go through all sites - # site2 *after* site1 - # it's only a factor 2 in speed, so not clear if it makes - # sense to change it - for site1 in iterate_all_sites(system, lead_copies=1): - for site2 in iterate_all_sites(system, lead_copies=1): - - tmp = subtract(pos(site1), pos(site2)) - sitedist = sqrt(dot(tmp, tmp)) - - if sitedist > 0: - if first: - dist = sitedist - first = False - else: - dist = min(dist, sitedist) - - # If dist ist still 0, all sites sit at the same spot - # In this case I can just use any value for dist - # (rangex and rangey will also be 0 then) - if dist == 0: - dist = 1 + a = typical_distance(pos, iterate_all_hoppings(system, lead_copies=1), + iterate_all_sites(system, lead_copies=1)) + elif a <= 0: + raise ValueError("The distance a must be >0") # Use the typical distance, if one of the ranges is 0 # (e.g. in a one-dimensional system) - + rangex = (maxx - minx) / (1 - 2 * border) if rangex == 0: - rangex = dist / (1 - 2 * border) + rangex = a / (1 - 2 * border) + rangey = (maxy - miny) / (1 - 2 * border) if rangey == 0: - rangey = dist / (1 - 2 * border) + rangey = a / (1 - 2 * border) # Compare with the desired dimensions of the plot - if height is None: height = width * rangey / rangex elif width is None: @@ -756,7 +741,6 @@ def plot(system, filename=defaultname, fmt=None, a=None, elif fmt == "png" or fmt == "jpg" or fmt is None: surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, int(round(width)), int(round(height))) - ctx = cairo.Context(surface) # The default background in the image surface is black @@ -788,15 +772,12 @@ def plot(system, filename=defaultname, fmt=None, a=None, ctx.scale(width/rangex, -height/rangey) ctx.translate(-minx, -miny) - # Now draw the system! - - # The lines for the hoppings - + #### Draw the lines for the hoppings. for site1, site2 in iterate_scattreg_hoppings(system): line = flines(site1, site2) if line is not None: - line._draw_cairo(ctx, pos(site1), pos(site2), dist) + line._draw_cairo(ctx, pos(site1), pos(site2), a) for site1, site2, ucindx1, ucindx2 in \ iterate_lead_hoppings(system, len(lead_fading)): @@ -804,7 +785,7 @@ def plot(system, filename=defaultname, fmt=None, a=None, line = fllines(site1, site2) if line is not None: - line._draw_cairo(ctx, pos(site1), pos(site2), dist, + line._draw_cairo(ctx, pos(site1), pos(site2), a, fading=(bcol, lead_fading[ucindx1])) else: if ucindx1 > -1: @@ -812,44 +793,44 @@ def plot(system, filename=defaultname, fmt=None, a=None, if line is not None: line._draw_cairo(ctx, pos(site1), 0.5 * add(pos(site1), pos(site2)), - dist, fading=(bcol, lead_fading[ucindx1])) + a, fading=(bcol, lead_fading[ucindx1])) else: #one end of the line is in the system line = flines(site1, site2) if line is not None: line._draw_cairo(ctx, pos(site1), - 0.5 * add(pos(site1), pos(site2)), dist) + 0.5 * add(pos(site1), pos(site2)), a) if ucindx2 > -1: line = fllines(site2, site1) if line is not None: line._draw_cairo(ctx, pos(site2), 0.5 * add(pos(site1), pos(site2)), - dist, fading=(bcol, lead_fading[ucindx2])) + a, fading=(bcol, lead_fading[ucindx2])) else: - #one end of the line is in the system + # One end of the line is in the system line = flines(site2, site1) if line is not None: line._draw_cairo(ctx, pos(site2), - 0.5 * add(pos(site1), pos(site2)), dist) - # the symbols for the sites + 0.5 * add(pos(site1), pos(site2)), a) + #### Draw the symbols for the sites. for site in iterate_scattreg_sites(system): symbol = fsymbols(site) if symbol is not None: - symbol._draw_cairo(ctx, pos(site), dist) + symbol._draw_cairo(ctx, pos(site), a) for site, ucindx in iterate_lead_sites(system, lead_copies=len(lead_fading)): symbol = flsymbols(site) if symbol is not None: - symbol._draw_cairo(ctx, pos(site), dist, + symbol._draw_cairo(ctx, pos(site), a, fading=(bcol, lead_fading[ucindx])) - # Show or save the picture, if necessary (depends on format) + # Show or save the picture, if necessary (depends on format). if fmt == None: im = Image.frombuffer("RGBA", (surface.get_width(), surface.get_height()), -- GitLab