diff --git a/releasenotes/notes/fix-digraph-find-cycle-141e302ff4a8fcd4.yaml b/releasenotes/notes/fix-digraph-find-cycle-141e302ff4a8fcd4.yaml new file mode 100644 index 000000000..578459129 --- /dev/null +++ b/releasenotes/notes/fix-digraph-find-cycle-141e302ff4a8fcd4.yaml @@ -0,0 +1,12 @@ +--- +fixes: + - | + Fixed the behavior of :func:`~rustworkx.digraph_find_cycle` when + no source node was provided. Previously, the function would start looking + for a cycle at an arbitrary node which was not guaranteed to return a cycle. + Now, the function will smartly choose a source node to start the search from + such that if a cycle exists, it will be found. +other: + - | + The `rustworkx-core` function `rustworkx_core::connectivity::find_cycle` now + requires the `petgraph::visit::Visitable` trait. diff --git a/rustworkx-core/src/connectivity/find_cycle.rs b/rustworkx-core/src/connectivity/find_cycle.rs index 4bbd755e5..c04847c50 100644 --- a/rustworkx-core/src/connectivity/find_cycle.rs +++ b/rustworkx-core/src/connectivity/find_cycle.rs @@ -11,8 +11,9 @@ // under the License. use hashbrown::{HashMap, HashSet}; +use petgraph::algo; use petgraph::visit::{ - EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, + EdgeCount, GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, Visitable, }; use petgraph::Direction::Outgoing; use std::hash::Hash; @@ -57,22 +58,22 @@ where G: GraphBase, G: NodeCount, G: EdgeCount, - for<'b> &'b G: GraphBase + IntoNodeIdentifiers + IntoNeighborsDirected, + for<'b> &'b G: + GraphBase + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable, G::NodeId: Eq + Hash, { // Find a cycle in the given graph and return it as a list of edges - let mut graph_nodes: HashSet = graph.node_identifiers().collect(); let mut cycle: Vec<(G::NodeId, G::NodeId)> = Vec::with_capacity(graph.edge_count()); - let temp_value: G::NodeId; - // If source is not set get an arbitrary node from the set of graph - // nodes we've not "examined" + // If source is not set get a node in an arbitrary cycle if it exists, + // otherwise return that there is no cycle let source_index = match source { Some(source_value) => source_value, - None => { - temp_value = *graph_nodes.iter().next().unwrap(); - graph_nodes.remove(&temp_value); - temp_value - } + None => match find_node_in_arbitrary_cycle(&graph) { + Some(node_in_cycle) => node_in_cycle, + None => { + return Vec::new(); + } + }, }; // Stack (ie "pushdown list") of vertices already in the spanning tree let mut stack: Vec = vec![source_index]; @@ -119,11 +120,47 @@ where cycle } +fn find_node_in_arbitrary_cycle(graph: &G) -> Option +where + G: GraphBase, + G: NodeCount, + G: EdgeCount, + for<'b> &'b G: + GraphBase + IntoNodeIdentifiers + IntoNeighborsDirected + Visitable, + G::NodeId: Eq + Hash, +{ + for scc in algo::kosaraju_scc(&graph) { + if scc.len() > 1 { + return Some(scc[0]); + } + } + for node in graph.node_identifiers() { + for neighbor in graph.neighbors_directed(node, Outgoing) { + if neighbor == node { + return Some(node); + } + } + } + None +} + #[cfg(test)] mod tests { use crate::connectivity::find_cycle; use petgraph::prelude::*; + // Utility to assert cycles in the response + macro_rules! assert_cycle { + ($g: expr, $cycle: expr) => {{ + for i in 0..$cycle.len() { + let (s, t) = $cycle[i]; + assert!($g.contains_edge(s, t)); + let (next_s, _) = $cycle[(i + 1) % $cycle.len()]; + assert_eq!(t, next_s); + } + }}; + } + #[test] fn test_find_cycle_source() { let edge_list = vec![ @@ -141,20 +178,13 @@ mod tests { (8, 9), ]; let graph = DiGraph::::from_edges(edge_list); - let mut res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0))) - .iter() - .map(|(s, t)| (s.index(), t.index())) - .collect(); - assert_eq!(res, [(0, 1), (1, 2), (2, 3), (3, 0)]); - res = find_cycle(&graph, Some(NodeIndex::new(1))) - .iter() - .map(|(s, t)| (s.index(), t.index())) - .collect(); - assert_eq!(res, [(1, 2), (2, 3), (3, 0), (0, 1)]); - res = find_cycle(&graph, Some(NodeIndex::new(5))) - .iter() - .map(|(s, t)| (s.index(), t.index())) - .collect(); + for i in [0, 1, 2, 3].iter() { + let idx = NodeIndex::new(*i); + let res = find_cycle(&graph, Some(idx)); + assert_cycle!(graph, res); + assert_eq!(res[0].0, idx); + } + let res = find_cycle(&graph, Some(NodeIndex::new(5))); assert_eq!(res, []); } @@ -176,10 +206,32 @@ mod tests { ]; let mut graph = DiGraph::::from_edges(edge_list); graph.add_edge(NodeIndex::new(1), NodeIndex::new(1), 0); - let res: Vec<(usize, usize)> = find_cycle(&graph, Some(NodeIndex::new(0))) - .iter() - .map(|(s, t)| (s.index(), t.index())) - .collect(); - assert_eq!(res, [(1, 1)]); + let res = find_cycle(&graph, Some(NodeIndex::new(0))); + assert_eq!(res[0].0, NodeIndex::new(1)); + assert_cycle!(graph, res); + } + + #[test] + fn test_self_loop_no_source() { + let edge_list = vec![(0, 1), (1, 2), (2, 3), (2, 2)]; + let graph = DiGraph::::from_edges(edge_list); + let res = find_cycle(&graph, None); + assert_cycle!(graph, res); + } + + #[test] + fn test_cycle_no_source() { + let edge_list = vec![(0, 1), (1, 2), (2, 3), (3, 4), (4, 2)]; + let graph = DiGraph::::from_edges(edge_list); + let res = find_cycle(&graph, None); + assert_cycle!(graph, res); + } + + #[test] + fn test_no_cycle_no_source() { + let edge_list = vec![(0, 1), (1, 2), (2, 3)]; + let graph = DiGraph::::from_edges(edge_list); + let res = find_cycle(&graph, None); + assert_eq!(res, []); } } diff --git a/tests/digraph/test_find_cycle.py b/tests/digraph/test_find_cycle.py index 5e09d3520..0b9b67bab 100644 --- a/tests/digraph/test_find_cycle.py +++ b/tests/digraph/test_find_cycle.py @@ -13,6 +13,7 @@ import unittest import rustworkx +import rustworkx.generators class TestFindCycle(unittest.TestCase): @@ -36,6 +37,14 @@ def setUp(self): ] ) + def assertCycle(self, first_node, graph, res): + self.assertEqual(first_node, res[0][0]) + for i in range(len(res)): + s, t = res[i] + self.assertTrue(graph.has_edge(s, t)) + next_s, _ = res[(i + 1) % len(res)] + self.assertEqual(t, next_s) + def test_find_cycle(self): graph = rustworkx.PyDiGraph() graph.add_nodes_from(list(range(6))) @@ -43,13 +52,13 @@ def test_find_cycle(self): [(0, 1), (0, 3), (0, 5), (1, 2), (2, 3), (3, 4), (4, 5), (4, 0)] ) res = rustworkx.digraph_find_cycle(graph, 0) - self.assertEqual([(0, 1), (1, 2), (2, 3), (3, 4), (4, 0)], res) + self.assertCycle(0, graph, res) def test_find_cycle_multiple_roots_same_cycles(self): res = rustworkx.digraph_find_cycle(self.graph, 0) - self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)]) + self.assertCycle(0, self.graph, res) res = rustworkx.digraph_find_cycle(self.graph, 1) - self.assertEqual(res, [(1, 2), (2, 3), (3, 0), (0, 1)]) + self.assertCycle(1, self.graph, res) res = rustworkx.digraph_find_cycle(self.graph, 5) self.assertEqual(res, []) @@ -57,9 +66,9 @@ def test_find_cycle_disconnected_graphs(self): self.graph.add_nodes_from(["A", "B", "C"]) self.graph.add_edges_from_no_data([(10, 11), (12, 10), (11, 12)]) res = rustworkx.digraph_find_cycle(self.graph, 0) - self.assertEqual(res, [(0, 1), (1, 2), (2, 3), (3, 0)]) + self.assertCycle(0, self.graph, res) res = rustworkx.digraph_find_cycle(self.graph, 10) - self.assertEqual(res, [(10, 11), (11, 12), (12, 10)]) + self.assertCycle(10, self.graph, res) def test_invalid_types(self): graph = rustworkx.PyGraph() @@ -69,4 +78,28 @@ def test_invalid_types(self): def test_self_loop(self): self.graph.add_edge(1, 1, None) res = rustworkx.digraph_find_cycle(self.graph, 0) - self.assertEqual([(1, 1)], res) + self.assertCycle(1, self.graph, res) + + def test_no_cycle_no_source(self): + g = rustworkx.generators.directed_grid_graph(10, 10) + res = rustworkx.digraph_find_cycle(g) + self.assertEqual(res, []) + + def test_cycle_no_source(self): + g = rustworkx.generators.directed_path_graph(1000) + a = g.add_node(1000) + b = g.node_indices()[-2] + g.add_edge(b, a, None) + g.add_edge(a, b, None) + res = rustworkx.digraph_find_cycle(g) + self.assertEqual(len(res), 2) + self.assertTrue(res[0] == res[1][::-1]) + + def test_cycle_self_loop(self): + g = rustworkx.generators.directed_path_graph(1000) + a = g.add_node(1000) + b = g.node_indices()[-1] + g.add_edge(b, a, None) + g.add_edge(a, a, None) + res = rustworkx.digraph_find_cycle(g) + self.assertEqual(res, [(a, a)])