diff --git a/doc/source/images/1-quantum_wire.py.diff b/doc/source/images/1-quantum_wire.py.diff index ed74c0973b9cad8cf2ae73067f22a21422d0c192..30d2dc9975651db99a185152b25d3d3bf38d1713 100644 --- a/doc/source/images/1-quantum_wire.py.diff +++ b/doc/source/images/1-quantum_wire.py.diff @@ -1,26 +1,29 @@ --- original +++ modified -@@ -9,6 +9,7 @@ - # - Using the simple sparse solver for computing Landauer conductance +@@ -10,6 +10,8 @@ + from matplotlib import pyplot import kwant -+import latex, html ++import latex ++import html # First, define the tight-binding system -@@ -73,7 +74,8 @@ +@@ -73,8 +75,9 @@ + sys.attach_lead(lead1) # Plot it, to make sure it's OK - +- -kwant.plot(sys) -+kwant.plot(sys, "1-quantum_wire_sys.pdf", width=latex.figwidth_pt) -+kwant.plot(sys, "1-quantum_wire_sys.png", width=html.figwidth_px) ++size = (latex.figwidth_in, 0.3 * latex.figwidth_in) ++kwant.plot(sys, file="1-quantum_wire_sys.pdf", fig_size=size, dpi=html.dpi) ++kwant.plot(sys, file="1-quantum_wire_sys.png", fig_size=size, dpi=html.dpi) # Finalize the system -@@ -98,8 +100,14 @@ +@@ -98,8 +101,13 @@ + # Use matplotlib to write output # We should see conductance steps - from matplotlib import pyplot -pyplot.figure() +fig = pyplot.figure() @@ -32,8 +35,7 @@ +pyplot.ylabel("conductance [in units of e^2/h]", fontsize=latex.mpl_label_size) +pyplot.setp(fig.get_axes()[0].get_xticklabels(), fontsize=latex.mpl_tick_size) +pyplot.setp(fig.get_axes()[0].get_yticklabels(), fontsize=latex.mpl_tick_size) -+fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) +fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) +fig.savefig("1-quantum_wire_result.pdf") -+fig.savefig("1-quantum_wire_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++fig.savefig("1-quantum_wire_result.png", dpi=html.dpi) diff --git a/doc/source/images/2-ab_ring.py.diff b/doc/source/images/2-ab_ring.py.diff index 29bd464fd4dc9b4573c7871ec43b923edf93589c..1075a5ce81fad6d59ee4e43ce3a4cb182f4a4109 100644 --- a/doc/source/images/2-ab_ring.py.diff +++ b/doc/source/images/2-ab_ring.py.diff @@ -81,11 +81,10 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("2-ab_ring_result.pdf") -+ fig.savefig("2-ab_ring_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("2-ab_ring_result.png", dpi=html.dpi) def main(): @@ -93,8 +92,9 @@ # Check that the system looks as intended. - kwant.plot(sys) -+ kwant.plot(sys, "2-ab_ring_sys.pdf", width=latex.figwidth_pt) -+ kwant.plot(sys, "2-ab_ring_sys.png", width=html.figwidth_px) ++ size = (latex.figwidth_in, latex.figwidth_in) ++ kwant.plot(sys, file="2-ab_ring_sys.pdf", fig_size=size, dpi=html.dpi) ++ kwant.plot(sys, file="2-ab_ring_sys.png", fig_size=size, dpi=html.dpi) # Finalize the system. sys = sys.finalized() @@ -104,11 +104,11 @@ + # Finally, some plots needed for the notes + sys = make_system_note1() -+ kwant.plot(sys, "2-ab_ring_note1.pdf", width=latex.figwidth_small_pt) -+ kwant.plot(sys, "2-ab_ring_note1.png", width=html.figwidth_small_px) ++ kwant.plot(sys, file="2-ab_ring_note1.pdf", fig_size=size, dpi=html.dpi) ++ kwant.plot(sys, file="2-ab_ring_note1.png", fig_size=size, dpi=html.dpi) + sys = make_system_note2() -+ kwant.plot(sys, "2-ab_ring_note2.pdf", width=latex.figwidth_small_pt) -+ kwant.plot(sys, "2-ab_ring_note2.png", width=html.figwidth_small_px) ++ kwant.plot(sys, file="2-ab_ring_note2.pdf", fig_size=size, dpi=html.dpi) ++ kwant.plot(sys, file="2-ab_ring_note2.png", fig_size=size, dpi=html.dpi) + + # Call the main function if the script gets executed (as opposed to imported). diff --git a/doc/source/images/2-quantum_well.py.diff b/doc/source/images/2-quantum_well.py.diff index b3e158e4e955fc36decf046d494e248248e9a358..579ba29460feff9a3f1db0827a4ee8b72589b022 100644 --- a/doc/source/images/2-quantum_well.py.diff +++ b/doc/source/images/2-quantum_well.py.diff @@ -8,7 +8,7 @@ # global variable governing the behavior of potential() in # make_system() -@@ -76,19 +77,26 @@ +@@ -76,19 +77,25 @@ smatrix = kwant.solve(sys, energy) data.append(smatrix.transmission(1, 0)) @@ -26,11 +26,10 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("2-quantum_well_result.pdf") -+ fig.savefig("2-quantum_well_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("2-quantum_well_result.png", dpi=html.dpi) def main(): diff --git a/doc/source/images/2-spin_orbit.py.diff b/doc/source/images/2-spin_orbit.py.diff index ec1bbd826eb7a7a28a453ed43d7d4ccde9afe649..d935aa9411b4f4d7c86871e931c7c6228d3508f0 100644 --- a/doc/source/images/2-spin_orbit.py.diff +++ b/doc/source/images/2-spin_orbit.py.diff @@ -8,7 +8,7 @@ # define Pauli-matrices for convenience sigma_0 = numpy.eye(2) -@@ -73,19 +74,25 @@ +@@ -73,19 +74,24 @@ smatrix = kwant.solve(sys, energy) data.append(smatrix.transmission(1, 0)) @@ -25,11 +25,10 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("2-spin_orbit_result.pdf") -+ fig.savefig("2-spin_orbit_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("2-spin_orbit_result.png", dpi=html.dpi) def main(): diff --git a/doc/source/images/3-band_structure.py.diff b/doc/source/images/3-band_structure.py.diff index 0ca2e45df01c597c3b76c782b6fe0eb5608c110c..277316381a5e55a82570c587951ce8244073e03f 100644 --- a/doc/source/images/3-band_structure.py.diff +++ b/doc/source/images/3-band_structure.py.diff @@ -8,7 +8,7 @@ def make_lead(a=1, t=1.0, W=10): -@@ -39,11 +40,20 @@ +@@ -39,11 +40,19 @@ # the bandstructure energy_list = [lead.energies(k) for k in momenta] @@ -25,11 +25,10 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("3-band_structure_result.pdf") -+ fig.savefig("3-band_structure_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("3-band_structure_result.png", dpi=html.dpi) def main(): diff --git a/doc/source/images/3-closed_system.py.diff b/doc/source/images/3-closed_system.py.diff index 0bb1a9c1ffcfbac10fa1d83d1ac133ed3c6b9507..7394098d154243250ba640a67d32ab87c65fd58f 100644 --- a/doc/source/images/3-closed_system.py.diff +++ b/doc/source/images/3-closed_system.py.diff @@ -8,7 +8,7 @@ def make_system(a=1, t=1.0, r=10): -@@ -69,18 +70,30 @@ +@@ -69,19 +70,24 @@ # we only plot the 15 lowest eigenvalues energies.append(ev[:15]) @@ -25,22 +25,18 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("3-closed_system_result.pdf") -+ fig.savefig("3-closed_system_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("3-closed_system_result.png", dpi=html.dpi) def main(): sys = make_system() - # Check that the system looks as intended. +- # Check that the system looks as intended. - kwant.plot(sys) -+ kwant.plot(sys, filename="3-closed_system_sys.pdf", -+ width=latex.figwidth_pt) -+ kwant.plot(sys, filename="3-closed_system_sys.png", -+ width=html.figwidth_px) - +- # Finalize the system. sys = sys.finalized() + diff --git a/doc/source/images/4-graphene.py.diff b/doc/source/images/4-graphene.py.diff index c86c7b55e1f331c641221f8af55639a3b553d57f..09401251ca6840faed6ca03f25b42622906423ff 100644 --- a/doc/source/images/4-graphene.py.diff +++ b/doc/source/images/4-graphene.py.diff @@ -1,14 +1,15 @@ --- original +++ modified -@@ -17,6 +17,7 @@ +@@ -17,6 +17,8 @@ # For plotting from matplotlib import pyplot -+import latex, html ++import latex ++import html # Define the graphene lattice -@@ -63,7 +64,7 @@ +@@ -63,7 +65,7 @@ return (-1 < x < 1) and (-0.4 * r < y < 0.4 * r) lead0 = kwant.Builder(sym0) @@ -17,7 +18,7 @@ for hopping in hoppings: lead0[lead0.possible_hoppings(*hopping)] = -1 -@@ -105,11 +106,21 @@ +@@ -105,11 +107,20 @@ smatrix = kwant.solve(sys, energy) data.append(smatrix.transmission(0, 1)) @@ -35,15 +36,14 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("4-graphene_result.pdf") -+ fig.savefig("4-graphene_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("4-graphene_result.png", dpi=html.dpi) def plot_bandstructure(flead, momenta): -@@ -117,11 +128,21 @@ +@@ -117,11 +128,20 @@ # the bandstructure energy_list = [flead.energies(k) for k in momenta] @@ -61,37 +61,38 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("4-graphene_bs.pdf") -+ fig.savefig("4-graphene_bs.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("4-graphene_bs.png", dpi=html.dpi) def main(): -@@ -136,17 +157,20 @@ - lcol=kwant.plotter.black)} +@@ -134,17 +154,22 @@ + return 0 if site.group == a else 1 # Plot the closed system without leads. -- kwant.plot(sys, symbols=plotter_symbols) +- kwant.plot(sys, site_color=group_colors, colorbar=False) - - # Compute some eigenvalues. - compute_evs(sys.finalized()) -+ kwant.plot(sys, symbols=plotter_symbols, -+ filename="4-graphene_sys1.pdf", width=latex.figwidth_pt) -+ kwant.plot(sys, symbols=plotter_symbols, -+ filename="4-graphene_sys1.png", width=html.figwidth_px) ++ size = (latex.figwidth_in, latex.figwidth_in) ++ kwant.plot(sys, site_color=group_colors, colorbar=False, ++ file="4-graphene_sys1.pdf", fig_size=size, dpi=html.dpi) ++ kwant.plot(sys, site_color=group_colors, colorbar=False, ++ file="4-graphene_sys1.png", fig_size=size, dpi=html.dpi) # Attach the leads to the system. for lead in leads: sys.attach_lead(lead) # Then, plot the system with leads. -- kwant.plot(sys, symbols=plotter_symbols) -+ kwant.plot(sys, symbols=plotter_symbols, -+ filename="4-graphene_sys2.pdf", width=latex.figwidth_pt) -+ kwant.plot(sys, symbols=plotter_symbols, -+ filename="4-graphene_sys2.png", width=html.figwidth_px) +- kwant.plot(sys, site_color=group_colors, colorbar=False) ++ size = (latex.figwidth_in, 0.9 * latex.figwidth_in) ++ kwant.plot(sys, site_color=group_colors, colorbar=False, ++ file="4-graphene_sys2.pdf", fig_size=size, dpi=html.dpi) ++ kwant.plot(sys, site_color=group_colors, colorbar=False, ++ file="4-graphene_sys2.png", fig_size=size, dpi=html.dpi) # Finalize the system. sys = sys.finalized() diff --git a/doc/source/images/5-superconductor_band_structure.py.diff b/doc/source/images/5-superconductor_band_structure.py.diff index ee2fbdd6dd7500a0c6c34153447c38aece0ab6c4..7392cd639bca0e8f253f26978dedfb8a30aaea65 100644 --- a/doc/source/images/5-superconductor_band_structure.py.diff +++ b/doc/source/images/5-superconductor_band_structure.py.diff @@ -8,7 +8,7 @@ tau_x = np.array([[0, 1], [1, 0]]) tau_z = np.array([[1, 0], [0, -1]]) -@@ -46,12 +47,20 @@ +@@ -46,12 +47,19 @@ # the bandstructure energy_list = [lead.energies(k) for k in momenta] @@ -23,11 +23,10 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("5-superconductor_band_structure_result.pdf") -+ fig.savefig("5-superconductor_band_structure_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("5-superconductor_band_structure_result.png", dpi=html.dpi) def main(): diff --git a/doc/source/images/5-superconductor_transport.py.diff b/doc/source/images/5-superconductor_transport.py.diff index 241d417dfc0ce8261658f2a10c6afc29deb35958..8f6d6d9dcf8f9e42cf8c6083be3a814ebc2e7bfc 100644 --- a/doc/source/images/5-superconductor_transport.py.diff +++ b/doc/source/images/5-superconductor_transport.py.diff @@ -8,7 +8,7 @@ def make_system(a=1, W=10, L=10, barrier=1.5, barrierpos=(3, 4), -@@ -95,19 +96,24 @@ +@@ -95,19 +96,23 @@ smatrix.transmission(0, 0) + smatrix.transmission(1, 0)) @@ -22,11 +22,10 @@ + fontsize=latex.mpl_tick_size) + pyplot.setp(fig.get_axes()[0].get_yticklabels(), + fontsize=latex.mpl_tick_size) -+ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in*3./4.) ++ fig.set_size_inches(latex.mpl_width_in, latex.mpl_width_in * 3. / 4.) + fig.subplots_adjust(left=0.15, right=0.95, top=0.95, bottom=0.15) + fig.savefig("5-superconductor_transport_result.pdf") -+ fig.savefig("5-superconductor_transport_result.png", -+ dpi=(html.figwidth_px/latex.mpl_width_in)) ++ fig.savefig("5-superconductor_transport_result.png", dpi=html.dpi) def main(): diff --git a/doc/source/images/html.py b/doc/source/images/html.py index 120f7de2110c088c243dd72b3be884a431a0eb79..e0415a3b3aab1fe67d8f87c77a80829c1f926175 100644 --- a/doc/source/images/html.py +++ b/doc/source/images/html.py @@ -1,4 +1,2 @@ -# Default width of figures in pixels -figwidth_px = 600 -# Width for smaller figures -figwidth_small_px = 400 +# dpi for conversion from inches +dpi = 90 diff --git a/doc/source/images/latex.py b/doc/source/images/latex.py index 870c858e678a4d1a9f1a9ebdec2b5664f7410e5d..4abf019cc69b4b1ed730045fe2c32e53dc2ff9d3 100644 --- a/doc/source/images/latex.py +++ b/doc/source/images/latex.py @@ -1,13 +1,14 @@ -pt_to_in = 1./72. +pt_to_in = 1. / 72. # Default width of figures in pts -figwidth_pt = 300 +figwidth_pt = 600 +figwidth_in = figwidth_pt * pt_to_in # Width for smaller figures -figwidth_small_pt = 200 +figwidth_small_pt = 400 figwidth_small_in = figwidth_small_pt * pt_to_in # Sizes for matplotlib figures mpl_width_in = figwidth_pt * pt_to_in -mpl_label_size = 10 # font sizes in points -mpl_tick_size = 9 +mpl_label_size = None # font sizes in points +mpl_tick_size = None diff --git a/doc/source/reference/kwant.plotter.rst b/doc/source/reference/kwant.plotter.rst index d88d9b53795c94a921c1313347c78635d1cc8c34..0ffcae77e9671cd446aa4664ffbc4c6e9dec2cae 100644 --- a/doc/source/reference/kwant.plotter.rst +++ b/doc/source/reference/kwant.plotter.rst @@ -10,30 +10,16 @@ Plotting routine :toctree: generated/ plot - show - interpolate + map -Auxiliary types ----------------- +Data-generating functions +------------------------- .. autosummary:: :toctree: generated/ - Circle - Polygon - Line - LineStyle - Color + sys_leads_sites + sys_leads_hoppings + sys_leads_pos + sys_leads_hopping_pos + mask_interpolate -Pre-defined colors ------------------- -+------------------------+ -| `~kwant.plotter.black` | -+------------------------+ -| `~kwant.plotter.white` | -+------------------------+ -| `~kwant.plotter.red` | -+------------------------+ -| `~kwant.plotter.green` | -+------------------------+ -| `~kwant.plotter.blue` | -+------------------------+ diff --git a/doc/source/tutorial/1-quantum_wire.py b/doc/source/tutorial/1-quantum_wire.py index 9e8ad75a5d3ea08f54e337dc229327073c0b56e0..670f99a55ed0f17e81cc9df94bb157a599c96219 100644 --- a/doc/source/tutorial/1-quantum_wire.py +++ b/doc/source/tutorial/1-quantum_wire.py @@ -8,6 +8,7 @@ # - Making scattering region and leads # - Using the simple sparse solver for computing Landauer conductance +from matplotlib import pyplot #HIDDEN_BEGIN_dwhx import kwant #HIDDEN_END_dwhx @@ -119,7 +120,6 @@ for ie in xrange(100): # Use matplotlib to write output # We should see conductance steps #HIDDEN_BEGIN_lliv -from matplotlib import pyplot pyplot.figure() pyplot.plot(energies, data) diff --git a/doc/source/tutorial/4-graphene.py b/doc/source/tutorial/4-graphene.py index e024b7f0bee2a43882553aedc862d09ebe1bf5aa..147fd890038defb62afd4fb9419529af1c6888e9 100644 --- a/doc/source/tutorial/4-graphene.py +++ b/doc/source/tutorial/4-graphene.py @@ -151,13 +151,11 @@ def main(): # To highlight the two sublattices of graphene, we plot one with # a filled, and the other one with an open circle: - plotter_symbols = {a: kwant.plotter.Circle(r=0.3), - b: kwant.plotter.Circle(r=0.3, - fcol=kwant.plotter.white, - lcol=kwant.plotter.black)} + def group_colors(site): + return 0 if site.group == a else 1 # Plot the closed system without leads. - kwant.plot(sys, symbols=plotter_symbols) + kwant.plot(sys, site_color=group_colors, colorbar=False) #HIDDEN_END_itkk # Compute some eigenvalues. @@ -170,7 +168,7 @@ def main(): sys.attach_lead(lead) # Then, plot the system with leads. - kwant.plot(sys, symbols=plotter_symbols) + kwant.plot(sys, site_color=group_colors, colorbar=False) # Finalize the system. sys = sys.finalized() diff --git a/doc/source/tutorial/tutorial4.rst b/doc/source/tutorial/tutorial4.rst index 4d9eb5f4402e377fba10f9616168b266e700bbb7..528850fb85a00130c7bfad3d4cc14cb4c8ccc8e5 100644 --- a/doc/source/tutorial/tutorial4.rst +++ b/doc/source/tutorial/tutorial4.rst @@ -118,27 +118,16 @@ plot the system: :start-after: #HIDDEN_BEGIN_itkk :end-before: #HIDDEN_END_itkk -We customize the plotting: `plotter_symbols` is a dictionary with the -sublattice objects `a` and `b` as keys, and the `~kwant.plotter.Circle` objects -specify that the sublattice `a` should be drawn using a filled black circle, -and `b` using a white circle with a black outline. :: - - plotter_symbols = {a: kwant.plotter.Circle(r=0.3), - b: kwant.plotter.Circle(r=0.3, - fcol=kwant.plotter.white, - lcol=kwant.plotter.black)} - -The radius of the circle is given in relative units: `~kwant.plotter.plot` uses -a typical length scale as a reference length. By default, the typical length -scale is the smallest distance between lattice points. `~kwant.plotter.plot` -can find this length by itself, but must then go through all -hoppings. Alternatively, one can specify the typical length scale using the -argument `a` as in the example (not to be confused with the sublattice `a`) -which is here set to the distance between carbon atoms in the graphene -lattice. Specifying ``r=0.3`` in `~kwant.plotter.Circle` hence means that the -radius of the circle is 30% of the carbon-carbon distance. Using this relative -unit it is easy to make good-looking plots where the symbols cover a -well-defined part of the plot. +We customize the plotting: we set the `site_colors` argument of +`~kwant.plotter.plot` to a function which returns 0 for +sublattice `a` and 1 for sublattice `b`:: + + def group_colors(site): + return 0 if site.group == a else 1 + +The function `~kwant.plotter.plot` shows these values using a color scale +(grayscale by default). The symbol `size` is specified in points, and is +independent on the overall figure size. Plotting the closed system gives this result: diff --git a/doc/source/whatsnew/0.2.rst b/doc/source/whatsnew/0.2.rst index 1344b7afc27740fa215ecf3bdac7ea791c68644c..2fc2710d112e2e95cba32d46407b10168772be76 100644 --- a/doc/source/whatsnew/0.2.rst +++ b/doc/source/whatsnew/0.2.rst @@ -29,33 +29,31 @@ New tutorial dealing with superconductivity ------------------------------------------- :doc:`../tutorial/tutorial5` -`~kwant.plotter.plot` more useful for low level systems -------------------------------------------------------- -The behavior of `~kwant.plotter.plot` has been changed when a `low level system -<kwant.system.System>` is plotted. Previously, only low level systems which -were finalized `builders <kwant.builder.Builder>` were supported and there was -no difference between plotting a low-level system and a builder. +New `~kwant.plotter` module +--------------------------- +`~kwant.plotter` has been rewritten using `matplotlib`, which allows +plot post-processing, basic 3D plotting and many other features. Due to the +possibility to easily modify a `matplotlib` plot after it has been generated, +function `~kwant.plotter.plot` has much fewer input parameters, and is less +flexible than its previous implementation. Its interface is also much more +similar to that of `matplotlib`. For the detailed interface and input +description check `~kwant.plotter.plot` documentation. -* Arguments of plot which are functions are given site numbers in place of - `~kwant.builder.Site` objects when plotting a low level system. This - provides an easy way to make the appearance of lines and symbols depend on - computation results. +The behavior of `~kwant.plotter.plot` with low level systems has changed. +Arguments of plot which are functions are given site numbers in place of +`~kwant.builder.Site` objects when plotting a low level system. This +provides an easy way to make the appearance of lines and symbols depend on +computation results. -* Only the scattering region (without leads) is plotted for low level systems. - -* For plotting low level systems, dictionaries with site group keys are no - longer supported as plot arguments. +A new function `~kwant.plotter.map` was implemented. It allows to show a map of +spatial dependence of a function of a system site (e.g. density of states) +without showing the sites themselves. Calculation of the local density of states ------------------------------------------ The new function of sparse solvers `~kwant.solvers.common.SparseSolver.ldos` allows the calculation of the local density of states. -Plotting of functions of system sites -------------------------------------- -The new function `kwant.plotter.show` plots functions of the system, i.e. the -potential or the LDOS. - Return value of sparse solver ----------------------------- The function `~kwant.solvers.common.SparseSolver.solve` of sparse solvers now @@ -63,6 +61,6 @@ always returns a single instance of `~kwant.solvers.common.BlockResult`. The latter has been generalized to include more information for leads defined as infinite systems. -Return value of `~kwant.solvers.sparse.make_linear_sys` has changed -------------------------------------------------------------------- +Return value of `~kwant.solvers.common.SparseSolver.make_linear_sys` has changed +-------------------------------------------------------------------------------- A namedtuple is used for more clarity. diff --git a/kwant/plotter.py b/kwant/plotter.py index 46ccdd314726935e31ca1a424f054a82187516a0..a73ecc3495cfbf4db5353da98befc08e4532a4f7 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -1,876 +1,704 @@ -"""kwant.plotter docstring""" +"""Plotter module for kwant. -from __future__ import division -from math import sqrt, pi, sin, cos, tan -import scipy.interpolate -from numpy import dot, add, subtract -import numpy as np +This module provides iterators useful for any plotter routine, such as a list +of system sites, their coordinates, lead sites at any lead unit cell, etc. If +`matplotlib` is available, it also provides simple functions for plotting the +system in two or three dimensions. +""" + +import itertools import warnings -import cairo -import matplotlib.pyplot as plt +import numpy as np +from scipy import spatial, interpolate +# All matplotlib imports must be isolated in a try, because even without +# matplotlib iterators remain useful. Further, mpl_toolkits used for 3D +# plotting are also imported separately, to ensure that 2D plotting works even +# if 3D does not. try: - import Image - defaultname = None - has_pil = True -except: - defaultname = "plot.pdf" - has_pil = False - -from . import builder, system + import matplotlib + from matplotlib.figure import Figure + from matplotlib import collections + from matplotlib.backends.backend_agg import FigureCanvasAgg + _mpl_enabled = True + from matplotlib.cbook import is_string_like, is_sequence_of_strings + try: + from mpl_toolkits import mplot3d + except ImportError: + warnings.warn("3D plotting not available.") +except ImportError: + warnings.warn("matplotlib is not available, only iterator-providing" + "functions will work.") + _mpl_enabled = False + +from . import system, builder + +__all__ = ['plot', 'map', 'sys_leads_sites', 'sys_leads_hoppings', + 'sys_leads_pos', 'sys_leads_hopping_pos', 'mask_interpolate'] + + +# matplotlib helper functions. + +def set_edge_colors(color, collection, cmap, norm=None): + """Process a color specification to a format accepted by collections. -__all__ = ['plot', 'Circle', 'Polygon', 'Line', 'Color', 'LineStyle', - 'black', 'white', 'red', 'green', 'blue', - 'interpolate', 'show'] - -class Color(object): - """RGBA color. + Parameters + ---------- + color : color specification + collection : instance of a subclass of `matplotlib.collections.Collection` + Collection to which the color is added. + cmap : `matplotlib` color map specification or None + Color map to be used if colors are specified as floats. + norm : `matplotlib` color norm + Norm to be used if colors are specified as floats. + """ + length = len(collection.get_paths()) + if isinstance(collection, mplot3d.art3d.Line3DCollection): + length = len(collection._segments3d) # Once again, matplotlib fault! + color_is_stringy = is_string_like(color) or is_sequence_of_strings(color) + if not color_is_stringy: + color = np.asanyarray(color) + if color.size == length: + color = np.ma.ravel(color) + if color_is_stringy: + colors = matplotlib.colors.colorConverter.to_rgba_array(color) + else: + # The inherent ambiguity is resolved in favor of color + # mapping, not interpretation as rgb or rgba: + if color.size == length: + colors = None # use cmap, norm after collection is created + else: + colors = matplotlib.colors.colorConverter.to_rgba_array(color) + collection.set_color(colors) + if colors is None: + if norm is not None and not isinstance(norm, + matplotlib.colors.Normalize): + raise ValueError('Illegal value of norm.') + collection.set_array(np.asarray(color)) + collection.set_cmap(cmap) + collection.set_norm(norm) - Standard Color object that can be used to specify colors in - `plot`. - When creating the Color object, the color is specified in an RGBA scheme, - i.e. by specifying the red (r), green (g) and blue (b) components - of the color and optionally an alpha channel controlling the transparancy. +def lines(axes, x0, x1, y0, y1, colors='k', linestyles='solid', cmap=None, + norm=None, **kwargs): + """Add a collection of line segments to an axes instance. Parameters ---------- - r, g, b : float in the range [0, 1] - specifies the values of the red, green and blue components of the color - alpha : float in the range [0, 1], optional - specifies the transparancy, with alpha=0 completely transparent and - alpha=1 completely opaque (not transparent). - Defaults to 1 (opaque). + axes : matplotlib.axes.Axes instance + Axes to which the lines have to be added. + x0 : array_like + Starting x-coordinates of each line segment + x1 : array_like + Ending x-coordinates of each line segment + y0 : array_like + Starting y-coordinates of each line segment + y1 : array_like + Ending y-coordinates of each line segment + colors : color definition, optional + Either a single object that is a proper matplotlib color definition + or a sequence of such objects of appropriate length. Defaults to all + segments black. + linestyles :linestyle definition, optional + Either a single object that is a proper matplotlib line style + definition or a sequence of such objects of appropriate length. + Defaults to all segments solid. + cmap : `matplotlib` color map specification or None + Color map to be used if colors are specified as floats. + norm : `matplotlib` color norm + Norm to be used if colors are specified as floats. + **kwargs : dict + keyword arguments to pass to `matplotlib.collections.LineCollection`. - Examples - -------- - The color black is specified using + Returns + ------- + `matplotlib.collections.LineCollection` instance containing all the + segments that were added. + """ + coords = (y0, y1, x0, x1) - >>> black = Color(0, 0, 0) + if not all(len(coord) == len(y0) for coord in coords): + raise ValueError('Incompatible lengths of coordinate arrays.') - and white using + if len(x0) == 0: + coll = collections.LineCollection([], linestyles=linestyles) + axes.add_collection(coll) + axes.autoscale_view() + return coll - >>> white = Color(1, 1, 1) + segments = (((i[0], i[1]), (i[2], i[3])) for + i in itertools.izip(x0, y0, x1, y1)) + coll = collections.LineCollection(segments, linestyles=linestyles) + set_edge_colors(colors, coll, cmap, norm) + axes.add_collection(coll) + coll.update(kwargs) - By default, a color is completely opaque (not transparent). Using the - optional parameter alpha one can specify transparancy. For example, + minx = min(x0.min(), x1.min()) + maxx = max(x0.max(), x1.max()) + miny = min(y0.min(), y1.min()) + maxy = max(y0.max(), y1.max()) - >>> black_transp = Color(0, 0, 0, alpha=0.5) + corners = (minx, miny), (maxx, maxy) - is black with 50% transparancy. + axes.update_datalim(corners) + axes.autoscale_view() - """ - def __init__(self, r, g, b, alpha=1.0): - for val in (r, g, b, alpha): - if val < 0 or val > 1: - raise ValueError("r, g, b, and alpha must be in " - "the range [0,1]") - self.r = r - self.g = g - self.b = b - self.alpha = alpha - - def _set_color_cairo(self, ctx, fading=None): - if fading is not None: - ctx.set_source_rgba(self.r + fading[1] * (fading[0].r - self.r), - self.g + fading[1] * (fading[0].g - self.g), - self.b + fading[1] * (fading[0].b - self.b), - self.alpha + fading[1] * - (fading[0].alpha - self.alpha)) - else: - ctx.set_source_rgba(self.r, self.g, self.b, self.alpha) + return coll -black = Color(0, 0, 0) -white = Color(1, 1, 1) -red = Color(1, 0, 0) -green = Color(0, 1, 0) -blue = Color(0, 0, 1) -# TODO: possibly add dashed, etc. -class LineStyle(object): - """Object for describing a line style. Can be used as a parameter in the - class `Line`. - - Right now, the LineStyle object only allows to specify the line cap (i.e. - the shape of the end of the line). In the future might include dashing, - etc. +def lines3d(axes, x0, x1, y0, y1, z0, z1, + colors='k', linestyles='solid', cmap=None, norm=None, **kwargs): + """Add a collection of 3D line segments to an Axes3D instance. Parameters ---------- - lcap : { 'butt', 'round', 'square'}, optional - Specifies the shape of the end of the line: - - 'butt' - End of the line is rectangular and ends exactly at the end point. - 'round' - End of the line is rounded, as if a half-circle is drawn around the - end point. - 'square' - End of the line is rectangular, but protrudes beyond the end point, - as if a square was drawn centered at the end point. - - Defaults to 'butt'. + axes : matplotlib.axes.Axes instance + Axes to which the lines have to be added. + x0 : array_like + Starting x-coordinates of each line segment + x1 : array_like + Ending x-coordinates of each line segment + y0 : array_like + Starting y-coordinates of each line segment + y1 : array_like + Ending y-coordinates of each line segment + z0 : array_like + Starting z-coordinates of each line segment + z1 : array_like + Ending z-coordinates of each line segment + colors : color definition, optional + Either a single object that is a proper matplotlib color definition + or a sequence of such objects of appropriate length. Defaults to all + segments black. + linestyles :linestyle definition, optional + Either a single object that is a proper matplotlib line style + definition or a sequence of such objects of appropriate length. + Defaults to all segments solid. + cmap : `matplotlib` color map specification or None + Color map to be used if colors are specified as floats. + norm : `matplotlib` color norm + Norm to be used if colors are specified as floats. + **kwargs : dict + keyword arguments to pass to `matplotlib.collections.LineCollection`. + + Returns + ------- + `mpl_toolkits.mplot3d.art3d.Line3DCollection` instance containing all the + segments that were added. """ - def __init__(self, lcap="butt"): - if lcap == "butt": - self.lcap = cairo.LINE_CAP_BUTT - elif lcap == "round": - self.lcap = cairo.LINE_CAP_ROUND - elif lcap == "square": - self.lcap = cairo.LINE_CAP_SQUARE - else: - raise ValueError("Unknown line cap style "+lcap) + had_data = axes.has_data() + coords = (y0, y1, x0, x1, z0, z1) - def _set_props_cairo(self, ctx, reflen): - ctx.set_line_cap(self.lcap) + if not all(len(coord) == len(y0) for coord in coords): + raise ValueError('Incompatible lengths of coordinate arrays.') -class Line(object): - """Draws a straight line between the two sites connected by a hopping. + if len(x0) == 0: + coll = mplot3d.art3d.Line3DCollection([], linestyles=linestyles) + axes.add_collection(coll) + return coll - Standard object that can be used to specify how to draw - a line representing a hopping in `plot`. + segments = [(i[: 3], i[3:]) for + i in itertools.izip(x0, y0, z0, x1, y1, z1)] + coll = mplot3d.art3d.Line3DCollection(segments, linestyles=linestyles) + set_edge_colors(colors, coll, cmap, norm) + coll.update(kwargs) + axes.add_collection(coll) - Parameters - ---------- - lw : float - line width relative to the reference length (see `plot`) - lcol : object realizing the "color functionality" (see `plot`) - line color - lsty : a LineStyle object - line style - """ - def __init__(self, lw, lcol=black, lsty=LineStyle()): - self.lw = lw - self.lcol = lcol - self.lsty = lsty - - def _draw_cairo(self, ctx, pos1, pos2, reflen, fading=None): - ctx.new_path() - if self.lw > 0 and self.lcol is not None and self.lsty is not None: - ctx.set_line_width(self.lw * reflen) - self.lcol._set_color_cairo(ctx, fading=fading) - self.lsty._set_props_cairo(ctx, reflen) - ctx.move_to(pos1[0], pos1[1]) - ctx.line_to(pos2[0], pos2[1]) - ctx.stroke() - -class Circle(object): - """Draw circle with (relative) radius r centered at a site. - - Standard symbol object that can be used with `plot`. - Sizes are always given in terms of the reference length - of `plot`. + min_max = lambda a, b: np.array(min(a.min(), b.min()), + max(a.max(), b.max())) + x, y, z = min_max(x0, x1), min_max(y0, y1), min_max(z0, z1) - Parameters - ---------- - r : float - Radius of the circle - fcol : color_like object or None, optional - Fill color. If None, the circle is not filled. Defaults to black. - lw : float, optional - Line width of the outline. If 0, no outline is drawn. - Defaults to 0.1. - lcol : color_like object or None, optional - Color of the outline. If None, no outline is drawn. Defaults to None. - lsty : `LineStyle` object - Line style of the outline. Defaults to LineStyle(). - """ - def __init__(self, r, fcol=black, lw=0.1, lcol=None, lsty=LineStyle()): - self.r = r - self.fcol = fcol - self.lw = lw - self.lcol= lcol - self.lsty = lsty - - def _draw_cairo(self, ctx, pos, reflen, fading=None): - ctx.new_path() - - if self.fcol is not None: - self.fcol._set_color_cairo(ctx, fading=fading) - ctx.arc(pos[0], pos[1], self.r * reflen, 0, 2*pi) - ctx.fill() - - if self.lw > 0 and self.lcol is not None and self.lsty is not None: - ctx.set_line_width(self.lw * reflen) - self.lcol._set_color_cairo(ctx, fading=fading) - self.lsty._set_props_cairo(ctx, reflen) - ctx.arc(pos[0], pos[1], self.r * reflen, 0, 2*pi) - ctx.stroke() - -class Polygon(object): - """Draw a regular n-sided polygon centered at a site. - - Standard symbol object that can be used with `plot`. - Sizes are always given in terms of the reference length - of `plot`. - - The size of the polygon can be specifed in one of two ways: - - either by specifying the side length `a` - - or by demanding that the area of the polygon is equal to a circle - with radius `size` + axes.auto_scale_xyz(x, y, z, had_data) - Parameters - ---------- - n : int - Number of sides (i.e. `n=3` is a triangle, `n=4` a square, etc.) - a, size : float, exactly one must be given - The size of the polygon, either specified by the side length `a` - or the radius `size` of a circle of equal area. - angle : float, optional - Rotate the polygon counter-clockwise by `angle` (specified - in radians. Defaults to 0. - fcol : color_like object or None, optional - Fill color. If None, the polygon is not filled. Defaults to black. - lw : float, optional - Line width of the outline. If 0, no outline is drawn. - Defaults to 0.1. - lcol : color_like object or None, optional - Color of the outline. If None, no outline is drawn. Defaults to None. - lsty : `LineStyle` object - Line style of the outline. Defaults to LineStyle(). - """ - def __init__(self, n, a=None, size=None, - angle=0, fcol=black, lw=0.1, lcol=None, lsty=LineStyle()): - if ((a is None and size is None) or - (a is not None and size is not None)): - raise ValueError("Either sidelength or equivalent circle radius " - "must be specified") - - self.n = n - if a is None: - # make are of triangle equal to circle of radius size - a = sqrt(4 * tan(pi / n) / n * pi) * size - # note: self.rc is the radius of the circumscribed circle - self.rc = a / (2 * sin(pi / n)) - self.angle = angle - self.fcol = fcol - self.lw = lw - self.lcol = lcol - self.lsty = lsty - - def _draw_cairo_poly(self, ctx, pos, reflen): - ctx.move_to(pos[0] + sin(self.angle) * self.rc * reflen, - pos[1] + cos(self.angle) * self.rc * reflen) - for i in xrange(1, self.n): - phi = i * 2 * pi / self.n - ctx.line_to(pos[0] + sin(self.angle + phi) * self.rc * reflen, - pos[1] + cos(self.angle + phi) * self.rc * reflen) - ctx.close_path() - - def _draw_cairo(self, ctx, pos, reflen, fading=None): - ctx.new_path() - - if self.fcol is not None: - self.fcol._set_color_cairo(ctx, fading=fading) - self._draw_cairo_poly(ctx, pos, reflen) - ctx.fill() - - if self.lw > 0 and self.lcol is not None and self.lsty is not None: - ctx.set_line_width(self.lw * reflen) - self.lcol._set_color_cairo(ctx, fading=fading) - self.lsty._set_props_cairo(ctx, reflen) - self._draw_cairo_poly(ctx, pos, reflen) - ctx.stroke() - - -def iterate_lead_sites_builder(syst, lead_copies): - for lead in syst.leads: - if not isinstance(lead, builder.BuilderLead): - continue - sym = lead.builder.symmetry - shift = sym.which(lead.interface[0]) + 1 - - for i in xrange(lead_copies): - for site in lead.builder.sites(): - yield sym.act(shift + i, site), i - - -def iterate_lead_hoppings_builder(syst, lead_copies): - for lead in syst.leads: - if not isinstance(lead, builder.BuilderLead): - continue - sym = lead.builder.symmetry - shift = sym.which(lead.interface[0]) + 1 - - for i in xrange(lead_copies): - for site1, site2 in lead.builder.hoppings(): - shift1 = sym.which(site1)[0] - shift2 = sym.which(site2)[0] - if shift1 >= shift2: - yield (sym.act(shift + i, site1), - sym.act(shift + i, site2), - i + shift1, i + shift2) - else: - # Note: this makes sure that hoppings beyond the unit - # cell are always ordered such that they are into - # the previous slice - yield (sym.act(shift + i - 1, site1), - sym.act(shift + i - 1, site2), - i - 1 + shift1, i - 1 + shift2) + return coll -def iterate_scattreg_sites_builder(syst): - for site in syst.sites(): - yield site +def output_fig(fig, output_mode='auto', file=None, savefile_opts=None, + show=True): + """Output a matplotlib figure using a given output mode. + Parameters + ---------- + fig : matplotlib.figure.Figure instance + The figure to be output. + output_mode : string + The output mode to be used. Can be one of the following: + 'pyplot' : attach the figure to pyplot, with the same behavior as if + pyplot.plot was called to create this figure. + 'ipython' : attach a `FigureCanvasAgg` to the figure and return it. + 'return' : return the figure. + 'file' : same as 'ipython', but also save the figure into a file. + 'auto' : if fname is given, save to a file, else if pyplot + is imported, attach to pyplot, otherwise just return. See also the + notes below. + file : string or a file object + The name of the target file or the target file itself + (opened for writing). + savefile_opts : (list, dict) or None + args and kwargs passed to `print_figure` of `matplotlib` + show : bool + Whether to call `matplotlib.pyplot.show()`. Only has an effect if the + output uses pyplot. -def iterate_scattreg_hoppings_builder(syst): - for hopping in syst.hoppings(): - yield hopping + Notes + ----- + For IPython with inline plotting, automatic mode selects 'return', since + there is a better way to show a figure by just calling `display(figure)`. + The behavior of this function producing a file is different from that of + matplotlib in that the `dpi` attribute of the figure is used by defaul + instead of the matplotlib config setting. + """ + if not _mpl_enabled: + raise RuntimeError('matplotlib is not installed.') + if output_mode == 'auto': + if file is not None: + output_mode = 'file' + else: + try: + if matplotlib.pyplot.get_backend() != \ + 'module://IPython.zmq.pylab.backend_inline': + output_mode = 'pyplot' + else: + output_mode = 'return' + except AttributeError: + output_mode = 'pyplot' + if output_mode == 'pyplot': + try: + fake_fig = matplotlib.pyplot.figure() + except AttributeError: + msg = 'matplotlib.pyplot is unavailable. Execute `import ' \ + 'matplotlib.pyplot` or use a different output mode.' + raise RuntimeError(msg) + fake_fig.canvas.figure = fig + fig.canvas = fake_fig.canvas + for ax in fig.axes: + try: + ax.mouse_init() # Make 3D interface interactive. + except AttributeError: + pass + if show: + matplotlib.pyplot.show() + return fig + elif output_mode == 'return': + canvas = FigureCanvasAgg(fig) + fig.canvas = canvas + return fig + elif output_mode == 'file': + canvas = FigureCanvasAgg(fig) + if savefile_opts is None: + savefile_opts = ([], {}) + if 'dpi' not in savefile_opts[1]: + savefile_opts[1]['dpi'] = fig.dpi + canvas.print_figure(file, *savefile_opts[0], + **savefile_opts[1]) + return fig + else: + assert False, 'Unknown output_mode' -def empty_generator(*args, **kwds): - return - yield + +# Extracting necessary data from the system. +def sys_leads_sites(sys, n_lead_copies=2): + """Return all the sites of the system and of the leads as a list. -def iterate_scattreg_sites_llsys(syst): - return xrange(syst.graph.num_nodes) + Parameters + ---------- + sys : kwant.builder.Builder or kwant.system.System instance + The system, sites of which should be returned. + n_lead_copies : integer + The number of times lead sites from each lead should be returned. + This is useful for showing several unit cells of the lead next to the + system. + Returns + ------- + sites : list of (site, lead_number, copy_number) tuples + A site is a `builder.Site` instance if the system is not finalized, + and an integer otherwise. For system sites `lead_number` is `None` and + `copy_number` is `0`, for leads both are integers. -def iterate_scattreg_hoppings_llsys(syst): - for i in xrange(syst.graph.num_nodes): - for j in syst.graph.out_neighbors(i): - # Only yield half of the hoppings (as builder does) - if i < j: - yield i, j + Notes + ----- + Leads are only supported if they are of the same type as the original + system, i.e. sites of `builder.BuilderLead` leads are returned with an + unfinalized system, and sites of `system.InfiniteSystem` leads are + returned with a finalized system. + """ + if isinstance(sys, builder.Builder): + sites = [(site, None, 0) for site in sys.sites()] + for leadnr, lead in enumerate(sys.leads): + if hasattr(lead, 'builder'): + sites.extend(((site, leadnr, i) for site in + lead.builder.sites() for i in + xrange(n_lead_copies))) + elif isinstance(sys, system.FiniteSystem): + sites = [(i, None, 0) for i in xrange(sys.graph.num_nodes)] + for leadnr, lead in enumerate(sys.leads): + # We will only plot leads with a graph and with a symmetry. + if hasattr(lead, 'graph') and hasattr(lead, 'symmetry'): + sites.extend(((site, leadnr, i) for site in + xrange(lead.slice_size) for i in + xrange(n_lead_copies))) + else: + raise TypeError('Unrecognized system type.') + return sites -def extent(pos, sites): - """Figure out the extent of the system.""" - minx = miny = inf = float('inf') - maxx = maxy = float('-inf') - for site in sites: - point = pos(site) - try: - x, y = point - except: - raise ValueError( - "Position must be 2d. Consider using the `pos` argument.") - minx = min(x, minx) - maxx = max(x, maxx) - miny = min(y, miny) - maxy = max(y, maxy) - if minx == inf: - warnings.warn("Plotting empty system"); - return 0, 1, 0, 1 - return minx, maxx, miny, maxy - - -def typical_distance(pos, hoppings, sites): - min_sq_dist = inf = float('inf') - for site1, site2 in hoppings: - tmp = subtract(pos(site1), pos(site2)) - sq_dist = dot(tmp, tmp) - if 0 < sq_dist < min_sq_dist: - min_sq_dist = sq_dist - - # If there were no hoppings, then we can only find the distance by checking - # the distances between all pairs sites (potentially slow). To speed this - # only look at the distances between 10 chosen sites and all the remaining - # sites. This simple heuristics works well in practice and is fast enough. - if min_sq_dist == inf: - first = True - positions = list(pos(site) for site in sites) - - for site1 in positions[:: max(len(positions) // 10, 1)]: - for site2 in positions: - tmp = subtract(site1, site2) - sq_dist = dot(tmp, tmp) - if 0 < sq_dist < min_sq_dist: - min_sq_dist = sq_dist - - # If min_sq_dist ist still 0, all sites sit at the same spot In this case I - # can just use any value for dist (rangex and rangey will also be 0 then) - return sqrt(min_sq_dist) if min_sq_dist != inf else 1 - - -def default_pos(syst): - if isinstance(syst, builder.Builder): - return lambda site: site.pos - elif isinstance(syst, builder.FiniteSystem): - return lambda i: syst.site(i).pos - else: - raise ValueError("`pos` argument needed when plotting" - " systems which are not (finalized) builders") - - -def plot(syst, filename=defaultname, fmt=None, a=None, - width=600, height=None, border=0.1, bcol=white, pos=None, - symbols=Circle(r=0.3), lines=Line(lw=0.1), - lead_symbols=-1, lead_lines=-1, - lead_fading=[0.6, 0.85]): - """Plot two-dimensional systems (or two-dimensional representations - of a system). - - `plot` can be used to plot both unfinalized kwant.builder.Builder - instances, and low level systems (i.e. instances of - kwant.system.FiniteSystem), including finalized builders. - - This function behaves differently for builders and low-level systems: - builders are plotted including those of their leads which are builders - themselves. For the leads, several copies of the lead unit cell are - plotted (per default 2), and they are gradually faded towards the - background color (at least in the default behavior). For low-level systems - the leads are ignored as there is no general way to recover the necessary - information about leads for low level systems. - - When arguments to this function are functions themselves, "sites" will be - passed to them as arguments. The meaning of "site" depends on whether the - system to be plotted is a builder or a low level system. For builders, a - site is a kwant.builder.Site object. For low level systems, a site is an - integer -- the site number. - - The output of `plot` is highly modifyable, as it does not perform any - drawing itself, but instead lets objects passed by the user (or as default - parameters) do the actual drawing work. `plot` itself does figure out the - range of positions occupied by the sites, as well as the smallest distance - between two sites which then serves as a reference length, unless the user - specifies explicitely a reference length. This reference length is then - used so that the sizes of symbols or lines are always given relative to - that reference length. This is particularly advantageous for regular - lattices, as it makes it easy to specify the area covered by symbols, etc. - - The objects that determine `plot`'s behavior are symbol_like (symbols - representing sites), line_like (lines representing hoppings) and color_like - (representing colors). The notes below explain in detail how to implement - custom classes. In most cases it is enough to use the predefined standard - objects: - - - for symbol_like: `Circle` and `Polygon` - - for line_like: `Line` - - for color_like: `Color`. +def sys_leads_pos(sys, site_lead_nr): + """Return an array of positions of sites in a system. Parameters ---------- - syst : (un)finalized system - System to plot. Either an unfinalized Builder - (instance of `kwant.builder.Builder`) - or a finalized builder (instance of - `kwant.builder.FiniteSystem`). - filename : string or None, optional - Name of the file the plot should be written to. The format - of the file can be determined from the suffix (see `fmt`). - If None, the plot is output on the screen [provided that the - Python Image Library (PIL) is installed]. Default is - None if the PIL is installed, and "plot.pdf" otherwise. - fmt : {"pdf", "ps", "eps", "svg", "png", "jpg", None}, optional - Format of the output file, if `filename` is not None. If - `fmt` is None, the format is determined from the suffix of the - `filename`. Defaults to None. - a : float, optional - Reference length. If None, the reference length is determined - as the smallest nonzero distance between sites. Defaults to None. - width, height : float or None, optional - Width and height of the output picture. In units of - "pt" for the vector graphics formats (pdf, ps, eps, svg) - and in pixels for the bitmap formats (png, jpg, and output to screen). - For the bitmap formats, `width` and `height` are rounded to the nearest - integer. One of `width` and `height` may be None (but not both - simultaneously). In this case, the unspecified size is chosen to - fit with the aspect ratio of the plot. If both are specified, the plot - is centered on the canvas (possibly with increasing the blank borders). - `width` defaults to 600, and `height` to None. - border : float, optional - Size of the blank border around the plot, relative to the - total size. Defaults to 0.1. - bcol : color_like, optional - Background color. Defaults to white. - - (If the plot is saved in a vector graphics format, `white` - actually corresponds to no background. This is a bit hacky - maybe [fading to bcol e.g. still makes a white symbol, not a - transparant symbol], but then again there is no reason for - having a white box behind everything) - pos : function or None, optional - When passed a site should return its (2D) position as a sequence of - length 2. If None, the real space position of the site is used if the - system to be plotted is a (finalized) builder. For other low level - systems it is required to specify this argument and an error will be - reported if it is missing. Defaults to None. - symbols : {symbol_like, function, dict, None}, optional - Object responsible for drawing the symbols correspodning to sites. - Either must be a single symbol_like object (the same symbol is drawn - for every site), a function that returns a symbol_like object when - passed a site, or None (in which case no symbols are drawn). Instead of - a symbol_like object the function may also return None corresponding to - no symbol. - - If the system is a builder, `symbols` may also be a dictionary with - site groups as keys and symbol_like as values. This allows to specify - different symbols for different site groups. - - Defaults to ``Circle(r=0.3)``. - - The standard symbols available are `Circle` and `Polygon`. - lines : {line_like, function, dict, None}, optional - Object responsible for drawing the lines representing the hoppings - between sites. Either a single line_like object (the same type of line - is drawn for all hoppings), a function that returns a line_like object - when passed two sites, or None (in which case no hoppings are - drawn). Instead of a line_like object the function may also return None - corresponding to no line. Defaults to ``Line(lw=0.1)``. - - If the system is a builder, `lines` may also be a dictionary with - tuples of two site groups as keys and line_like objects as values. - This allows to specify different line styles for different hoppings. - Note that if the hopping (a, b) is specified, (b, a) needs not be - included in the dictionary. - - The standard line available is `Line`. - lead_symbols : {symbol_like, function, dict, -1, None}, optional - Symbols to be drawn for the sites in the leads. The special - value -1 indicates that `symbols` (which is used for system sites) - should be used also for the leads. The other possible values are - as for the system `symbols`. - Defaults to -1. - lead_lines : {line_like, function, dict, -1, None}, optional - Lines to be drawn for the hoppings in the leads. The special - value -1 indicates that `lines` (which is used for system hoppings) - should be used also for the leads. The other possible values are - as for the system `lines`. - Defaults to -1. - lead_fading : list, optional - The number of entries in the list determines the number of - lead unit cells that are plotted. The unit cell `i` is then - faded by the ratio ``lead_fading[i]`` towards the - background color `bcol`. Here ``lead_fading[i]==0`` implies no fading - (i.e. the original symbols and lines), - whereas ``lead_fading[i]==1`` corresponds to the background color. + sys : `kwant.builder.Builder` or `kwant.system.System` instance + The system, coordinates of sites of which should be returned. + sites : list of `(site, leadnr, copynr)` tuples + Output of `sys_leads_sites` applied to the system. + + Returns + ------- + coords : numpy.ndarray of floats + Array of coordinates of the sites. Notes ----- - - `plot` knows three different legitimate classes representing - symbols (symbol_like), lines (line_like), and colors (color_like). - In order to serve as a legitimate object for these, - a class has to implement certain methods. In particular these - are - - - symbol_like: objects representing symbols for sites:: - - _draw_cairo(ctx, pos, reflen[, fading]) - - which draws the symbol onto the cairo context `ctx` - at the position `pos` (passed as a sequence of length 2). - `reflen` is the reference length, allowing the symbol to use - relative sizes. (Note though that `pos` is in **absolute** cairo - coordinates). - - If the symbol should also be used to draw leads, `_draw_cairo` - should also take the optional parameter `fading` wich is a tuple - `(fadecol, percent)` where `fadecol` is the color towards which - the symbol should be faded, and `percent` is a number between 0 - and 1 indicating the amount of fading, with `percent=0` no - fading, and `percent=1` fully faded to `fadecol`. Note that - while "fading" usually will imply color fading, this is not - required by plot. Anything conceivable is legitimate. - - The module :mod:`plot` provides two standard symbol classes: - `Circle` and `Polygon`. - - - line_like: objects representing lines for hoppings:: - - _draw_cairo(ctx, pos1, pos2, reflen[, fading]) - - which draws the something (typically a line of some sort) onto - the cairo context `ctx` connecting the position `pos1` and - `pos2` (passed as sequences of length 2). `reflen` is the - reference length, allowing the line to use relative sizes. (Note - though that `pos1` and `pos2` are in **absolute** cairo - coordinates). - - If the line should also be used to draw leads, `_draw_cairo` - should also take the optional parameter `fading` wich is a tuple - `(fadecol, percent)` where `fadecol` is the color towards which - the symbol should be faded, and `percent` is a number between 0 - and 1 indicating the amount of fading, with `percent=0` no - fading, and `percent=1` fully faded to `fadecol`. Note that - while "fading" usually will imply color fading, this is not - required by plot. Anything conceivable is legitimate. - - The module :mod:`plot` provides one standard line class: `Line`. - - - color_like: for objects representing colors:: - - def _set_color_cairo(ctx[, fading]): - - which sets the current color of the cairo context `ctx`. - - If the color is passed to an object that requires fading in - order to be applicable for the representation of leads, - it must also take the optional parameter 'fading' wich is a tuple - `(fadecol, percent)` where `fadecol` is the color towards which - the symbol should be faded, and `percent` is a number between 0 - and 1 indicating the amount of fading, with `percent=0` no - fading, and `percent=1` fully faded to `fadecol`. Note that - while "fading" usually will imply color fading, this is not - required by plot. Anything conceivable is legitimate. - - The module :mod:`plot` provides one standard color class: - `Color`. In addition, a few common colors are predefined - as instances of `Color`:`black`, `white`, `red`, `green`, - and `blue`. + This function uses `site.pos` property to get the position of a builder + site and `sys.pos(sitenr)` for finalized systems. This function requires + that all the positions of all the sites have the same dimensionality. """ + is_builder = isinstance(sys, builder.Builder) + n_lead_copies = site_lead_nr[-1][2] + 1 + if is_builder: + pos = np.array([i[0].pos for i in site_lead_nr]) + else: + sys_from_lead = lambda lead: (sys if (lead is None) + else sys.leads[lead]) + pos = np.array([sys_from_lead(i[1]).pos(i[0]) for i in site_lead_nr]) + if pos.dtype == object: # Happens if not all the pos are same length. + raise ValueError("pos attribute of the sites does not have consistent" + " values.") + dim = pos.shape[1] + + def get_vec_domain(lead_nr): + if lead_nr is None: + return np.zeros((dim,)), 0 + if is_builder: + sym = sys.leads[lead_nr].builder.symmetry + site = sys.leads[lead_nr].interface[0] + else: + sym = sys.leads[lead_nr].symmetry + site = sys.site(sys.lead_interfaces[lead_nr][0]) + dom = sym.which(site)[0] + 1 + # TODO (Anton): vec = sym.periods[0] not supported by ta.ndarray + # Remove conversion to np.ndarray when not necessary anymore. + vec = np.array(sym.periods)[0] + return vec, dom + vecs_doms = dict((i, get_vec_domain(i)) for i in xrange(len(sys.leads))) + vecs_doms[None] = np.zeros((dim,)), 0 + for k, v in vecs_doms.iteritems(): + vecs_doms[k] = [v[0] * i for i in xrange(v[1], v[1] + n_lead_copies)] + pos += [vecs_doms[i[1]][i[2]] for i in site_lead_nr] + return pos + + +def sys_leads_hoppings(sys, n_lead_copies=2): + """Return all the hoppings of the system and of the leads as an iterator. - def iterate_all_sites(syst, lead_copies=0): - for site in iterate_scattreg_sites(syst): - yield site - - for site, ucindx in iterate_lead_sites(syst, lead_copies): - yield site + Parameters + ---------- + sys : kwant.builder.Builder or kwant.system.System instance + The system, sites of which should be returned. + n_lead_copies : integer + The number of times lead sites from each lead should be returned. + This is useful for showing several unit cells of the lead next to the + system. - def iterate_all_hoppings(syst, lead_copies=0): - for site1, site2 in iterate_scattreg_hoppings(syst): - yield site1, site2 + Returns + ------- + hoppings : list of (hopping, lead_number, copy_number) tuples + A site is a `builder.Site` instance if the system is not finalized, + and an integer otherwise. For system sites `lead_number` is `None` and + `copy_number` is `0`, for leads both are integers. - for site1, site2, i1, i2 in iterate_lead_hoppings(syst, lead_copies): - yield site1, site2 + Notes + ----- + Leads are only supported if they are of the same type as the original + system, i.e. hoppings of `builder.BuilderLead` leads are returned with an + unfinalized system, and hoppings of `system.InfiniteSystem` leads are + returned with a finalized system. + """ + hoppings = [] + if isinstance(sys, builder.Builder): + hoppings.extend(((hop, None, 0) for hop in sys.hoppings())) - is_builder = isinstance(syst, builder.Builder) - is_lowlevel = isinstance(syst, system.FiniteSystem) - if is_builder: - iterate_scattreg_sites = iterate_scattreg_sites_builder - iterate_scattreg_hoppings = iterate_scattreg_hoppings_builder - iterate_lead_sites = iterate_lead_sites_builder - iterate_lead_hoppings = iterate_lead_hoppings_builder - elif is_lowlevel: - iterate_scattreg_sites = iterate_scattreg_sites_llsys - iterate_scattreg_hoppings = iterate_scattreg_hoppings_llsys - # We do not plot leads for low level systems, as there is no general - # way to do that. - iterate_lead_sites = empty_generator - iterate_lead_hoppings = empty_generator + def lead_hoppings(lead): + sym = lead.symmetry + for site2, site1 in lead.hoppings(): + shift1 = sym.which(site1)[0] + shift2 = sym.which(site2)[0] + # We need to make sure that the hopping is between a site in a + # fundamental domain and a site with a negative domain. The + # direction of the hopping is chosen arbitrarily + # NOTE(Anton): This may need to be revisited with the future + # builder format changes. + shift = max(shift1, shift2) + yield sym.act([-shift], site2), sym.act([-shift], site1) + + for leadnr, lead in enumerate(sys.leads): + if hasattr(lead, 'builder'): + hoppings.extend(((hop, leadnr, i) for hop in + lead_hoppings(lead.builder) for i in + xrange(n_lead_copies))) + elif isinstance(sys, system.System): + def ll_hoppings(sys): + for i in xrange(sys.graph.num_nodes): + for j in sys.graph.out_neighbors(i): + if i < j: + yield i, j + hoppings.extend(((hop, None, 0) for hop in ll_hoppings(sys))) + for leadnr, lead in enumerate(sys.leads): + # We will only plot leads with a graph and with a symmetry. + if hasattr(lead, 'graph') and hasattr(lead, 'symmetry'): + hoppings.extend(((hop, leadnr, i) for hop in + ll_hoppings(lead) for i in + xrange(n_lead_copies))) else: - raise ValueError("Plotting not suported for given system") - - if width is None and height is None: - raise ValueError("One of width and height must be not None") + raise TypeError('Unrecognized system type.') + return hoppings - if pos is None: - pos = default_pos(syst) - if fmt is None and filename is not None: - # Try to figure out the format from the filename - fmt = filename.split(".")[-1].lower() - elif fmt is not None and filename is None: - raise ValueError("If fmt is specified, filename must be given, too") +def sys_leads_hopping_pos(sys, hop_lead_nr): + """Return arrays of coordinates of all hoppings in a system. - if fmt not in [None, "pdf", "ps", "eps", "svg", "png", "jpg"]: - raise ValueError("Unknwon format " + fmt) + Parameters + ---------- + sys : `kwant.builder.Builder` or `kwant.system.System` instance + The system, coordinates of sites of which should be returned. + hoppings : list of `(hopping, leadnr, copynr)` tuples + Output of `sys_leads_hoppings` applied to the system. - # Those two need the PIL - if fmt in [None, "jpg"] and not has_pil: - raise ValueError("The requested functionality requires the " - "Python Image Library (PIL)") + Returns + ------- + coords : (end_site, start_site): tuple of numpy arrays of floats + Array of coordinates of the hoppings. The first half of coordinates + in each array entry are those of the first site in the hopping, the + last half are those of the second site. - # Symbols and lines may be constant or functions. Wrap them as functions. - if hasattr(symbols, "__call__"): - fsymbols = symbols - elif is_builder and hasattr(symbols, "__getitem__"): - fsymbols = lambda x : symbols[x.group] + Notes + ----- + This function uses `site.pos` property to get the position of a builder + site and `sys.pos(sitenr)` for finalized systems. This function requires + that all the positions of all the sites have the same dimensionality. + """ + is_builder = isinstance(sys, builder.Builder) + if len(hop_lead_nr) == 0: + return np.empty((0, 3)), np.empty((0, 3)) + n_lead_copies = hop_lead_nr[-1][2] + 1 + if is_builder: + pos = np.array([np.r_[i[0][0].pos, i[0][1].pos] for i in hop_lead_nr]) else: - fsymbols = lambda x : symbols + sys_from_lead = lambda lead: (sys if (lead is None) + else sys.leads[lead]) + pos = [(sys_from_lead(i[1]).pos(i[0][0]), + sys_from_lead(i[1]).pos(i[0][1])) for i in hop_lead_nr] + pos = np.array([np.r_[i[0], i[1]] for i in pos]) + if pos.dtype == object: # Happens if not all the pos are same length. + raise ValueError("pos attribute of the sites does not have consistent" + " values.") + dim = pos.shape[1] + + def get_vec_domain(lead_nr): + if lead_nr is None: + return np.zeros((dim,)), 0 + if is_builder: + sym = sys.leads[lead_nr].builder.symmetry + site = sys.leads[lead_nr].interface[0] + else: + sym = sys.leads[lead_nr].symmetry + site = sys.site(sys.lead_interfaces[lead_nr][0]) + dom = sym.which(site)[0] + 1 + # TODO (Anton): vec = sym.periods[0] not supported by ta.ndarray + # Remove conversion to np.ndarray when not necessary anymore. + vec = np.array(sym.periods)[0] + return np.r_[vec, vec], dom + + vecs_doms = dict((i, get_vec_domain(i)) for i in xrange(len(sys.leads))) + vecs_doms[None] = np.zeros((dim,)), 0 + for k, v in vecs_doms.iteritems(): + vecs_doms[k] = [v[0] * i for i in xrange(v[1], v[1] + n_lead_copies)] + pos += [vecs_doms[i[1]][i[2]] for i in hop_lead_nr] + return np.copy(pos[:, : dim / 2]), np.copy(pos[:, dim / 2:]) + + +# Useful plot functions (to be extended). + +def plot(sys, n_lead_copies=2, site_color='b', hop_color='b', cmap='gray', + size=4, thickness=None, pos_transform=None, colorbar=True, file=None, + show=True, dpi=None, fig_size=None): + """Plot a system in 2 or 3 dimensions. - if hasattr(lines, "__call__"): - flines = lines - elif is_builder and hasattr(lines, "__getitem__"): - flines = lambda x, y : (lines[x.group, y.group] if (x.group, y.group) - in lines else lines[y.group, x.group]) - else: - flines = lambda x, y : lines - - if lead_symbols == -1: - flsymbols = fsymbols - elif hasattr(lead_symbols, "__call__"): - flsymbols = lead_symbols - elif is_builder and hasattr(lead_symbols, "__getitem__"): - flsymbols = lambda x : lead_symbols[x.group] - else: - flsymbols = lambda x : lead_symbols - - if lead_lines == -1: - fllines = flines - elif hasattr(lead_lines, "__call__"): - fllines = lead_lines - elif is_builder and hasattr(lines, "__getitem__"): - fllines = lambda x, y : (lead_lines[x.group, y.group] - if (x.group, y.group) in lead_lines - else lead_lines[y.group ,x.group]) - else: - fllines = lambda x, y : lead_lines + Parameters + ---------- + sys : kwant.builder.Builder or kwant.system.FiniteSystem + A system to be plotted. + n_lead_copies : int + Number of lead copies to be shown with the system. + site_color : `matplotlib` color description or a function + A color used for plotting a site in the system or a function returning + this color when given a site of the system (ignored for lead sites). + If a colormap is used, this function should return a single float. + hop_color : `matplotlib` color description or a function + Same as site_color, but for hoppings. A function is passed two sites + in this case. + cmap : `matplotlib` color map or a tuple of two color maps or `None` + The color map used for sites and optionally hoppings. + size : float or `None` + Site size in points. If `None`, `matplotlib` default is used. + thickness : float or `None` + Line thickness in points. If `None`, `matplotlib` default is used. + pos_transform : function or None + Transformation to be applied to the site position. + colorbar : bool + Whether to show a colorbar if colormap is used + file : string or file object or `None` + The output file. If `None`, output will be shown instead. + show : bool + Whether `matplotlib.pyplot.show()` is to be called, and the output is + to be shown immediately. Defaults to `True`. + dpi : float + Number of pixels per inch. If not set the `matplotlib` default is + used. + fig_size : tuple + Figure size `(width, height)` in inches. If not set, the default + `matplotlib` value is used. - minx, maxx, miny, maxy = \ - extent(pos, iterate_all_sites(syst, len(lead_fading))) + Notes + ----- + - The meaning of "site" depends on whether the system to be plotted is a + builder or a low level system. For builders, a site is a + kwant.builder.Site object. For low level systems, a site is an integer + -- the site number. - # If the user gave no typical distance between sites, we need to figure it - # out ourselves - # (Note: it is enough to consider one copy of the lead unit cell for - # figuring out distances, because of the translational symmetry) - if a is None: - a = typical_distance(pos, iterate_all_hoppings(syst, lead_copies=1), - iterate_all_sites(syst, lead_copies=1)) - elif a <= 0: - raise ValueError("The distance a must be >0") - - # Use the typical distance, if one of the ranges is 0 - # (e.g. in a one-dimensional system) - rangex = (maxx - minx) / (1 - 2 * border) - if rangex == 0: - rangex = a / (1 - 2 * border) - rangey = (maxy - miny) / (1 - 2 * border) - if rangey == 0: - rangey = a / (1 - 2 * border) - - # Compare with the desired dimensions of the plot - if height is None: - height = width * rangey / rangex - elif width is None: - width = height * rangex / rangey + - The dimensionality of the plot (2D vs 3D) is inferred from the coordinate + array. If there are more than three coordinates, only the first three + are used. If there is just one coordinate, the second one is padded with + zeros. + + - The system is scaled to fit the smaller dimension of the figure, given + its aspect ratio. + """ + # Generate data. + sites = sys_leads_sites(sys, n_lead_copies) + n_sys_sites = sum(i[1] is None for i in sites) + sites_pos = sys_leads_pos(sys, sites) + hops = sys_leads_hoppings(sys, n_lead_copies) + n_sys_hops = sum(i[1] is None for i in hops) + end_pos, start_pos = sys_leads_hopping_pos(sys, hops) + # Apply transformations to the data and generate the colors. + if pos_transform is not None: + sites_pos = np.apply_along_axis(pos_transform, 1, sites_pos) + end_pos = np.apply_along_axis(pos_transform, 1, end_pos) + start_pos = np.apply_along_axis(pos_transform, 1, start_pos) + if hasattr(site_color, '__call__'): + site_color = [site_color(i[0]) for i in sites if i[1] is None] + if hasattr(hop_color, '__call__'): + hop_color = [hop_color(*i[0]) for i in hops if i[1] is None] + # Choose plot type. + dim = 3 if (sites_pos.shape[1] == 3) else 2 + ar = np.zeros((len(sites_pos), dim)) + ar[:, : min(dim, sites_pos.shape[1])] = sites_pos + sites_pos = ar + end_pos.resize(len(end_pos), min(dim, end_pos.shape[1])) + start_pos.resize(len(start_pos), min(dim, start_pos.shape[1])) + fig = Figure() + if dpi is not None: + fig.set_dpi(dpi) + if fig_size is not None: + fig.set_figwidth(fig_size[0]) + fig.set_figheight(fig_size[1]) + if isinstance(cmap, tuple): + hop_cmap = cmap[1] + cmap = cmap[0] else: - # both width and height specified - # check in which direction to expand the border - if width/height > rangex / rangey: - rangex = rangey * width / height + hop_cmap = None + if dim == 2: + ax = fig.add_subplot(111, aspect='equal') + ax.scatter(*sites_pos[: n_sys_sites].T, c=site_color, cmap=cmap, + s=size ** 2, zorder=2) + end, start = end_pos[: n_sys_hops], start_pos[: n_sys_hops] + lines(ax, end[:, 0], start[:, 0], end[:, 1], start[:, 1], hop_color, + linewidths=thickness, zorder=1, cmap=hop_cmap) + lead_site_colors = np.array([i[2] for i in + sites if i[1] is not None], dtype=float) + # Avoid the matplotlib autoscale bug (remove when fixed) + if len(sites_pos) > n_sys_sites: + ax.scatter(*sites_pos[n_sys_sites:].T, c=lead_site_colors, + cmap='gist_yarg_r', s=size ** 2, zorder=2, + norm=matplotlib.colors.Normalize(-1, n_lead_copies + 1)) else: - rangey = rangex * height / width - - # Setup cairo - if fmt == "pdf": - surface = cairo.PDFSurface(filename, width, height) - elif fmt == "ps": - surface = cairo.PSSurface(filename, width, height) - elif fmt == "eps": - surface = cairo.PSSurface(filename, width, height) - surface.set_eps(True) - elif fmt == "svg": - surface = cairo.SVGSurface(filename, width, height) - elif fmt == "png" or fmt == "jpg" or fmt is None: - surface = cairo.ImageSurface(cairo.FORMAT_ARGB32, - int(round(width)), int(round(height))) - ctx = cairo.Context(surface) - - # The default background in the image surface is black - if fmt == "png" or fmt == "jpg" or fmt is None: - bcol._set_color_cairo(ctx) - ctx.rectangle(0, 0, int(round(width)), int(round(height))) - ctx.fill() - elif bcol is not white: - # only draw a background rectangle if background color is not white - bcol._set_color_cairo(ctx) - ctx.rectangle(0, 0, width, height) - ctx.fill() - - # Setup the coordinate transformation - - # Note: Cairo uses a coordinate system - # ---> x positioned in the left upper corner - # | of the screen. - # v y - # - # Instead, we use a mathematical coordinate system. - - # TODO: figure out, if file sizes are smaller without transformation - # i. e. if we do the transformation ourselves - scrminx = width * 0.5 * (rangex - (maxx - minx)) / rangex - scrminy = height * 0.5 * (rangey - (maxy - miny)) / rangey - - ctx.translate(scrminx, height - scrminy) - ctx.scale(width/rangex, -height/rangey) - ctx.translate(-minx, -miny) - - #### Draw the lines for the hoppings. - for site1, site2 in iterate_scattreg_hoppings(syst): - line = flines(site1, site2) - - if line is not None: - line._draw_cairo(ctx, pos(site1), pos(site2), a) - - for site1, site2, ucindx1, ucindx2 in \ - iterate_lead_hoppings(syst, len(lead_fading)): - if ucindx1 == ucindx2: - line = fllines(site1, site2) - - if line is not None: - line._draw_cairo(ctx, pos(site1), pos(site2), a, - fading=(bcol, lead_fading[ucindx1])) + ax.add_collection(matplotlib.collections.PathCollection([])) + lead_hop_colors = np.array([i[2] for i in + hops if i[1] is not None], dtype=float) + end, start = end_pos[n_sys_hops:], start_pos[n_sys_hops:] + lines(ax, end[:, 0], start[:, 0], end[:, 1], start[:, 1], + lead_hop_colors, linewidths=thickness, cmap='gist_yarg_r', + norm=matplotlib.colors.Normalize(-1, n_lead_copies + 1), + zorder=1) + else: + warnings.filterwarnings('ignore', message=r'.*rotation.*') + ax = fig.add_subplot(111, projection='3d') + warnings.resetwarnings() + ax.scatter(*sites_pos[: n_sys_sites].T, c=site_color, cmap=cmap, + s=size ** 2) + end, start = end_pos[: n_sys_hops], start_pos[: n_sys_hops] + lines3d(ax, end[:, 0], start[:, 0], end[:, 1], start[:, 1], + end[:, 2], start[:, 2], hop_color, cmap=hop_cmap, + linewidths=thickness) + lead_site_colors = np.array([i[2] for i in + sites if i[1] is not None], dtype=float) + lead_site_colors = 1 / np.sqrt(1. + lead_site_colors) + # Avoid the matplotlib autoscale bug (remove when fixed) + if len(sites_pos) > n_sys_sites: + ax.scatter(*sites_pos[n_sys_sites:].T, c=lead_site_colors, + cmap='gist_yarg_r', s=size ** 2, + norm=matplotlib.colors.Normalize(-1, n_lead_copies + 1)) else: - if ucindx1 > -1: - line = fllines(site1, site2) - if line is not None: - line._draw_cairo(ctx, pos(site1), - 0.5 * add(pos(site1), pos(site2)), - a, fading=(bcol, lead_fading[ucindx1])) - else: - #one end of the line is in the system - line = flines(site1, site2) - if line is not None: - line._draw_cairo(ctx, pos(site1), - 0.5 * add(pos(site1), pos(site2)), a) - - if ucindx2 > -1: - line = fllines(site2, site1) - if line is not None: - line._draw_cairo(ctx, pos(site2), - 0.5 * add(pos(site1), pos(site2)), - a, fading=(bcol, lead_fading[ucindx2])) - else: - # One end of the line is in the system - line = flines(site2, site1) - if line is not None: - line._draw_cairo(ctx, pos(site2), - 0.5 * add(pos(site1), pos(site2)), a) - - #### Draw the symbols for the sites. - for site in iterate_scattreg_sites(syst): - symbol = fsymbols(site) - - if symbol is not None: - symbol._draw_cairo(ctx, pos(site), a) - - for site, ucindx in iterate_lead_sites(syst, - lead_copies=len(lead_fading)): - symbol = flsymbols(site) - - if symbol is not None: - symbol._draw_cairo(ctx, pos(site), a, - fading=(bcol, lead_fading[ucindx])) - - - # Show or save the picture, if necessary (depends on format). - if fmt == None: - im = Image.frombuffer("RGBA", - (surface.get_width(), surface.get_height()), - surface.get_data(), "raw", "BGRA", 0, 1) - im.show() - elif fmt == "png": - surface.write_to_png(filename) - elif fmt == "jpg": - im = Image.frombuffer("RGBA", - (surface.get_width(), surface.get_height()), - surface.get_data(), "raw", "BGRA", 0, 1) - im.save(filename, "JPG") - - -def interpolate(syst, function, a=None, pos=None, - method='nearest', oversampling=3): - """Interpolate a scalar function defined for the sites of a system. + ax.add_collection(mplot3d.art3d.Patch3DCollection([])) + lead_hop_colors = np.array([i[2] for i in + hops if i[1] is not None], dtype=float) + lead_hop_colors = 1 / np.sqrt(1. + lead_hop_colors) + end, start = end_pos[n_sys_hops:], start_pos[n_sys_hops:] + lines3d(ax, end[:, 0], start[:, 0], end[:, 1], start[:, 1], + end[:, 2], start[:, 2], + lead_hop_colors, linewidths=thickness, cmap='gist_yarg_r', + norm=matplotlib.colors.Normalize(-1, n_lead_copies + 1)) + min_ = np.min(sites_pos, 0) + max_ = np.max(sites_pos, 0) + w = np.max(max_ - min_) / 2 + m = (min_ + max_) / 2 + ax.auto_scale_xyz(*[(i - w, i + w) for i in m], had_data=True) + if ax.collections[0].get_array() is not None and colorbar: + fig.colorbar(ax.collections[0]) + if ax.collections[1].get_array() is not None and colorbar: + fig.colorbar(ax.collections[1]) + return output_fig(fig, file=file, show=show) + + +def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3): + """Interpolate a scalar function in vicinity of given points. + + Create a masked array corresponding to interpolated values of the function + at points lying not further than a certain distance from the original + data points provided. Parameters ---------- - syst : kwant.system.FiniteSystem or kwant.builder.Builder - The system for whose sites `function` is to be plotted. - function : function or mapping - Function which takes a site and returns a number, or a mapping whose - keys are sites and whose values are numbers. + coords : np.ndarray + An array with site coordinates. + values : np.ndarray + An array with the values from which the interpolation should be built. a : float, optional - Reference length. If not given, it is determined as the smallest - nonzero distance between sites. - pos : function, optional - When passed a site should return its (2D) position as a sequence of - length 2. If None, the real space position of the site is used if the - system to be plotted is a (finalized) builder. For other low level - systems it is required to specify this argument and an error will be - reported if it is missing. Defaults to None. + Reference length. If not given, it is determined as a typical + nearest neighbor distance. method : string, optional Passed to `scipy.interpolate.griddata`: "nearest" (default), "linear", or "cubic" @@ -887,11 +715,6 @@ def interpolate(syst, function, a=None, pos=None, Notes ----- - - The meaning of "site" depends on whether the system to be plotted is a - builder or a low level system. For builders, a site is a - kwant.builder.Site object. For low level systems, a site is an integer - -- the site number. - - `min` and `max` are chosen such that when plotting a system on a square lattice and `oversampling` is set to an odd integer, each site will lie exactly at the center of a pixel of the output array. @@ -900,90 +723,99 @@ def interpolate(syst, function, a=None, pos=None, makes sense to set `oversampling` to ``1`` to minimize the size of the output array. """ - if isinstance(syst, builder.Builder): - iterate_scattreg_sites = iterate_scattreg_sites_builder - iterate_scattreg_hoppings = iterate_scattreg_hoppings_builder - elif isinstance(syst, system.FiniteSystem): - iterate_scattreg_sites = iterate_scattreg_sites_llsys - iterate_scattreg_hoppings = iterate_scattreg_hoppings_llsys - else: - raise ValueError("Plotting not suported for given system.") + # Build the bounding box. + cmin, cmax = coords.min(0), coords.max(0) - if not hasattr(function, '__call__'): - try: - function = function.__getitem__ - except: - raise TypeError("`function` must be either callable or a mapping.") - - if pos is None: - pos = default_pos(syst) + tree = spatial.cKDTree(coords) if a is None: - a = typical_distance(pos, iterate_scattreg_hoppings(syst), - iterate_scattreg_sites(syst)) + points = coords[np.random.randint(len(coords), size=10)] + a = np.min(tree.query(points, 2)[0][:, 1]) elif a <= 0: - raise ValueError("The distance a must be >0") - - points = [] - values = [] - for site in iterate_scattreg_sites(syst): - point = pos(site) - if point.shape != (2,): - raise ValueError( - 'Position must be 2d. Consider using the `pos` argument.') - points.append(point) - values.append(function(site)) - points = np.array(points) - values = np.array(values) - - min = points.min(0) - max = points.max(0) - shape = (((max - min) / a + 1) * oversampling).round() + raise ValueError("The distance a must be strictly positive.") + + shape = (((cmin - cmax) / a + 1) * oversampling).round() delta = 0.5 * (oversampling - 1) * a / oversampling - min -= delta - max += delta - grid_x, grid_y = np.ogrid[min[0]:max[0]:complex(0, shape[0]), - min[1]:max[1]:complex(0, shape[1])] - img = scipy.interpolate.griddata(points, values, (grid_x, grid_y), - method) + cmin -= delta + cmax += delta + dims = tuple(slice(cmin[i], cmax[i], 1j * shape[i]) for i in + range(len(cmin))) + grid = tuple(np.ogrid[dims]) + img = interpolate.griddata(coords, values, grid, method) + mask = np.mgrid[dims].reshape(len(cmin), -1).T + mask = tree.query(mask, eps=1.)[0] > 1.5 * a + + return np.ma.masked_array(img, mask), cmin, cmax - return img, min, max -def show(syst, function, colorbar=True, **kwds): - """Show a scalar function defined for the sites of a systems. +def map(sys, value, colorbar=True, cmap=None, + a=None, method='nearest', oversampling=3, file=None, show=True): + """Show interpolated map of a function defined for the sites of a system. Create a pixmap representation of a function of the sites of a system by - calling `~kwant.plotter.interpolate` and show this pixmap using matplotlib. + calling `~kwant.plotter.mask_interpolate` and show this pixmap using + matplotlib. Parameters ---------- - syst : kwant.system.FiniteSystem or kwant.builder.Builder - The system for whose sites `function` is to be plotted. - function : function or mapping - Function which takes a site and returns a number, or a mapping whose - keys are sites and whose values are numbers. + sys : kwant.system.FiniteSystem or kwant.builder.Builder + The system for whose sites `value` is to be plotted. + value : function or list + Function which takes a site and returns a value if the system is a + builder, or a list of function values for each system site of the + finalized system. colorbar : bool, optional Whether to show a color bar. Defaults to `true`. - kwds : other arguments - All other arguments are passed to `~kwant.plotter.interpolate`. + cmap : `matplotlib` color map or `None` + The color map used for sites and optionally hoppings, if `None`, + `matplotlib` default is used. + a : float, optional + Reference length. If not given, it is determined as a typical + nearest neighbor distance. + method : string, optional + Passed to `scipy.interpolate.griddata` and to `matplotlib` + `Axes.imshow.interpolation`: "nearest" (default), "linear", or "cubic". + oversampling : integer, optional + Number of pixels per reference length. Defaults to 3. + file : string or file object or `None` + The output file. If `None`, output will be shown instead. + show : bool + Whether `matplotlib.pyplot.show()` is to be called, and the output is + to be shown immediately. Defaults to `True`. Notes ----- - - See notes of `~kwant.plotter.interpolate`. - - - This function uses matplotlib to show the interpolated function. + - See notes of `~kwant.plotter.show_interpolate`. - Matplotlib's interpolation is turned off, if the keyword argument `method` is not set or set to the default value "nearest". """ - img, min, max = interpolate(syst, function, **kwds) + sites = sys_leads_sites(sys, 0) + coords = sys_leads_pos(sys, sites) + if coords.shape[1] != 2: + raise ValueError('Only 2D systems can be plotted this way.') + if hasattr(value, '__call__'): + value = [value(site[0]) for site in sites] + else: + if not isinstance(sys, system.FiniteSystem): + raise ValueError('List of values is only allowed as input' + 'for finalized systems.') + value = np.array(value) + img, min, max = mask_interpolate(coords, value, a, method, oversampling) border = 0.5 * (max - min) / (np.asarray(img.shape) - 1) min -= border max += border - interpolation = 'nearest' if kwds.get('method') in [None, 'nearest'] \ - else None - plt.imshow(img.T, extent=(min[0], max[0], min[1], max[1]), origin='lower', - interpolation=interpolation) + fig = Figure() + ax = fig.add_subplot(111, aspect='equal') + if method != 'nearest': + method = 'bi' + method + image = ax.imshow(img.T, extent=(min[0], max[0], min[1], max[1]), + origin='lower', interpolation=method, cmap=cmap) if colorbar: - plt.colorbar() - plt.show() + fig.colorbar(image) + return output_fig(fig, file=file, show=show) + +# TODO (Anton): Fix plotting of parts of the system using color = np.nan. +# Not plotting sites currently works, not plotting hoppings does not. +# TODO (Anton): Allow a more flexible treatment of position than pos_transform +# (an interface for user-defined pos). diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index 29dc81bc23f0ed6e61170e8a94225c0f1fc9c471..61fffe207eaf05afb955d7dc647969251cc5245a 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -1,83 +1,114 @@ -import tempfile, os -from nose.tools import assert_raises -import numpy as np +import tempfile +import nose import kwant from kwant import plotter +if plotter._mpl_enabled: + from mpl_toolkits import mplot3d + from matplotlib import pyplot + + +def sys_2d(W=3, r1=3, r2=8): + a = 1 + t = 1.0 + lat = kwant.lattice.Square(a) + sys = kwant.Builder() + + def ring(pos): + (x, y) = pos + rsq = x ** 2 + y ** 2 + return r1 ** 2 < rsq < r2 ** 2 + + sys[lat.shape(ring, (0, r1 + 1))] = 4 * t + for hopping in lat.nearest: + sys[sys.possible_hoppings(*hopping)] = - t + sym_lead0 = kwant.TranslationalSymmetry([lat.vec((-1, 0))]) + lead0 = kwant.Builder(sym_lead0) + + def lead_shape(pos): + (x, y) = pos + return (-1 < x < 1) and (-W / 2 < y < W / 2) + + lead0[lat.shape(lead_shape, (0, 0))] = 4 * t + for hopping in lat.nearest: + lead0[lead0.possible_hoppings(*hopping)] = - t + lead1 = lead0.reversed() + sys.attach_lead(lead0) + sys.attach_lead(lead1) + return sys + + +def sys_3d(W=3, r1=2, r2=4, a=1, t=1.0): + lat = kwant.make_lattice(((a, 0, 0), (0, a, 0), (0, 0, a))) + lat.nearest = (((1, 0, 0), lat, lat), ((0, 1, 0), lat, lat), + ((0, 0, 1), lat, lat)) + sys = kwant.Builder() + + def ring(pos): + (x, y, z) = pos + rsq = x ** 2 + y ** 2 + return (r1 ** 2 < rsq < r2 ** 2) and abs(z) < 2 + sys[lat.shape(ring, (0, -r2 + 1, 0))] = 4 * t + for hopping in lat.nearest: + sys[sys.possible_hoppings(*hopping)] = - t + sym_lead0 = kwant.TranslationalSymmetry([lat.vec((-1, 0, 0))]) + lead0 = kwant.Builder(sym_lead0) + + def lead_shape(pos): + (x, y, z) = pos + return (-1 < x < 1) and (-W / 2 < y < W / 2) and abs(z) < 2 + + lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t + for hopping in lat.nearest: + lead0[lead0.possible_hoppings(*hopping)] = - t + lead1 = lead0.reversed() + sys.attach_lead(lead0) + sys.attach_lead(lead1) + return sys -lat = kwant.lattice.Square() - -def make_ribbon(width, dir, E, t): - b = kwant.Builder(kwant.TranslationalSymmetry([(dir, 0)])) - - # Add sites to the builder. - for y in xrange(width): - b[lat(0, y)] = E - - # Add hoppings to the builder. - for y in xrange(width): - b[lat(0, y), lat(1, y)] = t - if y+1 < width: - b[lat(0, y), lat(0, y+1)] = t - - return b - - -def make_rectangle(length, width, E, t): - b = kwant.Builder() - - # Add sites to the builder. - for x in xrange(length): - for y in xrange(width): - b[lat(x, y)] = E - - # Add hoppings to the builder. - for x in xrange(length): - for y in xrange(width): - if x+1 < length: - b[lat(x, y), lat(x+1, y)] = t - if y+1 < width: - b[lat(x, y), lat(x, y+1)] = t - - return b def test_plot(): - E = 4.0 - t = -1.0 - length = 5 - width = 5 - - b = make_rectangle(length, width, E, t) - b.attach_lead(make_ribbon(width, -1, E, t)) - b.attach_lead(make_ribbon(width, 1, E, t)) - - directory = tempfile.mkdtemp() - filename = os.path.join(directory, "test.pdf") - - kwant.plot(b.finalized(), filename=filename, - symbols=plotter.Circle(r=0.25, fcol=plotter.red), - lines=plotter.Line(lw=0.1, lcol=plotter.red), - lead_symbols=plotter.Circle(r=0.25, fcol=plotter.black), - lead_lines=plotter.Line(lw=0.1, lcol=plotter.black), - lead_fading=[0, 0.2, 0.4, 0.6, 0.8]) - - os.unlink(filename) - os.rmdir(directory) - -def test_non_2d_fails(): - directory = tempfile.mkdtemp() - filename = os.path.join(directory, "test.pdf") - - for d in [1, 2, 3, 15]: - b = kwant.Builder() - lat = kwant.make_lattice(np.identity(d)) - site = kwant.builder.Site(lat, (0,) * d) - b[site] = 0 - if d == 2: - kwant.plot(b, filename=filename) - plotter.interpolate(b, b) - else: - assert_raises(ValueError, kwant.plot, b) - assert_raises(ValueError, plotter.interpolate, b, b) - - os.unlink(filename) - os.rmdir(directory) + plot = plotter.plot + if not plotter._mpl_enabled: + raise nose.SkipTest + sys2d = sys_2d() + sys3d = sys_3d() + color_opts = ['k', (lambda site: site.tag[0]), + lambda site: (abs(site.tag[0] / 100), + abs(site.tag[1] / 100), 0)] + for color in color_opts: + for sys in (sys2d, sys3d): + fig = plot(sys, site_color=color, cmap='binary', show=False) + if color != 'k' and isinstance(color(iter(sys2d.sites()).next()), + float): + assert fig.axes[0].collections[0].get_array() is not None + assert len(fig.axes[0].collections) == 4 + color_opts = ['k', (lambda site, site2: site.tag[0]), + lambda site, site2: (abs(site.tag[0] / 100), + abs(site.tag[1] / 100), 0)] + for color in color_opts: + for sys in (sys2d, sys3d): + fig = plot(sys2d, hop_color=color, cmap='binary', show=False, + fig_size=(2, 10), dpi=30) + if color != 'k' and isinstance(color(iter(sys2d.sites()).next(), + None), float): + assert fig.axes[0].collections[1].get_array() is not None + + assert isinstance(plot(sys3d, show=False).axes[0], mplot3d.axes3d.Axes3D) + + sys2d.leads = [] + plot(sys2d, show=False) + del sys2d[list(sys2d.hoppings())] + plot(sys2d, show=False) + with tempfile.TemporaryFile('w+b') as output: + plot(sys3d, file=output) + + +def test_map(): + sys = sys_2d() + with tempfile.TemporaryFile('w+b') as output: + plotter.map(sys, lambda site: site.tag[0], file=output, + method='linear', a=4, oversampling=4, cmap='flag') + plotter.map(sys.finalized(), xrange(len(sys.sites())), + file=output) + nose.tools.assert_raises(ValueError, plotter.map, + sys, xrange(len(sys.sites())), file=output)