Skip to content

Commit

Permalink
Fix pickle/deepcopy not preserve original edge indices (#589)
Browse files Browse the repository at this point in the history
* fix issue #585 that pickling graph & digraph do not preserve original edge index

* fix clippy lints - collapsible_else_if

* Simplify logic in __setstate__

* Add release note

* Fix lint

---------

Co-authored-by: Matthew Treinish <[email protected]>
  • Loading branch information
binh-vu and mtreinish authored May 9, 2023
1 parent 5a3f9b3 commit 8686896
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 128 deletions.
10 changes: 10 additions & 0 deletions releasenotes/notes/fix-edge-indices-pickle-83fddf149441fa9f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
fixes:
- |
Fixed an issue when using ``copy.deepcopy()`` on :class:`~.PyDiGraph` and
:class:`~.PyGraph` objects when there were removed edges from the graph
object. Previously, if there were any holes in the edge indices caused by
the removal the output copy of the graph object would incorrectly have
flatten the indices. This has been corrected so that the edge indices are
recreated exactly after a ``deepcopy()``.
Fixed `#585 <https://github.com/Qiskit/rustworkx/issues/585>`__
231 changes: 165 additions & 66 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;

use petgraph::visit::{
GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable,
EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered,
Visitable,
};

Expand Down Expand Up @@ -298,97 +298,196 @@ impl PyDiGraph {
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let mut nodes: Vec<PyObject> = Vec::with_capacity(self.graph.node_count());
let mut edges: Vec<PyObject> = Vec::with_capacity(self.graph.edge_bound());

// save nodes to a list along with its index
for node_idx in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_idx).unwrap();
nodes.push((node_idx.index(), node_data).to_object(py));
}

// edges are saved with none (deleted edges) instead of their index to save space
for i in 0..self.graph.edge_bound() {
let idx = EdgeIndex::new(i);
let edge = match self.graph.edge_weight(idx) {
Some(edge_w) => {
let endpoints = self.graph.edge_endpoints(idx).unwrap();
(endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py)
}
None => py.None(),
};
edges.push(edge);
}

let out_dict = PyDict::new(py);
let node_dict = PyDict::new(py);
let mut out_list: Vec<PyObject> = Vec::with_capacity(self.graph.edge_count());
out_dict.set_item("nodes", node_dict)?;
let nodes_lst: PyObject = PyList::new(py, nodes).into();
let edges_lst: PyObject = PyList::new(py, edges).into();
out_dict.set_item("nodes", nodes_lst)?;
out_dict.set_item("edges", edges_lst)?;
out_dict.set_item("nodes_removed", self.node_removed)?;
out_dict.set_item("multigraph", self.multigraph)?;
out_dict.set_item("attrs", self.attrs.clone_ref(py))?;
out_dict.set_item("check_cycle", self.check_cycle)?;
let dir = petgraph::Direction::Incoming;
for node_index in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_index).unwrap();
node_dict.set_item(node_index.index(), node_data)?;
for edge in self.graph.edges_directed(node_index, dir) {
let edge_w = edge.weight();
let triplet = (edge.source().index(), edge.target().index(), edge_w).to_object(py);
out_list.push(triplet);
}
}
let py_out_list: PyObject = PyList::new(py, out_list).into();
out_dict.set_item("edges", py_out_list)?;
Ok(out_dict.into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
let dict_state = state.downcast::<PyDict>(py)?;
let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::<PyList>()?;
let edges_lst = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
self.graph = StablePyGraph::<Directed>::new();
let dict_state = state.downcast::<PyDict>(py)?;

let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
let nodes_removed_raw = dict_state
.get_item("nodes_removed")
.unwrap()
.downcast::<PyBool>()?;
self.node_removed = nodes_removed_raw.extract()?;
let multigraph_raw = dict_state
self.multigraph = dict_state
.get_item("multigraph")
.unwrap()
.downcast::<PyBool>()?;
self.multigraph = multigraph_raw.extract()?;
.downcast::<PyBool>()?
.extract()?;
self.node_removed = dict_state
.get_item("nodes_removed")
.unwrap()
.downcast::<PyBool>()?
.extract()?;
let attrs = match dict_state.get_item("attrs") {
Some(attr) => attr.into(),
None => py.None(),
};
self.attrs = attrs;
let check_cycle_raw = dict_state
self.check_cycle = dict_state
.get_item("check_cycle")
.unwrap()
.downcast::<PyBool>()?;
self.check_cycle = check_cycle_raw.extract()?;
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
node_indices.push(tmp_index.extract()?);
}
if node_indices.is_empty() {
.downcast::<PyBool>()?
.extract()?;

// graph is empty, stop early
if nodes_lst.is_empty() {
return Ok(());
}
let max_index: usize = *node_indices.iter().max().unwrap();
if max_index + 1 != node_indices.len() {
self.node_removed = true;
}
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {

if !self.node_removed {
for item in nodes_lst.iter() {
let node_w = item
.downcast::<PyTuple>()
.unwrap()
.get_item(1)
.unwrap()
.extract()
.unwrap();
self.graph.add_node(node_w);
}
} else if nodes_lst.len() == 1 {
// graph has only one node, handle logic here to save one if in the loop later
let item = nodes_lst
.get_item(0)
.unwrap()
.downcast::<PyTuple>()
.unwrap();
let node_idx: usize = item.get_item(0).unwrap().extract().unwrap();
let node_w = item.get_item(1).unwrap().extract().unwrap();

for _i in 0..node_idx {
self.graph.add_node(py.None());
}
self.graph.add_node(node_w);
for i in 0..node_idx {
self.graph.remove_node(NodeIndex::new(i));
}
} else {
let last_item = nodes_lst
.get_item(nodes_lst.len() - 1)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

// use a pointer to iter the node list
let mut pointer = 0;
let mut next_node_idx: usize = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();

// list of temporary nodes that will be removed later to re-create holes
let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap();
let mut tmp_nodes: Vec<NodeIndex> =
Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len());

for i in 0..nodes_lst.len() + 1 {
if i < next_node_idx {
// node does not exist
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
} else {
// add node to the graph, and update the next available node index
let item = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

let node_w = item.get_item(1).unwrap().extract().unwrap();
self.graph.add_node(node_w);
pointer += 1;
if pointer < nodes_lst.len() {
next_node_idx = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
}
}
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0)?.downcast::<PyLong>()?;
let p_index: usize = raw_p_index.extract()?;
let raw_c_index = edge.get_item(1)?.downcast::<PyLong>()?;
let c_index: usize = raw_c_index.extract()?;
let edge_data = edge.get_item(2)?;
self.graph.add_edge(
NodeIndex::new(p_index),
NodeIndex::new(c_index),
edge_data.into(),
);
}
// Remove any temporary nodes we added
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
}

// to ensure O(1) on edge deletion, use a temporary node to store missing edges
let tmp_node = self.graph.add_node(py.None());

for item in edges_lst {
if item.is_none() {
// add a temporary edge that will be deleted later to re-create the hole
self.graph.add_edge(tmp_node, tmp_node, py.None());
} else {
let triple = item.downcast::<PyTuple>().unwrap();
let edge_p: usize = triple
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_c: usize = triple
.get_item(1)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_w = triple.get_item(2).unwrap().extract().unwrap();
self.graph
.add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w);
}
}

// remove the temporary node will remove all deleted edges in bulk,
// the cost is equal to the number of edges
self.graph.remove_node(tmp_node);

Ok(())
}

Expand Down
Loading

0 comments on commit 8686896

Please sign in to comment.