Skip to content
Snippets Groups Projects
Commit 351fcf5e authored by Joseph Weston's avatar Joseph Weston
Browse files

Merge branch 'plotting/site-size' into 'master'

Raise a TypeError when site sizes/colors are provided as arrays
and the system to plot is a kwant.Builder.

Closes #293

See merge request kwant/kwant!301
parents d1adb928 1cff8cbf
No related branches found
No related tags found
No related merge requests found
Pipeline #18380 passed
......@@ -735,17 +735,21 @@ def plot(sys, num_lead_cells=2, unit='nn',
symbols specifications (only for kwant.system.FiniteSystem).
site_size : number, function, array, or `None`
Relative (linear) size of the site symbol.
An array may not be used when 'syst' is a kwant.Builder.
site_color : ``matplotlib`` color description, function, array, or `None`
A color used for plotting a site in the system. If a colormap is used,
it should be a function returning single floats or a one-dimensional
array of floats. By default sites are colored by their site family,
using the current matplotlib color cycle.
An array of colors may not be used when 'syst' is a kwant.Builder.
site_edgecolor : ``matplotlib`` color description, function, array, or `None`
Color used for plotting the edges of the site symbols. Only
valid matplotlib color descriptions are allowed (and no
combination of floats and colormap as for site_color).
An array of colors may not be used when 'syst' is a kwant.Builder.
site_lw : number, function, array, or `None`
Linewidth of the site symbol edges.
An array may not be used when 'syst' is a kwant.Builder.
hop_color : ``matplotlib`` color description or a function
Same as `site_color`, but for hoppings. A function is passed two sites
in this case. (arrays are not allowed in this case).
......@@ -911,7 +915,11 @@ def plot(sys, num_lead_cells=2, unit='nn',
raise ValueError('Invalid value of unit argument.')
# make all specs proper: either constant or lists/np.arrays:
def make_proper_site_spec(spec, fancy_indexing=False):
def make_proper_site_spec(spec_name, spec, fancy_indexing=False):
if _p.isarray(spec) and isinstance(syst, builder.Builder):
raise TypeError('{} cannot be an array when plotting'
' a Builder; use a function instead.'
.format(spec_name))
if callable(spec):
spec = [spec(i[0]) for i in sites if i[1] is None]
if (fancy_indexing and _p.isarray(spec)
......@@ -933,7 +941,8 @@ def plot(sys, num_lead_cells=2, unit='nn',
spec = np.asarray(spec, dtype='object')
return spec
site_symbol = make_proper_site_spec(site_symbol)
site_symbol = make_proper_site_spec('site_symbol', site_symbol)
if site_symbol is None: site_symbol = defaults['site_symbol'][dim]
# separate different symbols (not done in 3D, the separation
# would mess up sorting)
......@@ -967,10 +976,10 @@ def plot(sys, num_lead_cells=2, unit='nn',
# Unknown finalized system, no sites access.
site_color = defaults['site_color'][dim]
site_size = make_proper_site_spec(site_size, fancy_indexing)
site_color = make_proper_site_spec(site_color, fancy_indexing)
site_edgecolor = make_proper_site_spec(site_edgecolor, fancy_indexing)
site_lw = make_proper_site_spec(site_lw, fancy_indexing)
site_size = make_proper_site_spec('site_size', site_size, fancy_indexing)
site_color = make_proper_site_spec('site_color', site_color, fancy_indexing)
site_edgecolor = make_proper_site_spec('site_edgecolor', site_edgecolor, fancy_indexing)
site_lw = make_proper_site_spec('site_lw', site_lw, fancy_indexing)
hop_color = make_proper_hop_spec(hop_color)
hop_lw = make_proper_hop_spec(hop_lw)
......
......@@ -169,6 +169,21 @@ def test_plot_more_site_families_than_colors():
plotter.plot(syst, file=out)
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot_raises_on_bad_site_spec():
syst = kwant.Builder()
lat = kwant.lattice.square()
syst[(lat(i, j) for i in range(5) for j in range(5))] = None
# Cannot provide site_size as an array when syst is a Builder
with pytest.raises(TypeError):
plotter.plot(syst, site_size=[1] * 25)
# Cannot provide site_size as an array when syst is a Builder
with pytest.raises(TypeError):
plotter.plot(syst, site_symbol=['o'] * 25)
def good_transform(pos):
x, y = pos
return y, x
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment