Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pickle/deepcopy not preserve original edge indices #589

Merged
merged 8 commits into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 185 additions & 67 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ use petgraph::graph::{EdgeIndex, NodeIndex};
use petgraph::prelude::*;

use petgraph::visit::{
GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered, NodeIndexable,
Visitable,
EdgeIndexable, GraphBase, IntoEdgeReferences, IntoNodeReferences, NodeCount, NodeFiltered,
NodeIndexable, Visitable,
};

use super::dot_utils::build_dot;
Expand Down Expand Up @@ -273,85 +273,203 @@ impl PyDiGraph {
}

fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
let out_dict = PyDict::new(py);
let node_dict = PyDict::new(py);
let mut out_list: Vec<PyObject> = Vec::with_capacity(self.graph.edge_count());
out_dict.set_item("nodes", node_dict)?;
out_dict.set_item("nodes_removed", self.node_removed)?;
out_dict.set_item("multigraph", self.multigraph)?;
let dir = petgraph::Direction::Incoming;
for node_index in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_index).unwrap();
node_dict.set_item(node_index.index(), node_data)?;
for edge in self.graph.edges_directed(node_index, dir) {
let edge_w = edge.weight();
let triplet = (edge.source().index(), edge.target().index(), edge_w).to_object(py);
out_list.push(triplet);
}
let mut nodes: Vec<PyObject> = Vec::with_capacity(self.graph.node_count());
let mut edges: Vec<PyObject> = Vec::with_capacity(self.graph.edge_bound());

// save nodes to a list along with its index
for node_idx in self.graph.node_indices() {
let node_data = self.graph.node_weight(node_idx).unwrap();
nodes.push((node_idx.index(), node_data).to_object(py));
}

// edges are saved with none (deleted edges) instead of their index to save space
for i in 0..self.graph.edge_bound() {
let idx = EdgeIndex::new(i);
let edge = match self.graph.edge_weight(idx) {
Some(edge_w) => {
let endpoints = self.graph.edge_endpoints(idx).unwrap();
(endpoints.0.index(), endpoints.1.index(), edge_w).to_object(py)
}
None => py.None(),
};
edges.push(edge);
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
}
let py_out_list: PyObject = PyList::new(py, out_list).into();
out_dict.set_item("edges", py_out_list)?;
Ok(out_dict.into())

let outdict = PyDict::new(py);
let nodes_lst: PyObject = PyList::new(py, nodes).into();
let edges_lst: PyObject = PyList::new(py, edges).into();
outdict.set_item("nodes", nodes_lst)?;
outdict.set_item("edges", edges_lst)?;
outdict.set_item(
"node_removed",
(self.graph.node_bound() != self.graph.node_count()).to_object(py),
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
)?;
outdict.set_item("multigraph", self.multigraph)?;
Ok(outdict.into())
}

fn __setstate__(&mut self, py: Python, state: PyObject) -> PyResult<()> {
self.graph = StablePyGraph::<Directed>::new();
let dict_state = state.cast_as::<PyDict>(py)?;
let nodes_lst = dict_state.get_item("nodes").unwrap().downcast::<PyList>()?;
let edges_lst = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
mtreinish marked this conversation as resolved.
Show resolved Hide resolved

let nodes_dict = dict_state.get_item("nodes").unwrap().downcast::<PyDict>()?;
let edges_list = dict_state.get_item("edges").unwrap().downcast::<PyList>()?;
let nodes_removed_raw = dict_state
.get_item("nodes_removed")
.unwrap()
.downcast::<PyBool>()?;
self.node_removed = nodes_removed_raw.extract()?;
let multigraph_raw = dict_state
self.graph = StablePyGraph::<Directed>::new();
self.multigraph = dict_state
.get_item("multigraph")
.unwrap()
.downcast::<PyBool>()?;
self.multigraph = multigraph_raw.extract()?;
let mut node_indices: Vec<usize> = Vec::new();
for raw_index in nodes_dict.keys() {
let tmp_index = raw_index.downcast::<PyLong>()?;
node_indices.push(tmp_index.extract()?);
}
if node_indices.is_empty() {
.downcast::<PyBool>()?
.extract()?;
self.node_removed = dict_state
.get_item("node_removed")
.unwrap()
.downcast::<PyBool>()?
.extract()?;

// graph is empty, stop early
if nodes_lst.is_empty() {
return Ok(());
}
let max_index: usize = *node_indices.iter().max().unwrap();
if max_index + 1 != node_indices.len() {
self.node_removed = true;
}
let mut tmp_nodes: Vec<NodeIndex> = Vec::new();
let mut node_count: usize = 0;
while max_index >= self.graph.node_bound() {
match nodes_dict.get_item(node_count) {
Some(raw_data) => {
self.graph.add_node(raw_data.into());
}
None => {

if !self.node_removed {
for item in nodes_lst.iter() {
let node_w = item
.downcast::<PyTuple>()
.unwrap()
.get_item(1)
.unwrap()
.extract()
.unwrap();
self.graph.add_node(node_w);
}
} else if nodes_lst.len() == 1 {
// graph has only one node, handle logic here to save one if in the loop later
let item = nodes_lst
.get_item(0)
.unwrap()
.downcast::<PyTuple>()
.unwrap();
let node_idx: usize = item.get_item(0).unwrap().extract().unwrap();
let node_w = item.get_item(1).unwrap().extract().unwrap();

for _i in 0..node_idx {
self.graph.add_node(py.None());
}
self.graph.add_node(node_w);
for i in 0..node_idx {
self.graph.remove_node(NodeIndex::new(i));
}
} else {
let last_item = nodes_lst
.get_item(nodes_lst.len() - 1)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

// use a pointer to iter the node list
let mut pointer = 0;
let mut next_node_idx: usize = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();

// list of temporary nodes that will be removed later to re-create holes
let node_bound_1: usize = last_item.get_item(0).unwrap().extract().unwrap();
let mut tmp_nodes: Vec<NodeIndex> =
Vec::with_capacity(node_bound_1 + 1 - nodes_lst.len());

let second_last_node_idx: usize = nodes_lst
.get_item(nodes_lst.len() - 2)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be a bit easier to read if we created an intermediate Vec<(usize, PyObject)> here. It will eat a bit more memory since we're storing things twice but I'm not sure that's a big concern since we're just doubling the number of node indices in memory as the python data is just a reference. Just a thought, this works fine and we don't really need to change it, it's just a long chain of unwraps and extracts.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that too. What I was thinking is to create inline functions for get_index, and get_value. But I don't know where to put it. May be at the end of the file?


for i in 0..(second_last_node_idx + 1) {
if i < next_node_idx {
// node does not exist
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
} else {
// add node to the graph, and update the next available node index
let item = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap();

let node_w = item.get_item(1).unwrap().extract().unwrap();
self.graph.add_node(node_w);
pointer += 1;
next_node_idx = nodes_lst
.get_item(pointer)
.unwrap()
.downcast::<PyTuple>()
.unwrap()
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
}
};
node_count += 1;
}
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
for raw_edge in edges_list.iter() {
let edge = raw_edge.downcast::<PyTuple>()?;
let raw_p_index = edge.get_item(0)?.downcast::<PyLong>()?;
let p_index: usize = raw_p_index.extract()?;
let raw_c_index = edge.get_item(1)?.downcast::<PyLong>()?;
let c_index: usize = raw_c_index.extract()?;
let edge_data = edge.get_item(2)?;
self.graph.add_edge(
NodeIndex::new(p_index),
NodeIndex::new(c_index),
edge_data.into(),
);
}

for _i in (second_last_node_idx + 1)..next_node_idx {
let tmp_node = self.graph.add_node(py.None());
tmp_nodes.push(tmp_node);
}
let last_node_w = last_item.get_item(1).unwrap().extract().unwrap();
self.graph.add_node(last_node_w);
for tmp_node in tmp_nodes {
self.graph.remove_node(tmp_node);
}
mtreinish marked this conversation as resolved.
Show resolved Hide resolved
}

// to ensure O(1) on edge deletion, use a temporary node to store missing edges
let tmp_node = self.graph.add_node(py.None());

for item in edges_lst {
if item.is_none() {
// add a temporary edge that will be deleted later to re-create the hole
self.graph.add_edge(tmp_node, tmp_node, py.None());
} else {
let triple = item.downcast::<PyTuple>().unwrap();
let edge_p: usize = triple
.get_item(0)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_c: usize = triple
.get_item(1)
.unwrap()
.downcast::<PyLong>()
.unwrap()
.extract()
.unwrap();
let edge_w = triple.get_item(2).unwrap().extract().unwrap();
self.graph
.add_edge(NodeIndex::new(edge_p), NodeIndex::new(edge_c), edge_w);
}
}

// remove the temporary node will remove all deleted edges in bulk,
// the cost is equal to the number of edges
self.graph.remove_node(tmp_node);

Ok(())
}

Expand Down
Loading