From 205985676b5b5967d6736be593e9948a2a55e480 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Tue, 18 Aug 2020 16:31:06 -0400 Subject: [PATCH 1/2] Fix handling of index holes in PyDiGraph pickling This commit fixes an issue with pickling (and by extension deepcopy) PyDiGraph or PyDAG objects that have holes in their node id lists. Previously the node ids would not be preserved across pickling leading to a compacted list instead of the original node ids. For example, if you had a PyDiGraph with node ids [0, 1, 3] after pickling/deepcopy the node ids would be [0, 1, 2] but otherwise identical. This commit fixes this issue by adding a check for holes to __setstate__ method and incrementing the node id to reproduce a 1:1 mapping with the original node ids prior to pickling. --- Cargo.lock | 2 +- retworkx/__init__.py | 4 ---- src/digraph.rs | 33 ++++++++++++++++++++------------- tests/test_deepcopy.py | 12 ++++++++++++ 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ca279d4bd..601bd5ef0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -409,7 +409,7 @@ checksum = "41cc0f7e4d5d4544e8861606a285bb08d3e70712ccc7d2b84d7c0ccfaf4b05ce" [[package]] name = "retworkx" -version = "0.4.0" +version = "0.5.0" dependencies = [ "fixedbitset", "hashbrown", diff --git a/retworkx/__init__.py b/retworkx/__init__.py index 0c60929eb..03fe0819d 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -35,10 +35,6 @@ class PyDAG(PyDiGraph): ensure that no cycles are added, ensuring that the PyDAG class truly represents a directed acyclic graph. - .. note:: - When using ``copy.deepcopy()`` or pickling node indexes are not - guaranteed to be preserved. - PyDAG is a subclass of the PyDiGraph class and behaves identically to the :class:`~retworkx.PyDiGraph` class. """ diff --git a/src/digraph.rs b/src/digraph.rs index d9b3ec2fc..95f3fcccb 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -64,11 +64,6 @@ use super::{ /// With check_cycle set to true any calls to :meth:`PyDiGraph.add_edge` will /// ensure that no cycles are added, ensuring that the PyDiGraph class truly /// represents a directed acyclic graph. -/// -/// .. note:: -/// When using ``copy.deepcopy()`` or pickling node indexes are not -/// guaranteed to be preserved. -/// #[pyclass(module = "retworkx", subclass)] #[text_signature = "(/, check_cycle=False)"] pub struct PyDiGraph { @@ -326,7 +321,6 @@ impl PyDiGraph { } fn __setstate__(&mut self, state: PyObject) -> PyResult<()> { - let mut node_mapping: HashMap = HashMap::new(); self.graph = StableDiGraph::::new(); let gil = Python::acquire_gil(); let py = gil.python(); @@ -336,24 +330,37 @@ impl PyDiGraph { dict_state.get_item("nodes").unwrap().downcast::()?; let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; + let mut index_count = 0; for raw_index in nodes_dict.keys().iter() { let tmp_index = raw_index.downcast::()?; let index: usize = tmp_index.extract()?; let raw_data = nodes_dict.get_item(index).unwrap(); + let mut tmp_nodes: Vec = Vec::new(); + if index > index_count + 1 { + let diff = index - (index_count + 1); + for _ in 0..diff { + let tmp_node = self.graph.add_node(py.None()); + tmp_nodes.push(tmp_node); + } + } let node_index = self.graph.add_node(raw_data.into()); - node_mapping.insert(index, node_index); + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); + } + index_count = node_index.index(); } for raw_edge in edges_list.iter() { let edge = raw_edge.downcast::()?; let raw_p_index = edge.get_item(0).downcast::()?; - let tmp_p_index: usize = raw_p_index.extract()?; + let p_index: usize = raw_p_index.extract()?; let raw_c_index = edge.get_item(1).downcast::()?; - let tmp_c_index: usize = raw_c_index.extract()?; + let c_index: usize = raw_c_index.extract()?; let edge_data = edge.get_item(2); - - let p_index = node_mapping.get(&tmp_p_index).unwrap(); - let c_index = node_mapping.get(&tmp_c_index).unwrap(); - self.graph.add_edge(*p_index, *c_index, edge_data.into()); + self.graph.add_edge( + NodeIndex::new(p_index), + NodeIndex::new(c_index), + edge_data.into(), + ); } Ok(()) } diff --git a/tests/test_deepcopy.py b/tests/test_deepcopy.py index 08a96cf1d..b70268f32 100644 --- a/tests/test_deepcopy.py +++ b/tests/test_deepcopy.py @@ -27,3 +27,15 @@ def test_isomorphic_compare_nodes_identical(self): self.assertTrue( retworkx.is_isomorphic_node_match( dag_a, dag_b, lambda x, y: x == y)) + + def test_deepcopy_with_holes(self): + dag_a = retworkx.PyDAG() + node_a = dag_a.add_node('a_1') + node_b = dag_a.add_node('a_2') + dag_a.add_edge(node_a, node_b, 'edge_1') + node_c = dag_a.add_node('a_3') + dag_a.add_edge(node_b, node_c, 'edge_2') + dag_a.remove_node(node_b) + dag_b = copy.deepcopy(dag_a) + self.assertIsInstance(dag_b, retworkx.PyDAG) + self.assertEqual([node_a, node_c], dag_b.node_indexes()) From e41ffa5898afcb481dfe601d2829f5d925df3235 Mon Sep 17 00:00:00 2001 From: Matthew Treinish Date: Sat, 19 Sep 2020 16:08:57 -0400 Subject: [PATCH 2/2] Fix logic bug handling holes --- src/digraph.rs | 39 ++++++++++++++++++++++----------------- src/graph.rs | 33 ++++++++++++++++++--------------- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/digraph.rs b/src/digraph.rs index 800994079..0685a7d64 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -329,34 +329,39 @@ impl PyDiGraph { Ok(out_dict.into()) } - fn __setstate__(&mut self, state: PyObject) -> PyResult<()> { + fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { self.graph = StableDiGraph::::new(); - let gil = Python::acquire_gil(); - let py = gil.python(); let dict_state = state.cast_as::(py)?; let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::()?; let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - let mut index_count = 0; - for raw_index in nodes_dict.keys().iter() { + let mut node_indices: Vec = Vec::new(); + for raw_index in nodes_dict.keys() { let tmp_index = raw_index.downcast::()?; - let index: usize = tmp_index.extract()?; - let raw_data = nodes_dict.get_item(index).unwrap(); - let mut tmp_nodes: Vec = Vec::new(); - if index > index_count + 1 { - let diff = index - (index_count + 1); - for _ in 0..diff { + node_indices.push(tmp_index.extract()?); + } + let max_index: usize = *node_indices.iter().max().unwrap(); + if max_index + 1 != node_indices.len() { + self.node_removed = true; + } + let mut tmp_nodes: Vec = Vec::new(); + let mut node_count: usize = 0; + while max_index >= self.graph.node_bound() { + match nodes_dict.get_item(node_count) { + Some(raw_data) => { + self.graph.add_node(raw_data.into()); + } + None => { let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); } - } - let node_index = self.graph.add_node(raw_data.into()); - for tmp_node in tmp_nodes { - self.graph.remove_node(tmp_node); - } - index_count = node_index.index(); + }; + node_count += 1; + } + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); } for raw_edge in edges_list.iter() { let edge = raw_edge.downcast::()?; diff --git a/src/graph.rs b/src/graph.rs index edefc0fd3..e826f22e5 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -263,26 +263,29 @@ impl PyGraph { dict_state.get_item("nodes").unwrap().downcast::()?; let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - let mut index_count = 0; - for raw_index in nodes_dict.keys().iter() { + let mut node_indices: Vec = Vec::new(); + for raw_index in nodes_dict.keys() { let tmp_index = raw_index.downcast::()?; - let index: usize = tmp_index.extract()?; - let mut tmp_nodes: Vec = Vec::new(); - if index > index_count + 1 { - let diff = index - (index_count + 1); - for _ in 0..diff { + node_indices.push(tmp_index.extract()?); + } + let max_index: usize = *node_indices.iter().max().unwrap(); + let mut tmp_nodes: Vec = Vec::new(); + let mut node_count: usize = 0; + while max_index >= self.graph.node_bound() { + match nodes_dict.get_item(node_count) { + Some(raw_data) => { + self.graph.add_node(raw_data.into()); + } + None => { let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); } - } - let raw_data = nodes_dict.get_item(index).unwrap(); - let out_index = self.graph.add_node(raw_data.into()); - for tmp_node in tmp_nodes { - self.graph.remove_node(tmp_node); - } - index_count = out_index.index(); + }; + node_count += 1; + } + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); } - for raw_edge in edges_list.iter() { let edge = raw_edge.downcast::()?; let raw_p_index = edge.get_item(0).downcast::()?;