Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix handling of index holes in PyDiGraph pickling #116

Merged
merged 5 commits into from
Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions retworkx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ class PyDAG(PyDiGraph):
ensure that no cycles are added, ensuring that the PyDAG class truly
represents a directed acyclic graph.

.. note::
When using ``copy.deepcopy()`` or pickling node indexes are not
guaranteed to be preserved.

PyDAG is a subclass of the PyDiGraph class and behaves identically to
the :class:`~retworkx.PyDiGraph` class.
"""
Expand Down
52 changes: 32 additions & 20 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,6 @@ use super::{
/// With check_cycle set to true any calls to :meth:`PyDiGraph.add_edge` will
/// ensure that no cycles are added, ensuring that the PyDiGraph class truly
/// represents a directed acyclic graph.
///
/// .. note::
/// When using ``copy.deepcopy()`` or pickling node indexes are not
/// guaranteed to be preserved.
///
#[pyclass(module = "retworkx", subclass)]
#[text_signature = "(/, check_cycle=False)"]
pub struct PyDiGraph {
Expand Down Expand Up @@ -334,35 +329,52 @@ impl PyDiGraph {
Ok(out_dict.into())
}

fn __setstate__(&mut self, state: PyObject) -> PyResult<()> {
let mut node_mapping: HashMap<usize, NodeIndex> = HashMap::new();
fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
self.graph = StableDiGraph::<PyObject, PyObject>::new();
let gil = Python::acquire_gil();
let py = gil.python();
let dict_state = state.cast_as::<PyDict>(py)?;

let nodes_dict =
dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list =
dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
for raw_index in nodes_dict.keys().iter() {
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
let index: usize = tmp_index.extract()?;
let raw_data = nodes_dict.get_item(index).unwrap();
let node_index = self.graph.add_node(raw_data.into());
node_mapping.insert(index, node_index);
node_indices.push(tmp_index.extract()?);
}
let max_index: usize = *node_indices.iter().max().unwrap();
if max_index + 1 != node_indices.len() {
self.node_removed = true;
}
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
}
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0).downcast::<PyLong>()?;
let tmp_p_index: usize = raw_p_index.extract()?;
let p_index: usize = raw_p_index.extract()?;
let raw_c_index = edge.get_item(1).downcast::<PyLong>()?;
let tmp_c_index: usize = raw_c_index.extract()?;
let c_index: usize = raw_c_index.extract()?;
let edge_data = edge.get_item(2);

let p_index = node_mapping.get(&tmp_p_index).unwrap();
let c_index = node_mapping.get(&tmp_c_index).unwrap();
self.graph.add_edge(*p_index, *c_index, edge_data.into());
self.graph.add_edge(
NodeIndex::new(p_index),
NodeIndex::new(c_index),
edge_data.into(),
);
}
Ok(())
}
Expand Down
33 changes: 18 additions & 15 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,26 +263,29 @@ impl PyGraph {
dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list =
dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
let mut index_count = 0;
for raw_index in nodes_dict.keys().iter() {
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
let index: usize = tmp_index.extract()?;
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
if index > index_count + 1 {
let diff = index - (index_count + 1);
for _ in 0..diff {
node_indices.push(tmp_index.extract()?);
}
let max_index: usize = *node_indices.iter().max().unwrap();
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
}
}
let raw_data = nodes_dict.get_item(index).unwrap();
let out_index = self.graph.add_node(raw_data.into());
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
index_count = out_index.index();
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}

for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0).downcast::<PyLong>()?;
Expand Down
12 changes: 12 additions & 0 deletions tests/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,15 @@ def test_isomorphic_compare_nodes_identical(self):
self.assertTrue(
retworkx.is_isomorphic_node_match(
dag_a, dag_b, lambda x, y: x == y))

def test_deepcopy_with_holes(self):
dag_a = retworkx.PyDAG()
node_a = dag_a.add_node('a_1')
node_b = dag_a.add_node('a_2')
dag_a.add_edge(node_a, node_b, 'edge_1')
node_c = dag_a.add_node('a_3')
dag_a.add_edge(node_b, node_c, 'edge_2')
dag_a.remove_node(node_b)
dag_b = copy.deepcopy(dag_a)
self.assertIsInstance(dag_b, retworkx.PyDAG)
self.assertEqual([node_a, node_c], dag_b.node_indexes())