diff --git a/retworkx/__init__.py b/retworkx/__init__.py index 87d404c79..ae60d6b50 100644 --- a/retworkx/__init__.py +++ b/retworkx/__init__.py @@ -38,10 +38,6 @@ class PyDAG(PyDiGraph): ensure that no cycles are added, ensuring that the PyDAG class truly represents a directed acyclic graph. - .. note:: - When using ``copy.deepcopy()`` or pickling node indexes are not - guaranteed to be preserved. - PyDAG is a subclass of the PyDiGraph class and behaves identically to the :class:`~retworkx.PyDiGraph` class. """ diff --git a/src/digraph.rs b/src/digraph.rs index 545f0d01d..0685a7d64 100644 --- a/src/digraph.rs +++ b/src/digraph.rs @@ -67,11 +67,6 @@ use super::{ /// With check_cycle set to true any calls to :meth:`PyDiGraph.add_edge` will /// ensure that no cycles are added, ensuring that the PyDiGraph class truly /// represents a directed acyclic graph. -/// -/// .. note:: -/// When using ``copy.deepcopy()`` or pickling node indexes are not -/// guaranteed to be preserved. -/// #[pyclass(module = "retworkx", subclass)] #[text_signature = "(/, check_cycle=False)"] pub struct PyDiGraph { @@ -334,35 +329,52 @@ impl PyDiGraph { Ok(out_dict.into()) } - fn __setstate__(&mut self, state: PyObject) -> PyResult<()> { - let mut node_mapping: HashMap = HashMap::new(); + fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> { self.graph = StableDiGraph::::new(); - let gil = Python::acquire_gil(); - let py = gil.python(); let dict_state = state.cast_as::(py)?; let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::()?; let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - for raw_index in nodes_dict.keys().iter() { + let mut node_indices: Vec = Vec::new(); + for raw_index in nodes_dict.keys() { let tmp_index = raw_index.downcast::()?; - let index: usize = tmp_index.extract()?; - let raw_data = nodes_dict.get_item(index).unwrap(); - let node_index = self.graph.add_node(raw_data.into()); - node_mapping.insert(index, node_index); + node_indices.push(tmp_index.extract()?); + } + let max_index: usize = *node_indices.iter().max().unwrap(); + if max_index + 1 != node_indices.len() { + self.node_removed = true; + } + let mut tmp_nodes: Vec = Vec::new(); + let mut node_count: usize = 0; + while max_index >= self.graph.node_bound() { + match nodes_dict.get_item(node_count) { + Some(raw_data) => { + self.graph.add_node(raw_data.into()); + } + None => { + let tmp_node = self.graph.add_node(py.None()); + tmp_nodes.push(tmp_node); + } + }; + node_count += 1; + } + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); } for raw_edge in edges_list.iter() { let edge = raw_edge.downcast::()?; let raw_p_index = edge.get_item(0).downcast::()?; - let tmp_p_index: usize = raw_p_index.extract()?; + let p_index: usize = raw_p_index.extract()?; let raw_c_index = edge.get_item(1).downcast::()?; - let tmp_c_index: usize = raw_c_index.extract()?; + let c_index: usize = raw_c_index.extract()?; let edge_data = edge.get_item(2); - - let p_index = node_mapping.get(&tmp_p_index).unwrap(); - let c_index = node_mapping.get(&tmp_c_index).unwrap(); - self.graph.add_edge(*p_index, *c_index, edge_data.into()); + self.graph.add_edge( + NodeIndex::new(p_index), + NodeIndex::new(c_index), + edge_data.into(), + ); } Ok(()) } diff --git a/src/graph.rs b/src/graph.rs index edefc0fd3..e826f22e5 100644 --- a/src/graph.rs +++ b/src/graph.rs @@ -263,26 +263,29 @@ impl PyGraph { dict_state.get_item("nodes").unwrap().downcast::()?; let edges_list = dict_state.get_item("edges").unwrap().downcast::()?; - let mut index_count = 0; - for raw_index in nodes_dict.keys().iter() { + let mut node_indices: Vec = Vec::new(); + for raw_index in nodes_dict.keys() { let tmp_index = raw_index.downcast::()?; - let index: usize = tmp_index.extract()?; - let mut tmp_nodes: Vec = Vec::new(); - if index > index_count + 1 { - let diff = index - (index_count + 1); - for _ in 0..diff { + node_indices.push(tmp_index.extract()?); + } + let max_index: usize = *node_indices.iter().max().unwrap(); + let mut tmp_nodes: Vec = Vec::new(); + let mut node_count: usize = 0; + while max_index >= self.graph.node_bound() { + match nodes_dict.get_item(node_count) { + Some(raw_data) => { + self.graph.add_node(raw_data.into()); + } + None => { let tmp_node = self.graph.add_node(py.None()); tmp_nodes.push(tmp_node); } - } - let raw_data = nodes_dict.get_item(index).unwrap(); - let out_index = self.graph.add_node(raw_data.into()); - for tmp_node in tmp_nodes { - self.graph.remove_node(tmp_node); - } - index_count = out_index.index(); + }; + node_count += 1; + } + for tmp_node in tmp_nodes { + self.graph.remove_node(tmp_node); } - for raw_edge in edges_list.iter() { let edge = raw_edge.downcast::()?; let raw_p_index = edge.get_item(0).downcast::()?; diff --git a/tests/test_deepcopy.py b/tests/test_deepcopy.py index 08a96cf1d..b70268f32 100644 --- a/tests/test_deepcopy.py +++ b/tests/test_deepcopy.py @@ -27,3 +27,15 @@ def test_isomorphic_compare_nodes_identical(self): self.assertTrue( retworkx.is_isomorphic_node_match( dag_a, dag_b, lambda x, y: x == y)) + + def test_deepcopy_with_holes(self): + dag_a = retworkx.PyDAG() + node_a = dag_a.add_node('a_1') + node_b = dag_a.add_node('a_2') + dag_a.add_edge(node_a, node_b, 'edge_1') + node_c = dag_a.add_node('a_3') + dag_a.add_edge(node_b, node_c, 'edge_2') + dag_a.remove_node(node_b) + dag_b = copy.deepcopy(dag_a) + self.assertIsInstance(dag_b, retworkx.PyDAG) + self.assertEqual([node_a, node_c], dag_b.node_indexes())