From 6d6be7b152ae677ed8f825459ed01981f3b5142c Mon Sep 17 00:00:00 2001
From: Anton Akhmerov <anton.akhmerov@gmail.com>
Date: Thu, 8 Dec 2016 15:09:41 +0100
Subject: [PATCH] allow compressed graph pickling, add test for system pickling

---
 kwant/graph/core.pxd           |  9 ++++-
 kwant/graph/core.pyx           | 67 ++++++++++++++++++++++++----------
 kwant/graph/tests/test_core.py | 28 +++++++++++++-
 kwant/tests/test_system.py     | 24 +++++++++++-
 4 files changed, 106 insertions(+), 22 deletions(-)

diff --git a/kwant/graph/core.pxd b/kwant/graph/core.pxd
index d3f7fd09..52c851cc 100644
--- a/kwant/graph/core.pxd
+++ b/kwant/graph/core.pxd
@@ -1,4 +1,4 @@
-# Copyright 2011-2013 Kwant authors.
+# Copyright 2011-2016 Kwant authors.
 #
 # This file is part of Kwant.  It is subject to the license terms in the file
 # LICENSE.rst found in the top-level directory of this distribution and at
@@ -7,6 +7,7 @@
 # http://kwant-project.org/authors.
 
 cimport numpy as np
+from cpython cimport array
 from .defs cimport gint
 
 cdef struct Edge:
@@ -30,11 +31,17 @@ cdef class gintArraySlice:
 cdef class CGraph:
     cdef readonly bint twoway, edge_nr_translation
     cdef readonly gint num_nodes, num_edges, num_px_edges, num_xp_edges
+    cdef array.array _heads_idxs
     cdef gint *heads_idxs
+    cdef array.array _heads
     cdef gint *heads
+    cdef array.array _tails_idxs
     cdef gint *tails_idxs
+    cdef array.array _tails
     cdef gint *tails
+    cdef array.array _edge_ids
     cdef gint *edge_ids
+    cdef array.array _edge_ids_by_edge_nr
     cdef gint *edge_ids_by_edge_nr
     cdef gint edge_nr_end
 
diff --git a/kwant/graph/core.pyx b/kwant/graph/core.pyx
index 18cddb2b..1f9c15f2 100644
--- a/kwant/graph/core.pyx
+++ b/kwant/graph/core.pyx
@@ -1,4 +1,4 @@
-# Copyright 2011-2013 Kwant authors.
+# Copyright 2011-2016 Kwant authors.
 #
 # This file is part of Kwant.  It is subject to the license terms in the file
 # LICENSE.rst found in the top-level directory of this distribution and at
@@ -28,6 +28,8 @@ __all__ = ['Graph', 'CGraph']
 
 from libc.stdlib cimport malloc, realloc, free
 from libc.string cimport memset
+from cpython cimport array
+import array
 import numpy as np
 cimport numpy as np
 from .defs cimport gint
@@ -661,33 +663,42 @@ cdef class CGraph:
 cdef class CGraph_malloc(CGraph):
     """A CGraph which allocates and frees its own memory."""
 
