From 6d6be7b152ae677ed8f825459ed01981f3b5142c Mon Sep 17 00:00:00 2001 From: Anton Akhmerov <anton.akhmerov@gmail.com> Date: Thu, 8 Dec 2016 15:09:41 +0100 Subject: [PATCH] allow compressed graph pickling, add test for system pickling --- kwant/graph/core.pxd | 9 ++++- kwant/graph/core.pyx | 67 ++++++++++++++++++++++++---------- kwant/graph/tests/test_core.py | 28 +++++++++++++- kwant/tests/test_system.py | 24 +++++++++++- 4 files changed, 106 insertions(+), 22 deletions(-) diff --git a/kwant/graph/core.pxd b/kwant/graph/core.pxd index d3f7fd09..52c851cc 100644 --- a/kwant/graph/core.pxd +++ b/kwant/graph/core.pxd @@ -1,4 +1,4 @@ -# Copyright 2011-2013 Kwant authors. +# Copyright 2011-2016 Kwant authors. # # This file is part of Kwant. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution and at @@ -7,6 +7,7 @@ # http://kwant-project.org/authors. cimport numpy as np +from cpython cimport array from .defs cimport gint cdef struct Edge: @@ -30,11 +31,17 @@ cdef class gintArraySlice: cdef class CGraph: cdef readonly bint twoway, edge_nr_translation cdef readonly gint num_nodes, num_edges, num_px_edges, num_xp_edges + cdef array.array _heads_idxs cdef gint *heads_idxs + cdef array.array _heads cdef gint *heads + cdef array.array _tails_idxs cdef gint *tails_idxs + cdef array.array _tails cdef gint *tails + cdef array.array _edge_ids cdef gint *edge_ids + cdef array.array _edge_ids_by_edge_nr cdef gint *edge_ids_by_edge_nr cdef gint edge_nr_end diff --git a/kwant/graph/core.pyx b/kwant/graph/core.pyx index 18cddb2b..1f9c15f2 100644 --- a/kwant/graph/core.pyx +++ b/kwant/graph/core.pyx @@ -1,4 +1,4 @@ -# Copyright 2011-2013 Kwant authors. +# Copyright 2011-2016 Kwant authors. # # This file is part of Kwant. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution and at @@ -28,6 +28,8 @@ __all__ = ['Graph', 'CGraph'] from libc.stdlib cimport malloc, realloc, free from libc.string cimport memset +from cpython cimport array +import array import numpy as np cimport numpy as np from .defs cimport gint @@ -661,33 +663,42 @@ cdef class CGraph: cdef class CGraph_malloc(CGraph): """A CGraph which allocates and frees its own memory.""" - def __cinit__(self, twoway, edge_nr_translation, num_nodes, + def __init__(self, twoway, edge_nr_translation, num_nodes, num_pp_edges, num_pn_edges, num_np_edges): self.twoway = twoway self.edge_nr_translation = edge_nr_translation self.num_nodes = num_nodes self.num_px_edges = num_pp_edges + num_pn_edges self.edge_nr_end = num_pp_edges + num_pn_edges + num_np_edges - - self.heads_idxs = <gint*>malloc((num_nodes + 1) * sizeof(gint)) + self._heads_idxs = array.array('i', ()) + array.resize(self._heads_idxs, (num_nodes + 1)) + self.heads_idxs = <gint*>self._heads_idxs.data.as_ints if self.twoway: # The graph is two-way. n->p edges will exist in the compressed # graph. self.num_xp_edges = num_pp_edges + num_np_edges self.num_edges = self.edge_nr_end - self.tails_idxs = <gint*>malloc((num_nodes + 1) * sizeof(gint)) - self.tails = <gint*>malloc( - self.num_xp_edges * sizeof(gint)) - self.edge_ids = <gint*>malloc( - self.num_xp_edges * sizeof(gint)) + self._tails_idxs = array.array('i', ()) + array.resize(self._tails_idxs, (num_nodes + 1)) + self.tails_idxs = <gint*>self._tails_idxs.data.as_ints + self._tails = array.array('i', ()) + array.resize(self._tails, self.num_xp_edges) + self.tails = <gint*>self._tails.data.as_ints + self._edge_ids = array.array('i', ()) + array.resize(self._edge_ids, self.num_xp_edges) + self.edge_ids = <gint*>self._edge_ids.data.as_ints else: # The graph is one-way. n->p edges will be ignored. self.num_xp_edges = num_pp_edges self.num_edges = self.num_px_edges - self.heads = <gint*>malloc(self.num_edges * sizeof(gint)) + self._heads = array.array('i', ()) + array.resize(self._heads, self.num_edges) + self.heads = <gint*>self._heads.data.as_ints if edge_nr_translation: - self.edge_ids_by_edge_nr = <gint*>malloc( - self.edge_nr_end * sizeof(gint)) + self._edge_ids_by_edge_nr = array.array('i', ()) + array.resize(self._edge_ids_by_edge_nr, self.edge_nr_end) + self.edge_ids_by_edge_nr = (<gint*>self._edge_ids_by_edge_nr + .data.as_ints) if (not self.heads_idxs or not self.heads or (twoway and (not self.tails_idxs or not self.tails @@ -695,10 +706,28 @@ cdef class CGraph_malloc(CGraph): or (edge_nr_translation and not self.edge_ids_by_edge_nr)): raise MemoryError - def __dealloc__(self): - free(self.edge_ids_by_edge_nr) - free(self.heads) - free(self.edge_ids) - free(self.tails) - free(self.tails_idxs) - free(self.heads_idxs) + def __getstate__(self): + twoway = self.twoway + edge_nr_translation = self.edge_nr_translation + num_nodes = self.num_nodes + num_np_edges = self.edge_nr_end - self.num_px_edges + if twoway: + num_pp_edges = self.num_xp_edges - num_np_edges + else: + num_pp_edges = self.num_xp_edges + num_pn_edges = self.num_px_edges - num_pp_edges + init_args = (twoway, edge_nr_translation, num_nodes, + num_pp_edges, num_pn_edges, num_np_edges) + + return (init_args, self._heads_idxs, self._heads, self._tails_idxs, + self._tails, self._edge_ids, self._edge_ids_by_edge_nr) + + def __setstate__(self, state): + self.__init__(*state[0]) + array_attributes = (self._heads_idxs, self._heads, self._tails_idxs, + self._tails, self._edge_ids, + self._edge_ids_by_edge_nr) + for attribute, value in zip(array_attributes, state[1:]): + if attribute is None: + continue + attribute[:] = value diff --git a/kwant/graph/tests/test_core.py b/kwant/graph/tests/test_core.py index e657bdc5..b52e6a10 100644 --- a/kwant/graph/tests/test_core.py +++ b/kwant/graph/tests/test_core.py @@ -1,4 +1,4 @@ -# Copyright 2011-2013 Kwant authors. +# Copyright 2011-2016 Kwant authors. # # This file is part of Kwant. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution and at @@ -6,6 +6,7 @@ # the file AUTHORS.rst at the top-level directory of this distribution and at # http://kwant-project.org/authors. +import pickle from io import StringIO from itertools import zip_longest import numpy as np @@ -179,3 +180,28 @@ def test_edge_ids(): g = gr.compressed(edge_nr_translation=True, allow_lost_edges=True) raises(EdgeDoesNotExistError, g.edge_id, 1) + + +def test_pickle(): + gr = Graph(allow_negative_nodes=True) + edges = [(0, -1), (-1, 0), (1, 2), (1, 2), (0, -1), (-1, 0), (-1, 0)] + gr.add_edges(edges) + g = gr.compressed(twoway=True, edge_nr_translation=True) + g2 = pickle.loads(pickle.dumps(g)) + s = StringIO('') + g.write_dot(s) + s2 = StringIO('') + g2.write_dot(s2) + assert s.getvalue() == s2.getvalue() + assert g.__getstate__() == g2.__getstate__() + + gr = Graph(allow_negative_nodes=False) + edges = [(0, 1), (1, 2), (1, 2), (0, 2)] + g = gr.compressed(twoway=False, edge_nr_translation=False) + g2 = pickle.loads(pickle.dumps(g)) + s = StringIO('') + g.write_dot(s) + s2 = StringIO('') + g2.write_dot(s2) + assert s.getvalue() == s2.getvalue() + assert g.__getstate__() == g2.__getstate__() diff --git a/kwant/tests/test_system.py b/kwant/tests/test_system.py index f90a3b1c..f06da88a 100644 --- a/kwant/tests/test_system.py +++ b/kwant/tests/test_system.py @@ -1,4 +1,4 @@ -# Copyright 2011-2013 Kwant authors. +# Copyright 2011-2016 Kwant authors. # # This file is part of Kwant. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution and at @@ -6,6 +6,8 @@ # the file AUTHORS.rst at the top-level directory of this distribution and at # http://kwant-project.org/authors. +import pickle +import copy from pytest import raises import numpy as np from scipy import sparse @@ -99,3 +101,23 @@ def test_hamiltonian_submatrix(): mat = mat[perm, :] mat = mat[:, perm] np.testing.assert_array_equal(mat, mat_should_be) + + +def test_pickling(): + syst = kwant.Builder() + lead = kwant.Builder(symmetry=kwant.TranslationalSymmetry([1.])) + lat = kwant.lattice.chain() + syst[lat(0)] = syst[lat(1)] = 0 + syst[lat(0), lat(1)] = 1 + lead[lat(0)] = syst[lat(1)] = 0 + lead[lat(0), lat(1)] = 1 + syst.attach_lead(lead) + syst.attach_lead(lead.reversed()) + syst_copy1 = copy.copy(syst).finalized() + syst_copy2 = pickle.loads(pickle.dumps(syst)).finalized() + syst = syst.finalized() + syst_copy3 = copy.copy(syst) + syst_copy4 = pickle.loads(pickle.dumps(syst)) + s = kwant.smatrix(syst, 0.1) + for other in (syst_copy1, syst_copy2, syst_copy3, syst_copy4): + assert np.all(kwant.smatrix(other, 0.1).data == s.data) -- GitLab