Skip to content
Snippets Groups Projects
Commit 7a328741 authored by Christoph Groth's avatar Christoph Groth
Browse files

break up plot into smaller functions, optimize

parent f6947fb8
No related branches found
No related tags found
No related merge requests found
......@@ -2,6 +2,7 @@
from math import sqrt, pi, sin, cos, tan
from numpy import dot, add, subtract
import numpy as np
import warnings
import cairo
try:
......@@ -267,6 +268,126 @@ class Polygon(object):
ctx.stroke()
def iterate_lead_sites_builder(system, lead_copies):
for lead in system.leads:
if not isinstance(lead, kwant.builder.BuilderLead):
continue
sym = lead.builder.symmetry
shift = sym.which(lead.neighbors[0]) + 1
for i in xrange(lead_copies):
for site in lead.builder.sites():
yield sym.act(shift + i, site), i
def iterate_lead_hoppings_builder(system, lead_copies):
for lead in system.leads:
if not isinstance(lead, kwant.builder.BuilderLead):
continue
sym = lead.builder.symmetry
shift = sym.which(lead.neighbors[0]) + 1
for i in xrange(lead_copies):
for site1, site2 in lead.builder.hoppings():
shift1 = sym.which(site1)[0]
shift2 = sym.which(site2)[0]
if shift1 >= shift2:
yield (sym.act(shift + i, site1),
sym.act(shift + i, site2),
i + shift1, i + shift2)
else:
# Note: this makes sure that hoppings beyond the unit
# cell are always ordered such that they are into
# the previous slice
yield (sym.act(shift + i - 1, site1),
sym.act(shift + i - 1, site2),
i - 1 + shift1, i - 1 + shift2)
def iterate_scattreg_sites_builder(system):
for site in system.sites():
yield site
def iterate_scattreg_hoppings_builder(system):
for hopping in system.hoppings():
yield hopping
def empty_generator(*args, **kwds):
return
yield
def iterate_scattreg_sites_llsys(system):
return xrange(system.graph.num_nodes)
def iterate_scattreg_hoppings_llsys(system):
for i in xrange(system.graph.num_nodes):
for j in system.graph.out_neighbors(i):
# Only yield half of the hoppings (as builder does)
if i < j:
yield i, j
def extent(pos, sites):
"""Figure out the extent of the system."""
minx = miny = inf = float('inf')
maxx = maxy = float('-inf')
for site in sites:
try:
x, y = pos(site)
except TypeError:
raise RuntimeError("Only 2 dimensions are supported by plot")
minx = min(x, minx)
maxx = max(x, maxx)
miny = min(y, miny)
maxy = max(y, maxy)
if minx == inf:
warnings.warn("Plotting empty system");
return 0, 1, 0, 1
return minx, maxx, miny, maxy
def typical_distance(pos, hoppings, sites):
min_sq_dist = inf = float('inf')
for site1, site2 in hoppings:
tmp = subtract(pos(site1), pos(site2))
sq_dist = dot(tmp, tmp)
if 0 < sq_dist < min_sq_dist:
min_sq_dist = sq_dist
# If there were no hoppings, then we can only find the distance by checking
# the distances between all pairs sites (potentially slow). To speed this
# only look at the distances between 10 chosen sites and all the remaining
# sites. This simple heuristics works well in practice and is fast enough.
if min_sq_dist == inf:
first = True
positions = list(pos(site) for site in sites)
for site1 in positions[:: max(len(positions) // 10, 1)]:
for site2 in positions:
tmp = subtract(site1, site2)
sq_dist = dot(tmp, tmp)
if 0 < sq_dist < min_sq_dist:
min_sq_dist = sq_dist
# If min_sq_dist ist still 0, all sites sit at the same spot In this case I
# can just use any value for dist (rangex and rangey will also be 0 then)
return sqrt(min_sq_dist) if min_sq_dist != inf else 1
def default_pos(system):
if isinstance(system, kwant.builder.Builder):
return lambda site: site.pos
elif isinstance(system, kwant.builder.FiniteSystem):
return lambda i: system.site(i).pos
else:
raise ValueError("`pos` argument needed when plotting"
" systems which are not (finalized) builders")
def plot(system, filename=defaultname, fmt=None, a=None,
width=600, height=None, border=0.1, bcol=white, pos=None,
symbols=Circle(r=0.3), lines=Line(lw=0.1),
......@@ -485,63 +606,6 @@ def plot(system, filename=defaultname, fmt=None, a=None,
and `blue`.
"""
def iterate_lead_sites_builder(system, lead_copies):
for lead in system.leads:
if not isinstance(lead, kwant.builder.BuilderLead):
continue
sym = lead.builder.symmetry
shift = sym.which(lead.neighbors[0]) + 1
for i in xrange(lead_copies):
for site in lead.builder.sites():
yield sym.act(shift + i, site), i
def iterate_lead_hoppings_builder(system, lead_copies):
for lead in system.leads:
if not isinstance(lead, kwant.builder.BuilderLead):
continue
sym = lead.builder.symmetry
shift = sym.which(lead.neighbors[0]) + 1
for i in xrange(lead_copies):
for site1, site2 in lead.builder.hoppings():
shift1 = sym.which(site1)[0]
shift2 = sym.which(site2)[0]
if shift1 >= shift2:
yield (sym.act(shift + i, site1),
sym.act(shift + i, site2),
i + shift1, i + shift2)
else:
# Note: this makes sure that hoppings beyond the unit
# cell are always ordered such that they are into
# the previous slice
yield (sym.act(shift + i - 1, site1),
sym.act(shift + i - 1, site2),
i - 1 + shift1, i - 1 + shift2)
def iterate_scattreg_sites_builder(system):
for site in system.sites():
yield site
def iterate_scattreg_hoppings_builder(system):
for hopping in system.hoppings():
yield hopping
def empty_generator(*args, **kwds):
return
yield
def iterate_scattreg_sites_llsys(system):
return xrange(system.graph.num_nodes)
def iterate_scattreg_hoppings_llsys(system):
for i in xrange(system.graph.num_nodes):
for j in system.graph.out_neighbors(i):
# Only yield half of the hoppings (as builder does)
if i < j:
yield i, j
def iterate_all_sites(system, lead_copies=0):
for site in iterate_scattreg_sites(system):
yield site
......@@ -576,22 +640,8 @@ def plot(system, filename=defaultname, fmt=None, a=None,
if width is None and height is None:
raise ValueError("One of width and height must be not None")
if a is None:
dist = 0
else:
if a > 0:
dist = a
else:
raise ValueError("The distance a must be >0")
if pos is None:
if is_builder:
pos = lambda site: site.pos
elif is_lowlevel:
pos = lambda i: system.site(i).pos
else:
raise ValueError("`pos` argument needed when plotting"
" systems which are not (finalized) builders")
pos = default_pos(system)
if fmt is None and filename is not None:
# Try to figure out the format from the filename
......@@ -607,9 +657,7 @@ def plot(system, filename=defaultname, fmt=None, a=None,
raise ValueError("The requested functionality requires the "
"Python Image Library (PIL)")
# symbols and lines may be constant or functions
# Here they are wrapped as a function
# Symbols and lines may be constant or functions. Wrap them as functions.
if hasattr(symbols, "__call__"):
fsymbols = symbols
elif is_builder and hasattr(symbols, "__getitem__"):
......@@ -645,92 +693,29 @@ def plot(system, filename=defaultname, fmt=None, a=None,
else:
fllines = lambda x, y : lead_lines
# Figure out the extent of the system
nsites = 0
first = True
for site in iterate_all_sites(system, len(lead_fading)):
sitepos = pos(site)
nsites += 1
if len(sitepos) != 2:
raise RuntimeError("Only 2 dimensions are supported by plot")
if first:
minx = maxx = sitepos[0]
miny = maxy = sitepos[1]
first = False
else:
minx = min(sitepos[0], minx)
maxx = max(sitepos[0], maxx)
miny = min(sitepos[1], miny)
maxy = max(sitepos[1], maxy)
if nsites == 0:
warnings.warn("Empty system. No output generated");
return
rangex = (maxx - minx) / (1 - 2 * border)
rangey = (maxy - miny) / (1 - 2 * border)
minx, maxx, miny, maxy = \
extent(pos, iterate_all_sites(system, len(lead_fading)))
# If the user gave no typical distance between sites, we need to figure it
# out ourselves
# (Note: it is enough to consider one copy of the lead unit cell for
# figuring out distances, because of the translational symmetry)
if a is None:
first = True
for site1, site2 in iterate_all_hoppings(system, lead_copies=1):
tmp = subtract(pos(site1), pos(site2))
sitedist = sqrt(dot(tmp, tmp))
if sitedist > 0:
if first:
dist = sitedist
first = False
else:
dist = min(dist, sitedist)
# If there were no hoppings, then we can only find the distance
# by checking the distance between all sites (potentially slow)
if dist == 0:
warnings.warn("Finding the typical distance automatically"
"may be slow!")
first = True
# TODO: hm, in this way I will check distances always twice
# in principle, it would be enoughto go through all sites
# site2 *after* site1
# it's only a factor 2 in speed, so not clear if it makes
# sense to change it
for site1 in iterate_all_sites(system, lead_copies=1):
for site2 in iterate_all_sites(system, lead_copies=1):
tmp = subtract(pos(site1), pos(site2))
sitedist = sqrt(dot(tmp, tmp))
if sitedist > 0:
if first:
dist = sitedist
first = False
else:
dist = min(dist, sitedist)
# If dist ist still 0, all sites sit at the same spot
# In this case I can just use any value for dist
# (rangex and rangey will also be 0 then)
if dist == 0:
dist = 1
a = typical_distance(pos, iterate_all_hoppings(system, lead_copies=1),
iterate_all_sites(system, lead_copies=1))
elif a <= 0:
raise ValueError("The distance a must be >0")
# Use the typical distance, if one of the ranges is 0
# (e.g. in a one-dimensional system)
rangex = (maxx - minx) / (1 - 2 * border)
if rangex == 0:
rangex = dist / (1 - 2 * border)
rangex = a / (1 - 2 * border)
rangey = (maxy - miny) / (1 - 2 * border)
if rangey == 0:
rangey = dist / (1 - 2 * border)
rangey = a / (1 - 2 * border)
# Compare with the desired dimensions of the plot
if height is None:
height = width * rangey / rangex
elif width is None:
......@@ -756,7 +741,6 @@ def plot(system, filename=defaultname, fmt=None, a=None,
elif fmt == "png" or fmt == "jpg" or fmt is None:
surface = cairo.ImageSurface(cairo.FORMAT_ARGB32,
int(round(width)), int(round(height)))
ctx = cairo.Context(surface)
# The default background in the image surface is black
......@@ -788,15 +772,12 @@ def plot(system, filename=defaultname, fmt=None, a=None,
ctx.scale(width/rangex, -height/rangey)
ctx.translate(-minx, -miny)
# Now draw the system!
# The lines for the hoppings
#### Draw the lines for the hoppings.
for site1, site2 in iterate_scattreg_hoppings(system):
line = flines(site1, site2)
if line is not None:
line._draw_cairo(ctx, pos(site1), pos(site2), dist)
line._draw_cairo(ctx, pos(site1), pos(site2), a)
for site1, site2, ucindx1, ucindx2 in \
iterate_lead_hoppings(system, len(lead_fading)):
......@@ -804,7 +785,7 @@ def plot(system, filename=defaultname, fmt=None, a=None,
line = fllines(site1, site2)
if line is not None:
line._draw_cairo(ctx, pos(site1), pos(site2), dist,
line._draw_cairo(ctx, pos(site1), pos(site2), a,
fading=(bcol, lead_fading[ucindx1]))
else:
if ucindx1 > -1:
......@@ -812,44 +793,44 @@ def plot(system, filename=defaultname, fmt=None, a=None,
if line is not None:
line._draw_cairo(ctx, pos(site1),
0.5 * add(pos(site1), pos(site2)),
dist, fading=(bcol, lead_fading[ucindx1]))
a, fading=(bcol, lead_fading[ucindx1]))
else:
#one end of the line is in the system
line = flines(site1, site2)
if line is not None:
line._draw_cairo(ctx, pos(site1),
0.5 * add(pos(site1), pos(site2)), dist)
0.5 * add(pos(site1), pos(site2)), a)
if ucindx2 > -1:
line = fllines(site2, site1)
if line is not None:
line._draw_cairo(ctx, pos(site2),
0.5 * add(pos(site1), pos(site2)),
dist, fading=(bcol, lead_fading[ucindx2]))
a, fading=(bcol, lead_fading[ucindx2]))
else:
#one end of the line is in the system
# One end of the line is in the system
line = flines(site2, site1)
if line is not None:
line._draw_cairo(ctx, pos(site2),
0.5 * add(pos(site1), pos(site2)), dist)
# the symbols for the sites
0.5 * add(pos(site1), pos(site2)), a)
#### Draw the symbols for the sites.
for site in iterate_scattreg_sites(system):
symbol = fsymbols(site)
if symbol is not None:
symbol._draw_cairo(ctx, pos(site), dist)
symbol._draw_cairo(ctx, pos(site), a)
for site, ucindx in iterate_lead_sites(system,
lead_copies=len(lead_fading)):
symbol = flsymbols(site)
if symbol is not None:
symbol._draw_cairo(ctx, pos(site), dist,
symbol._draw_cairo(ctx, pos(site), a,
fading=(bcol, lead_fading[ucindx]))
# Show or save the picture, if necessary (depends on format)
# Show or save the picture, if necessary (depends on format).
if fmt == None:
im = Image.frombuffer("RGBA",
(surface.get_width(), surface.get_height()),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment