diff --git a/kwant/plotter.py b/kwant/plotter.py index d3bb4493bc32dca040cdf5352a653adc6a3bea4a..fabe5f61c9234b2323e0c11c860962b3b5b2bd5d 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -735,17 +735,21 @@ def plot(sys, num_lead_cells=2, unit='nn', symbols specifications (only for kwant.system.FiniteSystem). site_size : number, function, array, or `None` Relative (linear) size of the site symbol. + An array may not be used when 'syst' is a kwant.Builder. site_color : ``matplotlib`` color description, function, array, or `None` A color used for plotting a site in the system. If a colormap is used, it should be a function returning single floats or a one-dimensional array of floats. By default sites are colored by their site family, using the current matplotlib color cycle. + An array of colors may not be used when 'syst' is a kwant.Builder. site_edgecolor : ``matplotlib`` color description, function, array, or `None` Color used for plotting the edges of the site symbols. Only valid matplotlib color descriptions are allowed (and no combination of floats and colormap as for site_color). + An array of colors may not be used when 'syst' is a kwant.Builder. site_lw : number, function, array, or `None` Linewidth of the site symbol edges. + An array may not be used when 'syst' is a kwant.Builder. hop_color : ``matplotlib`` color description or a function Same as `site_color`, but for hoppings. A function is passed two sites in this case. (arrays are not allowed in this case). @@ -911,7 +915,11 @@ def plot(sys, num_lead_cells=2, unit='nn', raise ValueError('Invalid value of unit argument.') # make all specs proper: either constant or lists/np.arrays: - def make_proper_site_spec(spec, fancy_indexing=False): + def make_proper_site_spec(spec_name, spec, fancy_indexing=False): + if _p.isarray(spec) and isinstance(syst, builder.Builder): + raise TypeError('{} cannot be an array when plotting' + ' a Builder; use a function instead.' + .format(spec_name)) if callable(spec): spec = [spec(i[0]) for i in sites if i[1] is None] if (fancy_indexing and _p.isarray(spec) @@ -933,7 +941,8 @@ def plot(sys, num_lead_cells=2, unit='nn', spec = np.asarray(spec, dtype='object') return spec - site_symbol = make_proper_site_spec(site_symbol) + + site_symbol = make_proper_site_spec('site_symbol', site_symbol) if site_symbol is None: site_symbol = defaults['site_symbol'][dim] # separate different symbols (not done in 3D, the separation # would mess up sorting) @@ -967,10 +976,10 @@ def plot(sys, num_lead_cells=2, unit='nn', # Unknown finalized system, no sites access. site_color = defaults['site_color'][dim] - site_size = make_proper_site_spec(site_size, fancy_indexing) - site_color = make_proper_site_spec(site_color, fancy_indexing) - site_edgecolor = make_proper_site_spec(site_edgecolor, fancy_indexing) - site_lw = make_proper_site_spec(site_lw, fancy_indexing) + site_size = make_proper_site_spec('site_size', site_size, fancy_indexing) + site_color = make_proper_site_spec('site_color', site_color, fancy_indexing) + site_edgecolor = make_proper_site_spec('site_edgecolor', site_edgecolor, fancy_indexing) + site_lw = make_proper_site_spec('site_lw', site_lw, fancy_indexing) hop_color = make_proper_hop_spec(hop_color) hop_lw = make_proper_hop_spec(hop_lw) diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py index 5974d893889f10061f12db69661dab27eabad0b4..a8358284528c23ea5c70e7adbcadfe3177a151db 100644 --- a/kwant/tests/test_plotter.py +++ b/kwant/tests/test_plotter.py @@ -169,6 +169,21 @@ def test_plot_more_site_families_than_colors(): plotter.plot(syst, file=out) +@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.") +def test_plot_raises_on_bad_site_spec(): + syst = kwant.Builder() + lat = kwant.lattice.square() + syst[(lat(i, j) for i in range(5) for j in range(5))] = None + + # Cannot provide site_size as an array when syst is a Builder + with pytest.raises(TypeError): + plotter.plot(syst, site_size=[1] * 25) + + # Cannot provide site_size as an array when syst is a Builder + with pytest.raises(TypeError): + plotter.plot(syst, site_symbol=['o'] * 25) + + def good_transform(pos): x, y = pos return y, x