Skip to content
Snippets Groups Projects
Commit 6d6be7b1 authored by Anton Akhmerov's avatar Anton Akhmerov
Browse files

allow compressed graph pickling, add test for system pickling

parent 18c7fd24
No related branches found
No related tags found
1 merge request!56System pickle
# Copyright 2011-2013 Kwant authors.
# Copyright 2011-2016 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -7,6 +7,7 @@
# http://kwant-project.org/authors.
cimport numpy as np
from cpython cimport array
from .defs cimport gint
cdef struct Edge:
......@@ -30,11 +31,17 @@ cdef class gintArraySlice:
cdef class CGraph:
cdef readonly bint twoway, edge_nr_translation
cdef readonly gint num_nodes, num_edges, num_px_edges, num_xp_edges
cdef array.array _heads_idxs
cdef gint *heads_idxs
cdef array.array _heads
cdef gint *heads
cdef array.array _tails_idxs
cdef gint *tails_idxs
cdef array.array _tails
cdef gint *tails
cdef array.array _edge_ids
cdef gint *edge_ids
cdef array.array _edge_ids_by_edge_nr
cdef gint *edge_ids_by_edge_nr
cdef gint edge_nr_end
......
# Copyright 2011-2013 Kwant authors.
# Copyright 2011-2016 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -28,6 +28,8 @@ __all__ = ['Graph', 'CGraph']
from libc.stdlib cimport malloc, realloc, free
from libc.string cimport memset
from cpython cimport array
import array
import numpy as np
cimport numpy as np
from .defs cimport gint
......@@ -661,33 +663,42 @@ cdef class CGraph:
cdef class CGraph_malloc(CGraph):
"""A CGraph which allocates and frees its own memory."""
def __cinit__(self, twoway, edge_nr_translation, num_nodes,
def __init__(self, twoway, edge_nr_translation, num_nodes,
num_pp_edges, num_pn_edges, num_np_edges):
self.twoway = twoway
self.edge_nr_translation = edge_nr_translation
self.num_nodes = num_nodes
self.num_px_edges = num_pp_edges + num_pn_edges
self.edge_nr_end = num_pp_edges + num_pn_edges + num_np_edges
self.heads_idxs = <gint*>malloc((num_nodes + 1) * sizeof(gint))
self._heads_idxs = array.array('i', ())
array.resize(self._heads_idxs, (num_nodes + 1))
self.heads_idxs = <gint*>self._heads_idxs.data.as_ints
if self.twoway:
# The graph is two-way. n->p edges will exist in the compressed
# graph.
self.num_xp_edges = num_pp_edges + num_np_edges
self.num_edges = self.edge_nr_end
self.tails_idxs = <gint*>malloc((num_nodes + 1) * sizeof(gint))
self.tails = <gint*>malloc(
self.num_xp_edges * sizeof(gint))
self.edge_ids = <gint*>malloc(
self.num_xp_edges * sizeof(gint))
self._tails_idxs = array.array('i', ())
array.resize(self._tails_idxs, (num_nodes + 1))
self.tails_idxs = <gint*>self._tails_idxs.data.as_ints
self._tails = array.array('i', ())
array.resize(self._tails, self.num_xp_edges)
self.tails = <gint*>self._tails.data.as_ints
self._edge_ids = array.array('i', ())
array.resize(self._edge_ids, self.num_xp_edges)
self.edge_ids = <gint*>self._edge_ids.data.as_ints
else:
# The graph is one-way. n->p edges will be ignored.
self.num_xp_edges = num_pp_edges
self.num_edges = self.num_px_edges
self.heads = <gint*>malloc(self.num_edges * sizeof(gint))
self._heads = array.array('i', ())
array.resize(self._heads, self.num_edges)
self.heads = <gint*>self._heads.data.as_ints
if edge_nr_translation:
self.edge_ids_by_edge_nr = <gint*>malloc(
self.edge_nr_end * sizeof(gint))
self._edge_ids_by_edge_nr = array.array('i', ())
array.resize(self._edge_ids_by_edge_nr, self.edge_nr_end)
self.edge_ids_by_edge_nr = (<gint*>self._edge_ids_by_edge_nr
.data.as_ints)
if (not self.heads_idxs or not self.heads
or (twoway and (not self.tails_idxs
or not self.tails
......@@ -695,10 +706,28 @@ cdef class CGraph_malloc(CGraph):
or (edge_nr_translation and not self.edge_ids_by_edge_nr)):
raise MemoryError
def __dealloc__(self):
free(self.edge_ids_by_edge_nr)
free(self.heads)
free(self.edge_ids)
free(self.tails)
free(self.tails_idxs)
free(self.heads_idxs)
def __getstate__(self):
twoway = self.twoway
edge_nr_translation = self.edge_nr_translation
num_nodes = self.num_nodes
num_np_edges = self.edge_nr_end - self.num_px_edges
if twoway:
num_pp_edges = self.num_xp_edges - num_np_edges
else:
num_pp_edges = self.num_xp_edges
num_pn_edges = self.num_px_edges - num_pp_edges
init_args = (twoway, edge_nr_translation, num_nodes,
num_pp_edges, num_pn_edges, num_np_edges)
return (init_args, self._heads_idxs, self._heads, self._tails_idxs,
self._tails, self._edge_ids, self._edge_ids_by_edge_nr)
def __setstate__(self, state):
self.__init__(*state[0])
array_attributes = (self._heads_idxs, self._heads, self._tails_idxs,
self._tails, self._edge_ids,
self._edge_ids_by_edge_nr)
for attribute, value in zip(array_attributes, state[1:]):
if attribute is None:
continue
attribute[:] = value
# Copyright 2011-2013 Kwant authors.
# Copyright 2011-2016 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -6,6 +6,7 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import pickle
from io import StringIO
from itertools import zip_longest
import numpy as np
......@@ -179,3 +180,28 @@ def test_edge_ids():
g = gr.compressed(edge_nr_translation=True, allow_lost_edges=True)
raises(EdgeDoesNotExistError, g.edge_id, 1)
def test_pickle():
gr = Graph(allow_negative_nodes=True)
edges = [(0, -1), (-1, 0), (1, 2), (1, 2), (0, -1), (-1, 0), (-1, 0)]
gr.add_edges(edges)
g = gr.compressed(twoway=True, edge_nr_translation=True)
g2 = pickle.loads(pickle.dumps(g))
s = StringIO('')
g.write_dot(s)
s2 = StringIO('')
g2.write_dot(s2)
assert s.getvalue() == s2.getvalue()
assert g.__getstate__() == g2.__getstate__()
gr = Graph(allow_negative_nodes=False)
edges = [(0, 1), (1, 2), (1, 2), (0, 2)]
g = gr.compressed(twoway=False, edge_nr_translation=False)
g2 = pickle.loads(pickle.dumps(g))
s = StringIO('')
g.write_dot(s)
s2 = StringIO('')
g2.write_dot(s2)
assert s.getvalue() == s2.getvalue()
assert g.__getstate__() == g2.__getstate__()
# Copyright 2011-2013 Kwant authors.
# Copyright 2011-2016 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -6,6 +6,8 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import pickle
import copy
from pytest import raises
import numpy as np
from scipy import sparse
......@@ -99,3 +101,23 @@ def test_hamiltonian_submatrix():
mat = mat[perm, :]
mat = mat[:, perm]
np.testing.assert_array_equal(mat, mat_should_be)
def test_pickling():
syst = kwant.Builder()
lead = kwant.Builder(symmetry=kwant.TranslationalSymmetry([1.]))
lat = kwant.lattice.chain()
syst[lat(0)] = syst[lat(1)] = 0
syst[lat(0), lat(1)] = 1
lead[lat(0)] = syst[lat(1)] = 0
lead[lat(0), lat(1)] = 1
syst.attach_lead(lead)
syst.attach_lead(lead.reversed())
syst_copy1 = copy.copy(syst).finalized()
syst_copy2 = pickle.loads(pickle.dumps(syst)).finalized()
syst = syst.finalized()
syst_copy3 = copy.copy(syst)
syst_copy4 = pickle.loads(pickle.dumps(syst))
s = kwant.smatrix(syst, 0.1)
for other in (syst_copy1, syst_copy2, syst_copy3, syst_copy4):
assert np.all(kwant.smatrix(other, 0.1).data == s.data)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment