From b825d6c03f918aa385a42dbc7b669d7ffefe0cf6 Mon Sep 17 00:00:00 2001
From: Anton Akhmerov <anton.akhmerov@gmail.com>
Date: Mon, 2 Sep 2013 23:45:21 +0200
Subject: [PATCH] make plotter tests run without DISPLAY set

---
 kwant/tests/test_plotter.py | 66 +++++++++++++++++++------------------
 1 file changed, 34 insertions(+), 32 deletions(-)

diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py
index 4b97cfb..525091a 100644
--- a/kwant/tests/test_plotter.py
+++ b/kwant/tests/test_plotter.py
@@ -96,47 +96,49 @@ def test_plot():
     color_opts = ['k', (lambda site: site.tag[0]),
                   lambda site: (abs(site.tag[0] / 100),
                                 abs(site.tag[1] / 100), 0)]
-    for color in color_opts:
-        for sys in (sys2d, sys3d):
-            fig = plot(sys, site_color=color, cmap='binary', show=False)
-            if color != 'k' and isinstance(color(iter(sys2d.sites()).next()),
-                                           float):
-                assert fig.axes[0].collections[0].get_array() is not None
-            assert len(fig.axes[0].collections) == (8 if sys is sys2d else 6)
-    color_opts = ['k', (lambda site, site2: site.tag[0]),
-                  lambda site, site2: (abs(site.tag[0] / 100),
-                                       abs(site.tag[1] / 100), 0)]
-    for color in color_opts:
-        for sys in (sys2d, sys3d):
-            fig = plot(sys2d, hop_color=color, cmap='binary', show=False,
-                       fig_size=(2, 10), dpi=30)
-            if color != 'k' and isinstance(color(iter(sys2d.sites()).next(),
-                                                      None), float):
-                assert fig.axes[0].collections[1].get_array() is not None
-
-    assert isinstance(plot(sys3d, show=False).axes[0], mplot3d.axes3d.Axes3D)
-
-    sys2d.leads = []
-    plot(sys2d, show=False)
-    del sys2d[list(sys2d.hoppings())]
-    plot(sys2d, show=False)
-    with tempfile.TemporaryFile('w+b') as output:
-        plot(sys3d, file=output)
+    with tempfile.TemporaryFile('w+b') as out:
+        for color in color_opts:
+            for sys in (sys2d, sys3d):
+                fig = plot(sys, site_color=color, cmap='binary', file=out)
+                if color != 'k' and \
+                   isinstance(color(iter(sys2d.sites()).next()), float):
+                    assert fig.axes[0].collections[0].get_array() is not None
+                assert len(fig.axes[0].collections) == (8 if sys is sys2d else
+                                                        6)
+        color_opts = ['k', (lambda site, site2: site.tag[0]),
+                      lambda site, site2: (abs(site.tag[0] / 100),
+                                           abs(site.tag[1] / 100), 0)]
+        for color in color_opts:
+            for sys in (sys2d, sys3d):
+                fig = plot(sys2d, hop_color=color, cmap='binary', file=out,
+                           fig_size=(2, 10), dpi=30)
+                if color != 'k' and isinstance(color(iter(sys2d.sites()).next(),
+                                                          None), float):
+                    assert fig.axes[0].collections[1].get_array() is not None
+
+        assert isinstance(plot(sys3d, file=out).axes[0], mplot3d.axes3d.Axes3D)
+
+        sys2d.leads = []
+        plot(sys2d, file=out)
+        del sys2d[list(sys2d.hoppings())]
+        plot(sys2d, file=out)
+
+        plot(sys3d, file=out)
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
-            plot(sys2d.finalized(), file=output)
+            plot(sys2d.finalized(), file=out)
 
 
 def test_map():
     if not plotter.mpl_enabled:
         raise nose.SkipTest
     sys = sys_2d()
-    with tempfile.TemporaryFile('w+b') as output:
-        plotter.map(sys, lambda site: site.tag[0], file=output,
+    with tempfile.TemporaryFile('w+b') as out:
+        plotter.map(sys, lambda site: site.tag[0], file=out,
                           method='linear', a=4, oversampling=4, cmap='flag')
         with warnings.catch_warnings():
             warnings.simplefilter("ignore")
             plotter.map(sys.finalized(), xrange(len(sys.sites())),
-                              file=output)
-        nose.tools.assert_raises(ValueError, plotter.map,
-                                 sys, xrange(len(sys.sites())), file=output)
+                              file=out)
+        nose.tools.assert_raises(ValueError, plotter.map, sys,
+                                 xrange(len(sys.sites())), file=out)
-- 
GitLab