diff --git a/kwant/graph/core.pyx b/kwant/graph/core.pyx index 1f9c15f2249d7a186eb2134e3bed4b0d9d930dcc..a5175bee100b26328fc4fa8fd516799c5f0fa14b 100644 --- a/kwant/graph/core.pyx +++ b/kwant/graph/core.pyx @@ -719,15 +719,22 @@ cdef class CGraph_malloc(CGraph): init_args = (twoway, edge_nr_translation, num_nodes, num_pp_edges, num_pn_edges, num_np_edges) - return (init_args, self._heads_idxs, self._heads, self._tails_idxs, - self._tails, self._edge_ids, self._edge_ids_by_edge_nr) + return init_args, (self._heads_idxs, self._heads, self._tails_idxs, + self._tails, self._edge_ids, + self._edge_ids_by_edge_nr) def __setstate__(self, state): - self.__init__(*state[0]) + init_args, arrays = state + self.__init__(*init_args) array_attributes = (self._heads_idxs, self._heads, self._tails_idxs, self._tails, self._edge_ids, self._edge_ids_by_edge_nr) - for attribute, value in zip(array_attributes, state[1:]): + for attribute, value in zip(array_attributes, arrays): if attribute is None: continue attribute[:] = value + + # We are required to implement this as of Cython 0.26 + def __reduce__(self): + state = init_args, _ = self.__getstate__() + return (CGraph_malloc, init_args, state, None, None)