-    def __cinit__(self, twoway, edge_nr_translation, num_nodes,
+    def __init__(self, twoway, edge_nr_translation, num_nodes,
                   num_pp_edges, num_pn_edges, num_np_edges):
         self.twoway = twoway
         self.edge_nr_translation = edge_nr_translation
         self.num_nodes = num_nodes
         self.num_px_edges = num_pp_edges + num_pn_edges
         self.edge_nr_end = num_pp_edges + num_pn_edges + num_np_edges
-
-        self.heads_idxs = <gint*>malloc((num_nodes + 1) * sizeof(gint))
+        self._heads_idxs = array.array('i', ())
+        array.resize(self._heads_idxs, (num_nodes + 1))
+        self.heads_idxs = <gint*>self._heads_idxs.data.as_ints
         if self.twoway:
             # The graph is two-way. n->p edges will exist in the compressed
             # graph.
             self.num_xp_edges = num_pp_edges + num_np_edges
             self.num_edges = self.edge_nr_end
-            self.tails_idxs = <gint*>malloc((num_nodes + 1) * sizeof(gint))
-            self.tails = <gint*>malloc(
-                self.num_xp_edges * sizeof(gint))
-            self.edge_ids = <gint*>malloc(
-                self.num_xp_edges * sizeof(gint))
+            self._tails_idxs = array.array('i', ())
+            array.resize(self._tails_idxs, (num_nodes + 1))
+            self.tails_idxs = <gint*>self._tails_idxs.data.as_ints
+            self._tails = array.array('i', ())
+            array.resize(self._tails, self.num_xp_edges)
+            self.tails = <gint*>self._tails.data.as_ints
+            self._edge_ids = array.array('i', ())
+            array.resize(self._edge_ids, self.num_xp_edges)
+            self.edge_ids = <gint*>self._edge_ids.data.as_ints
         else:
             # The graph is one-way. n->p edges will be ignored.
             self.num_xp_edges = num_pp_edges
             self.num_edges = self.num_px_edges
-        self.heads = <gint*>malloc(self.num_edges * sizeof(gint))
+        self._heads = array.array('i', ())
+        array.resize(self._heads, self.num_edges)
+        self.heads = <gint*>self._heads.data.as_ints
         if edge_nr_translation:
-            self.edge_ids_by_edge_nr = <gint*>malloc(
-                self.edge_nr_end * sizeof(gint))
+            self._edge_ids_by_edge_nr = array.array('i', ())
+            array.resize(self._edge_ids_by_edge_nr, self.edge_nr_end)
+            self.edge_ids_by_edge_nr = (<gint*>self._edge_ids_by_edge_nr
+                                        .data.as_ints)
         if (not self.heads_idxs or not self.heads
             or (twoway and (not self.tails_idxs
                              or not self.tails
@@ -695,10 +706,28 @@ cdef class CGraph_malloc(CGraph):
             or (edge_nr_translation and not self.edge_ids_by_edge_nr)):
             raise MemoryError
 
-    def __dealloc__(self):
-        free(self.edge_ids_by_edge_nr)
-        free(self.heads)
-        free(self.edge_ids)
-        free(self.tails)
-        free(self.tails_idxs)
-        free(self.heads_idxs)
+    def __getstate__(self):
+        twoway = self.twoway
+        edge_nr_translation = self.edge_nr_translation
+        num_nodes = self.num_nodes
+        num_np_edges = self.edge_nr_end - self.num_px_edges
+        if twoway:
+            num_pp_edges = self.num_xp_edges - num_np_edges
+        else:
+            num_pp_edges = self.num_xp_edges
+        num_pn_edges = self.num_px_edges - num_pp_edges
+        init_args = (twoway, edge_nr_translation, num_nodes,
+                num_pp_edges, num_pn_edges, num_np_edges)
+
+        return (init_args, self._heads_idxs, self._heads, self._tails_idxs,
+                self._tails, self._edge_ids, self._edge_ids_by_edge_nr)
+
+    def __setstate__(self, state):
+        self.__init__(*state[0])
+        array_attributes = (self._heads_idxs, self._heads, self._tails_idxs,
+                            self._tails, self._edge_ids,
+                            self._edge_ids_by_edge_nr)
+        for attribute, value in zip(array_attributes, state[1:]):
+            if attribute is None:
+                continue
+            attribute[:] = value
diff --git a/kwant/graph/tests/test_core.py b/kwant/graph/tests/test_core.py
index e657bdc5..b52e6a10 100644
--- a/kwant/graph/tests/test_core.py
+++ b/kwant/graph/tests/test_core.py
@@ -1,4 +1,4 @@
-# Copyright 2011-2013 Kwant authors.
+# Copyright 2011-2016 Kwant authors.
 #
 # This file is part of Kwant.  It is subject to the license terms in the file
 # LICENSE.rst found in the top-level directory of this distribution and at
@@ -6,6 +6,7 @@
 # the file AUTHORS.rst at the top-level directory of this distribution and at
 # http://kwant-project.org/authors.
 
+import pickle
 from io import StringIO
 from itertools import zip_longest
 import numpy as np
@@ -179,3 +180,28 @@ def test_edge_ids():
 
     g = gr.compressed(edge_nr_translation=True, allow_lost_edges=True)
     raises(EdgeDoesNotExistError, g.edge_id, 1)
+
+
+def test_pickle():
+    gr = Graph(allow_negative_nodes=True)
+    edges = [(0, -1), (-1, 0), (1, 2), (1, 2), (0, -1), (-1, 0), (-1, 0)]
+    gr.add_edges(edges)
+    g = gr.compressed(twoway=True, edge_nr_translation=True)
+    g2 = pickle.loads(pickle.dumps(g))
+    s = StringIO('')
+    g.write_dot(s)
+    s2 = StringIO('')
+    g2.write_dot(s2)
+    assert s.getvalue() == s2.getvalue()
+    assert g.__getstate__() == g2.__getstate__()
+
+    gr = Graph(allow_negative_nodes=False)
+    edges = [(0, 1), (1, 2), (1, 2), (0, 2)]
+    g = gr.compressed(twoway=False, edge_nr_translation=False)
+    g2 = pickle.loads(pickle.dumps(g))
+    s = StringIO('')
+    g.write_dot(s)
+    s2 = StringIO('')
+    g2.write_dot(s2)
+    assert s.getvalue() == s2.getvalue()
+    assert g.__getstate__() == g2.__getstate__()
diff --git a/kwant/tests/test_system.py b/kwant/tests/test_system.py
index f90a3b1c..f06da88a 100644
--- a/kwant/tests/test_system.py
+++ b/kwant/tests/test_system.py
@@ -1,4 +1,4 @@
-# Copyright 2011-2013 Kwant authors.
+# Copyright 2011-2016 Kwant authors.
 #
 # This file is part of Kwant.  It is subject to the license terms in the file
 # LICENSE.rst found in the top-level directory of this distribution and at
@@ -6,6 +6,8 @@
 # the file AUTHORS.rst at the top-level directory of this distribution and at
 # http://kwant-project.org/authors.
 
+import pickle
+import copy
 from pytest import raises
 import numpy as np
 from scipy import sparse
@@ -99,3 +101,23 @@ def test_hamiltonian_submatrix():
     mat = mat[perm, :]
     mat = mat[:, perm]
     np.testing.assert_array_equal(mat, mat_should_be)
+
+
+def test_pickling():
+    syst = kwant.Builder()
+    lead = kwant.Builder(symmetry=kwant.TranslationalSymmetry([1.]))
+    lat = kwant.lattice.chain()
+    syst[lat(0)] = syst[lat(1)] = 0
+    syst[lat(0), lat(1)] = 1
+    lead[lat(0)] = syst[lat(1)] = 0
+    lead[lat(0), lat(1)] = 1
+    syst.attach_lead(lead)
+    syst.attach_lead(lead.reversed())
+    syst_copy1 = copy.copy(syst).finalized()
+    syst_copy2 = pickle.loads(pickle.dumps(syst)).finalized()
+    syst = syst.finalized()
+    syst_copy3 = copy.copy(syst)
+    syst_copy4 = pickle.loads(pickle.dumps(syst))
+    s = kwant.smatrix(syst, 0.1)
+    for other in (syst_copy1, syst_copy2, syst_copy3, syst_copy4):
+        assert np.all(kwant.smatrix(other, 0.1).data == s.data)
-- 
GitLab