Commit 6e1b8653 authored by Kelvin Loh's avatar Kelvin Loh

Adds the plotly testing script.

parent 2f43a91c
Pipeline #25617 passed with stages
in 8 minutes and 48 seconds
......@@ -174,12 +174,17 @@ if plotly_available:
def convert_cmap_list_mpl_plotly(mpl_cmap_name, N=255):
cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name)
cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl)
level = np.linspace(0, 1, N)
cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl))
for level, cmap_mpl in zip(level,
cmap_mpl_arr)]
if isinstance(mpl_cmap_name, str):
cmap_mpl = matplotlib.cm.get_cmap(mpl_cmap_name)
cmap_mpl_arr = matplotlib.colors.makeMappingArray(N, cmap_mpl)
level = np.linspace(0, 1, N)
cmap_plotly_linear = [(level, convert_colormap_mpl_plotly(cmap_mpl))
for level, cmap_mpl in zip(level,
cmap_mpl_arr)]
else:
assert(isinstance(mpl_cmap_name, list))
# Do not do any conversion if it's already a list
cmap_plotly_linear = mpl_cmap_name
return cmap_plotly_linear
......
......@@ -1612,6 +1612,7 @@ def mask_interpolate(coords, values, a=None, method='nearest', oversampling=3):
range(len(cmin)))
grid = tuple(np.ogrid[dims])
img = interpolate.griddata(coords, values, grid, method)
img = img.astype(np.float_)
mask = np.mgrid[dims].reshape(len(cmin), -1).T
# The numerical values in the following line are optimized for the common
# case of a square lattice:
......@@ -1793,7 +1794,8 @@ def _map_plotly(syst, img, colorbar, _max, _min, vmin, vmax, overflow_pct,
contour_object.y = np.linspace(_min[1],_max[1],img.shape[1])
contour_object.zsmooth = False
contour_object.connectgaps = False
contour_object.colorscale = _p.convert_cmap_list_mpl_plotly(cmap)
cmap = _p.convert_cmap_list_mpl_plotly(cmap)
contour_object.colorscale = cmap
contour_object.zmax = vmax
contour_object.zmin = vmin
contour_object.hoverinfo = 'none'
......
......@@ -117,7 +117,9 @@ def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot():
def test_matplotlib_plot():
plotter.set_engine('matplotlib')
plot = plotter.plot
syst2d = syst_2d()
syst3d = syst_3d()
......@@ -159,8 +161,43 @@ def test_plot():
plot(syst3d, file=out, pos_transform=lambda pos: pos[:2])
@pytest.mark.skipif(not _plotter.plotly_available, reason="Plotly unavailable.")
def test_plotly_plot():
plotter.set_engine('plotly')
plot = plotter.plot
syst2d = syst_2d()
syst3d = syst_3d()
color_opts = ['black', (lambda site: site.tag[0]),
lambda site: (abs(site.tag[0] / 100),
abs(site.tag[1] / 100), 0)]
with tempfile.TemporaryFile('w+b') as out:
out = f'{out}.html'
for color in color_opts:
for syst in (syst2d, syst3d):
plot(syst, site_color=color, cmap='binary', file=out, show=False)
color_opts = ['black', (lambda site, site2: site.tag[0]),
lambda site, site2: (abs(site.tag[0] / 100),
abs(site.tag[1] / 100), 0)]
syst2d.leads = []
plot(syst2d, file=out, show=False)
del syst2d[list(syst2d.hoppings())]
plot(syst2d, file=out, show=False)
plot(syst3d, file=out, show=False)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plot(syst2d.finalized(), file=out, show=False)
# test 2D projections of 3D systems
plot(syst3d, file=out, pos_transform=lambda pos: pos[:2], show=False)
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot_more_site_families_than_colors():
@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
def test_plot_more_site_families_than_colors(engine):
# test against regression reported in
# https://gitlab.kwant-project.org/kwant/kwant/issues/257
ncolors = len(pyplot.rcParams['axes.prop_cycle'])
......@@ -169,17 +206,23 @@ def test_plot_more_site_families_than_colors():
for i in range(ncolors + 1)]
for i, lat in enumerate(lattices):
syst[lat(i, 0)] = None
plotter.set_engine(engine)
with tempfile.TemporaryFile('w+b') as out:
plotter.plot(syst, file=out)
if engine == 'plotly':
out = f'{out}.html'
plotter.plot(syst, file=out, show=False)
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot_raises_on_bad_site_spec():
@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
def test_plot_raises_on_bad_site_spec(engine):
syst = kwant.Builder()
lat = kwant.lattice.square(norbs=1)
syst[(lat(i, j) for i in range(5) for j in range(5))] = None
# Cannot provide site_size as an array when syst is a Builder
plotter.set_engine(engine)
with pytest.raises(TypeError):
plotter.plot(syst, site_size=[1] * 25)
......@@ -197,18 +240,23 @@ def bad_transform(pos):
return x, y, 0
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_map():
@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
def test_map(engine):
plotter.set_engine(engine)
syst = syst_2d()
with tempfile.TemporaryFile('w+b') as out:
if engine == 'plotly':
out = f'{out}.html'
plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
file=out, method='linear', a=4, oversampling=4, cmap='flag')
file=out, method='linear', a=4, oversampling=4, cmap='flag', show=False)
pytest.raises(ValueError, plotter.map, syst,
lambda site: site.tag[0],
pos_transform=bad_transform, file=out)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plotter.map(syst.finalized(), range(len(syst.sites())),
file=out)
file=out, show=False)
pytest.raises(ValueError, plotter.map, syst,
range(len(syst.sites())), file=out)
......@@ -233,22 +281,33 @@ def test_mask_interpolate():
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_bands():
@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
def test_bands(engine):
plotter.set_engine(engine)
syst = syst_2d().finalized().leads[0]
with tempfile.TemporaryFile('w+b') as out:
plotter.bands(syst, file=out)
plotter.bands(syst, fig_size=(10, 10), file=out)
plotter.bands(syst, momenta=np.linspace(0, 2 * np.pi), file=out)
if engine == 'plotly':
out = f'{out}.html'
plotter.bands(syst, show=False, file=out)
plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out)
if engine == 'matplotlib':
plotter.bands(syst, show=False, fig_size=(10, 10), file=out)
fig = pyplot.Figure()
ax = fig.add_subplot(1, 1, 1)
plotter.bands(syst, show=False, ax=ax, file=out)
fig = pyplot.Figure()
ax = fig.add_subplot(1, 1, 1)
plotter.bands(syst, ax=ax, file=out)
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_spectrum():
@pytest.mark.parametrize("engine", ["plotly", "matplotlib"])
def test_spectrum(engine):
plotter.set_engine(engine)
def ham_1d(a, b, c):
return a**2 + b**2 + c**2
......@@ -265,37 +324,44 @@ def test_spectrum():
vals = np.linspace(0, 1, 3)
with tempfile.TemporaryFile('w+b') as out:
if engine == 'plotly':
out = f'{out}.html'
for ham in (ham_1d, ham_2d, fsyst):
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out)
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
fig_size=(10, 10), file=out)
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out, show=False)
if engine == 'matplotlib':
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
fig_size=(10, 10), file=out, show=False)
for ham in (ham_1d, ham_2d, fsyst):
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
params=dict(c=1), file=out)
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
params=dict(c=1), fig_size=(10, 10), file=out)
# test 2D plot and explicitly passing axis
fig = pyplot.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
params=dict(c=1), ax=ax, file=out)
# explicitly pass axis without 3D support
ax = fig.add_subplot(1, 1, 1)
with pytest.raises(TypeError):
params=dict(c=1), file=out, show=False)
if engine == 'matplotlib':
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
params=dict(c=1), fig_size=(10, 10), file=out, show=False)
if engine == 'matplotlib':
# test 2D plot and explicitly passing axis
fig = pyplot.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
params=dict(c=1), ax=ax, file=out)
params=dict(c=1), ax=ax, file=out, show=False)
# explicitly pass axis without 3D support
ax = fig.add_subplot(1, 1, 1)
with pytest.raises(TypeError):
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
params=dict(c=1), ax=ax, file=out, show=False)
def mask(a, b):
return a > 0.5
with tempfile.TemporaryFile('w+b') as out:
if engine == 'plotly':
out = f'{out}.html'
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
mask=mask, file=out)
mask=mask, file=out, show=False)
def syst_rect(lat, salt, W=3, L=50):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment