Commit 92993c0c authored by Kelvin Loh's avatar Kelvin Loh

Test 5.

parent ee04a2ce
Pipeline #25628 passed with stages
in 8 minutes and 17 seconds
......@@ -136,10 +136,11 @@ def test_matplotlib_plot():
lambda site: (abs(site.tag[0] / 100),
abs(site.tag[1] / 100), 0)]
engine = plotter.get_engine()
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)).name as out:
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
out_filename = out.name
for color in color_opts:
for syst in (syst2d, syst3d):
fig = plot(syst, site_color=color, cmap='binary', file=out)
fig = plot(syst, site_color=color, cmap='binary', file=out_filename)
if (color != 'k' and
isinstance(color(next(iter(syst2d.sites()))), float)):
assert fig.axes[0].collections[0].get_array() is not None
......@@ -149,26 +150,26 @@ def test_matplotlib_plot():
abs(site.tag[1] / 100), 0)]
for color in color_opts:
for syst in (syst2d, syst3d):
fig = plot(syst2d, hop_color=color, cmap='binary', file=out,
fig = plot(syst2d, hop_color=color, cmap='binary', file=out_filename,
fig_size=(2, 10), dpi=30)
if color != 'k' and isinstance(color(next(iter(syst2d.sites())),
None), float):
assert fig.axes[0].collections[1].get_array() is not None
assert isinstance(plot(syst3d, file=out).axes[0], mplot3d.axes3d.Axes3D)
assert isinstance(plot(syst3d, file=out_filename).axes[0], mplot3d.axes3d.Axes3D)
syst2d.leads = []
plot(syst2d, file=out)
plot(syst2d, file=out_filename)
del syst2d[list(syst2d.hoppings())]
plot(syst2d, file=out)
plot(syst2d, file=out_filename)
plot(syst3d, file=out)
plot(syst3d, file=out_filename)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plot(syst2d.finalized(), file=out)
plot(syst2d.finalized(), file=out_filename)
# test 2D projections of 3D systems
plot(syst3d, file=out, pos_transform=lambda pos: pos[:2])
plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2])
@pytest.mark.skipif(not _plotter.plotly_available, reason="Plotly unavailable.")
......@@ -182,27 +183,28 @@ def test_plotly_plot():
lambda site: (abs(site.tag[0] / 100),
abs(site.tag[1] / 100), 0)]
engine = plotter.get_engine()
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)).name as out:
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
out_filename = out.name
for color in color_opts:
for syst in (syst2d, syst3d):
plot(syst, site_color=color, cmap='binary', file=out, show=False)
plot(syst, site_color=color, cmap='binary', file=out_filename, show=False)
color_opts = ['black', (lambda site, site2: site.tag[0]),
lambda site, site2: (abs(site.tag[0] / 100),
abs(site.tag[1] / 100), 0)]
syst2d.leads = []
plot(syst2d, file=out, show=False)
plot(syst2d, file=out_filename, show=False)
del syst2d[list(syst2d.hoppings())]
plot(syst2d, file=out, show=False)
plot(syst2d, file=out_filename, show=False)
plot(syst3d, file=out, show=False)
plot(syst3d, file=out_filename, show=False)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plot(syst2d.finalized(), file=out, show=False)
plot(syst2d.finalized(), file=out_filename, show=False)
# test 2D projections of 3D systems
plot(syst3d, file=out, pos_transform=lambda pos: pos[:2], show=False)
plot(syst3d, file=out_filename, pos_transform=lambda pos: pos[:2], show=False)
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
......@@ -218,9 +220,10 @@ def test_plot_more_site_families_than_colors(engine):
syst[lat(i, 0)] = None
plotter.set_engine(engine)
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)).name as out:
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
out_filename = out.name
print(out)
plotter.plot(syst, file=out, show=False)
plotter.plot(syst, file=out_filename, show=False)
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
......@@ -254,18 +257,19 @@ def test_map(engine):
plotter.set_engine(engine)
syst = syst_2d()
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)).name as out:
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
out_filename = out.name
plotter.map(syst, lambda site: site.tag[0], pos_transform=good_transform,
file=out, method='linear', a=4, oversampling=4, cmap='flag', show=False)
file=out_filename, method='linear', a=4, oversampling=4, cmap='flag', show=False)
pytest.raises(ValueError, plotter.map, syst,
lambda site: site.tag[0],
pos_transform=bad_transform, file=out)
pos_transform=bad_transform, file=out_filename)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
plotter.map(syst.finalized(), range(len(syst.sites())),
file=out, show=False)
file=out_filename, show=False)
pytest.raises(ValueError, plotter.map, syst,
range(len(syst.sites())), file=out)
range(len(syst.sites())), file=out_filename)
def test_mask_interpolate():
......@@ -295,16 +299,17 @@ def test_bands(engine):
syst = syst_2d().finalized().leads[0]
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)).name as out:
plotter.bands(syst, show=False, file=out)
plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out)
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
out_filename = out.name
plotter.bands(syst, show=False, file=out_filename)
plotter.bands(syst, show=False, momenta=np.linspace(0, 2 * np.pi), file=out_filename)
if engine == 'matplotlib':
plotter.bands(syst, show=False, fig_size=(10, 10), file=out)
plotter.bands(syst, show=False, fig_size=(10, 10), file=out_filename)
fig = pyplot.Figure()
ax = fig.add_subplot(1, 1, 1)
plotter.bands(syst, show=False, ax=ax, file=out)
plotter.bands(syst, show=False, ax=ax, file=out_filename)
......@@ -328,41 +333,43 @@ def test_spectrum(engine):
vals = np.linspace(0, 1, 3)
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)).name as out:
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
out_filename = out.name
for ham in (ham_1d, ham_2d, fsyst):
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out, show=False)
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1), file=out_filename, show=False)
if engine == 'matplotlib':
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), params=dict(b=1, c=1),
fig_size=(10, 10), file=out, show=False)
fig_size=(10, 10), file=out_filename, show=False)
for ham in (ham_1d, ham_2d, fsyst):
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
params=dict(c=1), file=out, show=False)
params=dict(c=1), file=out_filename, show=False)
if engine == 'matplotlib':
# test with explicit figsize
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals),
params=dict(c=1), fig_size=(10, 10), file=out, show=False)
params=dict(c=1), fig_size=(10, 10), file=out_filename, show=False)
if engine == 'matplotlib':
# test 2D plot and explicitly passing axis
fig = pyplot.figure()
ax = fig.add_subplot(1, 1, 1, projection='3d')
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
params=dict(c=1), ax=ax, file=out, show=False)
params=dict(c=1), ax=ax, file=out_filename, show=False)
# explicitly pass axis without 3D support
ax = fig.add_subplot(1, 1, 1)
with pytest.raises(TypeError):
plotter.spectrum(ham_1d, ('a', vals), ('b', 2 * vals),
params=dict(c=1), ax=ax, file=out, show=False)
params=dict(c=1), ax=ax, file=out_filename, show=False)
def mask(a, b):
return a > 0.5
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)).name as out:
with tempfile.NamedTemporaryFile('w+b', suffix=plotter_file_suffix(engine)) as out:
out_filename = out.name
plotter.spectrum(ham, ('a', vals), ('b', 2 * vals), params=dict(c=1),
mask=mask, file=out, show=False)
mask=mask, file=out_filename, show=False)
def syst_rect(lat, salt, W=3, L=50):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment