Skip to content

Commit

Permalink
Merge branch 'main' into platform-support-updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mtreinish authored Oct 6, 2023
2 parents 8e9d57c + 2d5694e commit 4cf03a7
Show file tree
Hide file tree
Showing 10 changed files with 202 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
features:
- |
Added a new function, :func:`clear` that clears all nodes and edges
from a :class:`rustworkx.PyGraph` or :class:`rustworkx.PyDiGraph`
- |
Added a new function, :func:`clear_edges` that clears all edges for
:class:`rustworkx.PyGraph` or :class:`rustworkx.PyDiGraph` without
modifying nodes
2 changes: 1 addition & 1 deletion rustworkx-core/src/centrality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ where
{
let alpha: f64 = alpha.unwrap_or(0.1);

let mut beta: HashMap<usize, f64> = beta_map.unwrap_or_else(HashMap::new);
let mut beta: HashMap<usize, f64> = beta_map.unwrap_or_default();

if beta.is_empty() {
// beta_map was none
Expand Down
20 changes: 20 additions & 0 deletions rustworkx-core/src/token_swapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ where
}
let id_node = self.rev_node_map[&node];
let id_token = self.rev_node_map[&tokens[&node]];

if self.graph.neighbors(id_node).next().is_none() {
return Err(MapNotPossible {});
}

for id_neighbor in self.graph.neighbors(id_node) {
let neighbor = self.node_map[&id_neighbor];
let dist_neighbor: DictMap<G::NodeId, usize> = dijkstra(
Expand Down Expand Up @@ -705,4 +710,19 @@ mod test_token_swapper {
Err(_) => (),
};
}

#[test]
fn test_edgeless_graph_fails() {
let mut g = petgraph::graph::UnGraph::<(), ()>::new_undirected();
let a = g.add_node(());
let b = g.add_node(());
let c = g.add_node(());
let d = g.add_node(());
g.add_edge(c, d, ());
let mapping = HashMap::from([(a, b), (b, a)]);
match token_swapper(&g, mapping, Some(10), Some(4), Some(50)) {
Ok(_) => panic!("This should error"),
Err(_) => (),
};
}
}
1 change: 0 additions & 1 deletion src/dag_algo/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,6 @@ pub fn collect_bicolor_runs(
}
} else {
for color in colors {
let color = color;
ensure_vector_has_index!(pending_list, block_id, color);
if let Some(color_block_id) = block_id[color] {
block_list[color_block_id].append(&mut pending_list[color]);
Expand Down
14 changes: 14 additions & 0 deletions src/digraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,20 @@ impl PyDiGraph {
}
false
}

/// Clear all nodes and edges
#[pyo3(text_signature = "(self)")]
pub fn clear(&mut self) {
self.graph.clear();
self.node_removed = true;
}

/// Clears all edges, leaves nodes intact
#[pyo3(text_signature = "(self)")]
pub fn clear_edges(&mut self) {
self.graph.clear_edges();
}

/// Return the number of nodes in the graph
#[pyo3(text_signature = "(self)")]
pub fn num_nodes(&self) -> usize {
Expand Down
13 changes: 13 additions & 0 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,19 @@ impl PyGraph {
false
}

/// Clears all nodes and edges
#[pyo3(text_signature = "(self)")]
pub fn clear(&mut self) {
self.graph.clear();
self.node_removed = true;
}

/// Clears all edges, leaves nodes intact
#[pyo3(text_signature = "(self)")]
pub fn clear_edges(&mut self) {
self.graph.clear_edges();
}

/// Return the number of nodes in the graph
#[pyo3(text_signature = "(self)")]
pub fn num_nodes(&self) -> usize {
Expand Down
2 changes: 1 addition & 1 deletion src/matching/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ fn _inner_is_matching(graph: &graph::PyGraph, matching: &HashSet<(usize, usize)>
.contains_edge(NodeIndex::new(e.0), NodeIndex::new(e.1))
};

if !matching.iter().all(|e| has_edge(e)) {
if !matching.iter().all(has_edge) {
return false;
}
let mut found: HashSet<usize> = HashSet::with_capacity(2 * matching.len());
Expand Down
1 change: 1 addition & 0 deletions src/score.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// License for the specific language governing permissions and limitations
// under the License.
#![allow(clippy::derive_partial_eq_without_eq)]
#![allow(clippy::incorrect_partial_ord_impl_on_ord_type)]

use std::cmp::Ordering;
use std::ops::{Add, AddAssign};
Expand Down
66 changes: 66 additions & 0 deletions tests/rustworkx_tests/digraph/test_clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import rustworkx


class TestClear(unittest.TestCase):
def test_clear(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
dag.add_child(node_a, "b", {"a": 1})
dag.add_child(node_a, "c", {"a": 2})
dag.clear()
self.assertEqual(dag.num_nodes(), 0)
self.assertEqual(dag.num_edges(), 0)
self.assertEqual(dag.nodes(), [])
self.assertEqual(dag.edges(), [])

def test_clear_reuse(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
dag.add_child(node_a, "b", {"a": 1})
dag.add_child(node_a, "c", {"a": 2})
dag.clear()
node_a = dag.add_node("a")
dag.add_child(node_a, "b", {"a": 1})
dag.add_child(node_a, "c", {"a": 2})
self.assertEqual(dag.num_nodes(), 3)
self.assertEqual(dag.num_edges(), 2)
self.assertEqual(dag.nodes(), ["a", "b", "c"])
self.assertEqual(dag.edges(), [{"a": 1}, {"a": 2}])

def test_clear_edges(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
dag.add_child(node_a, "b", {"a": 1})
dag.add_child(node_a, "c", {"a": 2})
dag.clear_edges()
self.assertEqual(dag.num_nodes(), 3)
self.assertEqual(dag.num_edges(), 0)
self.assertEqual(dag.nodes(), ["a", "b", "c"])
self.assertEqual(dag.edges(), [])

def test_clear_edges_reuse(self):
dag = rustworkx.PyDAG()
node_a = dag.add_node("a")
node_b = dag.add_child(node_a, "b", {"a": 1})
node_c = dag.add_child(node_a, "c", {"a": 2})
dag.clear_edges()
dag.add_edge(node_a, node_b, {"a": 1})
dag.add_edge(node_a, node_c, {"a": 2})
self.assertEqual(dag.num_nodes(), 3)
self.assertEqual(dag.num_edges(), 2)
self.assertEqual(dag.nodes(), ["a", "b", "c"])
self.assertEqual(dag.edges(), [{"a": 1}, {"a": 2}])
76 changes: 76 additions & 0 deletions tests/rustworkx_tests/graph/test_clear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Licensed under the Apache License, Version 3.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import unittest

import rustworkx


class TestClear(unittest.TestCase):
def test_clear(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, {"a": 1})
node_c = graph.add_node("c")
graph.add_edge(node_a, node_c, {"a": 2})
graph.clear()
self.assertEqual(graph.num_nodes(), 0)
self.assertEqual(graph.num_edges(), 0)
self.assertEqual(graph.nodes(), [])
self.assertEqual(graph.edges(), [])

def test_clear_reuse(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, {"a": 1})
node_c = graph.add_node("c")
graph.add_edge(node_a, node_c, {"a": 2})
graph.clear()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, {"a": 1})
node_c = graph.add_node("c")
graph.add_edge(node_a, node_c, {"a": 2})
self.assertEqual(graph.num_nodes(), 3)
self.assertEqual(graph.num_edges(), 2)
self.assertEqual(graph.nodes(), ["a", "b", "c"])
self.assertEqual(graph.edges(), [{"a": 1}, {"a": 2}])

def test_clear_edges(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, {"e1", 1})
node_c = graph.add_node("c")
graph.add_edge(node_a, node_c, {"e2", 2})
graph.clear_edges()
self.assertEqual(graph.num_edges(), 0)
self.assertEqual(graph.edges(), [])
self.assertEqual(graph.num_nodes(), 3)
self.assertEqual(graph.nodes(), ["a", "b", "c"])

def test_clear_edges_reuse(self):
graph = rustworkx.PyGraph()
node_a = graph.add_node("a")
node_b = graph.add_node("b")
graph.add_edge(node_a, node_b, {"e1", 1})
node_c = graph.add_node("c")
graph.add_edge(node_a, node_c, {"e2", 2})
graph.clear_edges()
graph.add_edge(node_a, node_b, {"e1", 1})
graph.add_edge(node_a, node_c, {"e2", 2})
self.assertEqual(graph.num_nodes(), 3)
self.assertEqual(graph.num_edges(), 2)
self.assertEqual(graph.nodes(), ["a", "b", "c"])
self.assertEqual(graph.edges(), [{"e1", 1}, {"e2", 2}])

0 comments on commit 4cf03a7

Please sign in to comment.