diff --git a/Cargo.lock b/Cargo.lock index 8eee8253d..6e4267fc2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -562,6 +562,7 @@ dependencies = [ "petgraph", "priority-queue", "rand", + "rand_pcg", "rayon", "rayon-cond", ] diff --git a/docs/source/api.rst b/docs/source/api.rst index 487d543b8..44a7dab7a 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -370,6 +370,7 @@ typed API based on the data type. rustworkx.graph_complement rustworkx.graph_union rustworkx.graph_tensor_product + rustworkx.graph_token_swapper rustworkx.graph_cartesian_product rustworkx.graph_random_layout rustworkx.graph_bipartite_layout diff --git a/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml b/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml new file mode 100644 index 000000000..e26ec104e --- /dev/null +++ b/releasenotes/notes/added-token-swapper-bd168eeb5a31bd99.yaml @@ -0,0 +1,6 @@ +--- +features: + - | + Added a new function, :func:`~.token_swapper()` which performs an + approximately optimal Token Swapping algorithm and supports partial + mappings (i.e. not-permutations) for graphs with missing tokens. diff --git a/rustworkx-core/Cargo.toml b/rustworkx-core/Cargo.toml index 7a3762a02..33c7e1200 100644 --- a/rustworkx-core/Cargo.toml +++ b/rustworkx-core/Cargo.toml @@ -14,6 +14,8 @@ keywords = ["graph"] ahash = "0.8.0" fixedbitset = "0.4.2" petgraph = "0.6.3" +rand = "0.8.5" +rand_pcg = "0.3.1" rayon = "1.6" num-traits = "0.2" priority-queue = "1.2" diff --git a/rustworkx-core/src/lib.rs b/rustworkx-core/src/lib.rs index 4da244f5c..ab54ad5dc 100644 --- a/rustworkx-core/src/lib.rs +++ b/rustworkx-core/src/lib.rs @@ -50,6 +50,7 @@ //! * [`connectivity`](./connectivity/index.html) //! * [`max_weight_matching`](./max_weight_matching/index.html) //! * [`shortest_path`](./shortest_path/index.html) +//! * [`token_swapper`](./token_swapper/index.html) //! * [`traversal`](./traversal/index.html) //! * [`generators`](./generators/index.html) //! @@ -82,6 +83,8 @@ pub mod traversal; pub mod dictmap; pub mod distancemap; mod min_scored; +/// Module for swapping tokens +pub mod token_swapper; pub mod utils; // re-export petgraph so there is a consistent version available to users and diff --git a/rustworkx-core/src/token_swapper.rs b/rustworkx-core/src/token_swapper.rs new file mode 100644 index 000000000..469236acc --- /dev/null +++ b/rustworkx-core/src/token_swapper.rs @@ -0,0 +1,608 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use rand::distributions::{Standard, Uniform}; +use rand::prelude::*; +use rand_pcg::Pcg64; +use std::hash::Hash; + +use hashbrown::HashMap; +use petgraph::stable_graph::{NodeIndex, StableGraph}; +use petgraph::visit::{ + EdgeCount, GraphBase, IntoEdges, IntoNeighborsDirected, IntoNodeIdentifiers, NodeCount, + NodeIndexable, Visitable, +}; +use petgraph::Directed; +use petgraph::Direction::{Incoming, Outgoing}; +use rayon_cond::CondIterator; + +use crate::connectivity::find_cycle; +use crate::dictmap::*; +use crate::shortest_path::dijkstra; +use crate::traversal::dfs_edges; + +type Swap = (NodeIndex, NodeIndex); +type Edge = (NodeIndex, NodeIndex); + +struct TokenSwapper +where + G::NodeId: Eq + Hash, +{ + // The input graph + graph: G, + // The user-supplied mapping to use for swapping tokens + mapping: HashMap, + // Number of trials + trials: usize, + // Seed for random selection of a node for a trial + seed: Option, + // Threshold for how many nodes will trigger parallel iterator + parallel_threshold: usize, + // Map of NodeId to NodeIndex + node_map: HashMap, + // Map of NodeIndex to NodeId + rev_node_map: HashMap, +} + +impl TokenSwapper +where + G: NodeCount + + EdgeCount + + IntoEdges + + Visitable + + NodeIndexable + + IntoNeighborsDirected + + IntoNodeIdentifiers + + Send + + Sync, + G::NodeId: Hash + Eq + Send + Sync, +{ + fn new( + graph: G, + mapping: HashMap, + trials: Option, + seed: Option, + parallel_threshold: Option, + ) -> Self { + TokenSwapper { + graph, + mapping, + trials: trials.unwrap_or(4), + seed, + parallel_threshold: parallel_threshold.unwrap_or(50), + node_map: HashMap::with_capacity(graph.node_count()), + rev_node_map: HashMap::with_capacity(graph.node_count()), + } + } + + fn map(&mut self) -> Vec { + let num_nodes = self.graph.node_bound(); + let num_edges = self.graph.edge_count(); + + // Directed graph with nodes matching ``graph`` and + // edges for neighbors closer than nodes + let mut digraph = StableGraph::with_capacity(num_nodes, num_edges); + + // First fill the digraph with nodes. Then since it's a stable graph, + // must go through and remove nodes that were removed in original graph + for _ in 0..self.graph.node_bound() { + digraph.add_node(()); + } + let mut count: usize = 0; + for gnode in self.graph.node_identifiers() { + let gidx = self.graph.to_index(gnode); + if gidx != count { + for idx in count..gidx { + digraph.remove_node(NodeIndex::new(idx)); + } + count = gidx; + } + count += 1; + } + + // Create maps between NodeId and NodeIndex + for node in self.graph.node_identifiers() { + self.node_map + .insert(node, NodeIndex::new(self.graph.to_index(node))); + self.rev_node_map + .insert(NodeIndex::new(self.graph.to_index(node)), node); + } + // sub will become same as digraph but with no self edges in add_token_edges + let mut sub_digraph = digraph.clone(); + + // The mapping in HashMap form using NodeIndex + let mut tokens: HashMap = self + .mapping + .iter() + .map(|(k, v)| (self.node_map[k], self.node_map[v])) + .collect(); + + // todo_nodes are all the mapping entries where left != right + let todo_nodes: Vec = tokens + .iter() + .filter_map(|(node, dest)| if node != dest { Some(*node) } else { None }) + .collect(); + + // Add initial edges to the digraph/sub_digraph + for node in self.graph.node_identifiers() { + self.add_token_edges( + self.node_map[&node], + &mut digraph, + &mut sub_digraph, + &mut tokens, + ); + } + // First collect the self.trial number of random numbers + // into a Vec based on the given seed + let outer_rng: Pcg64 = match self.seed { + Some(rng_seed) => Pcg64::seed_from_u64(rng_seed), + None => Pcg64::from_entropy(), + }; + let trial_seeds_vec: Vec = + outer_rng.sample_iter(&Standard).take(self.trials).collect(); + + CondIterator::new( + trial_seeds_vec, + self.graph.node_count() >= self.parallel_threshold, + ) + .map(|trial_seed| { + self.trial_map( + digraph.clone(), + sub_digraph.clone(), + tokens.clone(), + todo_nodes.clone(), + trial_seed, + ) + }) + .min_by_key(|result| result.len()) + .unwrap() + } + + fn add_token_edges( + &self, + node: NodeIndex, + digraph: &mut StableGraph<(), (), Directed>, + sub_digraph: &mut StableGraph<(), (), Directed>, + tokens: &mut HashMap, + ) { + // Adds an edge to digraph if distance from the token to a neighbor is + // less than distance from token to node. sub_digraph is same except + // for self-edges. + if !(tokens.contains_key(&node)) { + return; + } + if tokens[&node] == node { + digraph.update_edge(node, node, ()); + return; + } + let id_node = self.rev_node_map[&node]; + let id_token = self.rev_node_map[&tokens[&node]]; + for id_neighbor in self.graph.neighbors(id_node) { + let neighbor = self.node_map[&id_neighbor]; + let dist_neighbor: DictMap = dijkstra( + &self.graph, + id_neighbor, + Some(id_token), + |_| Ok::(1), + None, + ) + .unwrap(); + + let dist_node: DictMap = dijkstra( + &self.graph, + id_node, + Some(id_token), + |_| Ok::(1), + None, + ) + .unwrap(); + + if dist_neighbor[&id_token] < dist_node[&id_token] { + digraph.update_edge(node, neighbor, ()); + sub_digraph.update_edge(node, neighbor, ()); + } + } + } + + fn trial_map( + &self, + mut digraph: StableGraph<(), (), Directed>, + mut sub_digraph: StableGraph<(), (), Directed>, + mut tokens: HashMap, + mut todo_nodes: Vec, + trial_seed: u64, + ) -> Vec { + // Create a random trial list of swaps to move tokens to optimal positions + let mut steps = 0; + let mut swap_edges: Vec = vec![]; + let mut rng_seed: Pcg64 = Pcg64::seed_from_u64(trial_seed); + while !todo_nodes.is_empty() && steps <= 4 * digraph.node_count().pow(2) { + // Choose a random todo_node + let between = Uniform::new(0, todo_nodes.len()); + let random: usize = between.sample(&mut rng_seed); + let todo_node = todo_nodes[random]; + + // If there's a cycle in sub_digraph, add it to swap_edges and do swap + let cycle = find_cycle(&sub_digraph, Some(todo_node)); + if !cycle.is_empty() { + for edge in cycle[1..].iter().rev() { + swap_edges.push(*edge); + self.swap( + edge.0, + edge.1, + &mut digraph, + &mut sub_digraph, + &mut tokens, + &mut todo_nodes, + ); + } + steps += cycle.len() - 1; + // If there's no cycle, see if there's an edge target that matches a token key. + // If so, add to swap_edges and do swap + } else { + let mut found = false; + let sub2 = &sub_digraph.clone(); + for edge in dfs_edges(sub2, Some(todo_node)) { + let new_edge = (NodeIndex::new(edge.0), NodeIndex::new(edge.1)); + if !tokens.contains_key(&new_edge.1) { + swap_edges.push(new_edge); + self.swap( + new_edge.0, + new_edge.1, + &mut digraph, + &mut sub_digraph, + &mut tokens, + &mut todo_nodes, + ); + steps += 1; + found = true; + break; + } + } + // If none found, look for cycle in digraph which will result in + // an unhappy node. Look for a predecessor and add node and pred + // to swap_edges and do swap + if !found { + let cycle: Vec = find_cycle(&digraph, Some(todo_node)); + let unhappy_node = cycle[0].0; + let mut found = false; + let di2 = &mut digraph.clone(); + for predecessor in di2.neighbors_directed(unhappy_node, Incoming) { + if predecessor != unhappy_node { + swap_edges.push((unhappy_node, predecessor)); + self.swap( + unhappy_node, + predecessor, + &mut digraph, + &mut sub_digraph, + &mut tokens, + &mut todo_nodes, + ); + steps += 1; + found = true; + break; + } + } + assert!( + found, + "The token swap process has ended unexpectedly, this points to a bug in rustworkx, please open an issue." + ); + } + } + } + assert!( + todo_nodes.is_empty(), + "The output final swap map is incomplete, this points to a bug in rustworkx, please open an issue." + ); + swap_edges + } + + fn swap( + &self, + node1: NodeIndex, + node2: NodeIndex, + digraph: &mut StableGraph<(), (), Directed>, + sub_digraph: &mut StableGraph<(), (), Directed>, + tokens: &mut HashMap, + todo_nodes: &mut Vec, + ) { + // Get token values for the 2 nodes and remove them + let token1 = tokens.remove(&node1); + let token2 = tokens.remove(&node2); + + // Swap the token edge values + if let Some(t2) = token2 { + tokens.insert(node1, t2); + } + if let Some(t1) = token1 { + tokens.insert(node2, t1); + } + // For each node, remove the (node, successor) from digraph and + // sub_digraph. Then add new token edges back in. + for node in [node1, node2] { + let edge_nodes: Vec<(NodeIndex, NodeIndex)> = digraph + .neighbors_directed(node, Outgoing) + .map(|successor| (node, successor)) + .collect(); + for (edge_node1, edge_node2) in edge_nodes { + let edge = digraph.find_edge(edge_node1, edge_node2).unwrap(); + digraph.remove_edge(edge); + } + let edge_nodes: Vec<(NodeIndex, NodeIndex)> = sub_digraph + .neighbors_directed(node, Outgoing) + .map(|successor| (node, successor)) + .collect(); + for (edge_node1, edge_node2) in edge_nodes { + let edge = sub_digraph.find_edge(edge_node1, edge_node2).unwrap(); + sub_digraph.remove_edge(edge); + } + self.add_token_edges(node, digraph, sub_digraph, tokens); + + // If a node is a token key and not equal to the value, add it to todo_nodes + if tokens.contains_key(&node) && tokens[&node] != node { + if !todo_nodes.contains(&node) { + todo_nodes.push(node); + } + // Otherwise if node is in todo_nodes, remove it + } else if todo_nodes.contains(&node) { + todo_nodes.swap_remove(todo_nodes.iter().position(|x| *x == node).unwrap()); + } + } + } +} + +/// Module to perform an approximately optimal Token Swapping algorithm. Supports partial +/// mappings (i.e. not-permutations) for graphs with missing tokens. +/// +/// Based on the paper: Approximation and Hardness for Token Swapping by Miltzow et al. (2016) +/// ArXiV: +/// +/// Arguments: +/// +/// * `graph` - The graph on which to perform the token swapping. +/// * `mapping` - A partial mapping to be implemented in swaps. +/// * `trials` - Optional number of trials. If None, defaults to 4. +/// * `seed` - Optional integer seed. If None, the internal rng will be initialized from system entropy. +/// * `parallel_threshold` - Optional integer for the number of nodes in the graph that will +/// trigger the use of parallel threads. If the number of nodes in the graph is less than this value +/// it will run in a single thread. The default value is 50. +/// +/// It returns a list of tuples representing the swaps to perform. +/// +/// This function is multithreaded and will launch a thread pool with threads equal to +/// the number of CPUs by default. You can tune the number of threads with +/// the ``RAYON_NUM_THREADS`` environment variable. For example, setting ``RAYON_NUM_THREADS=4`` +/// would limit the thread pool to 4 threads. +/// +/// # Example +/// ```rust +/// use hashbrown::HashMap; +/// use rustworkx_core::petgraph; +/// use rustworkx_core::token_swapper::token_swapper; +/// use rustworkx_core::petgraph::graph::NodeIndex; +/// +/// let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); +/// let mapping = HashMap::from([ +/// (NodeIndex::new(0), NodeIndex::new(0)), +/// (NodeIndex::new(1), NodeIndex::new(3)), +/// (NodeIndex::new(3), NodeIndex::new(1)), +/// (NodeIndex::new(2), NodeIndex::new(2)), +/// ]); +/// // Do the token swap +/// let output = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); +/// assert_eq!(3, output.len()); +/// +/// ``` + +pub fn token_swapper( + graph: G, + mapping: HashMap, + trials: Option, + seed: Option, + parallel_threshold: Option, +) -> Vec +where + G: NodeCount + + EdgeCount + + IntoEdges + + Visitable + + NodeIndexable + + IntoNeighborsDirected + + IntoNodeIdentifiers + + Send + + Sync, + G::NodeId: Hash + Eq + Send + Sync, +{ + let mut swapper = TokenSwapper::new(graph, mapping, trials, seed, parallel_threshold); + swapper.map() +} + +#[cfg(test)] +mod test_token_swapper { + + use crate::petgraph; + use crate::token_swapper::token_swapper; + use hashbrown::HashMap; + use petgraph::graph::NodeIndex; + + fn do_swap(mapping: &mut HashMap, swaps: &Vec<(NodeIndex, NodeIndex)>) { + // Apply the swaps to the mapping to get final result + for (swap1, swap2) in swaps { + //Need to create temp nodes in case of partial mapping + let mut temp_node1: Option = None; + let mut temp_node2: Option = None; + if mapping.contains_key(swap1) { + temp_node1 = Some(mapping[swap1]); + mapping.remove(swap1); + } + if mapping.contains_key(swap2) { + temp_node2 = Some(mapping[swap2]); + mapping.remove(swap2); + } + if let Some(t1) = temp_node1 { + mapping.insert(*swap2, t1); + } + if let Some(t2) = temp_node2 { + mapping.insert(*swap1, t2); + } + } + } + + #[test] + fn test_simple_swap() { + // Simple arbitrary swap + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); + let mapping = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(0)), + (NodeIndex::new(1), NodeIndex::new(3)), + (NodeIndex::new(3), NodeIndex::new(1)), + (NodeIndex::new(2), NodeIndex::new(2)), + ]); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + assert_eq!(3, swaps.len()); + } + + #[test] + fn test_small_swap() { + // Reverse all small swap + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[ + (0, 1), + (1, 2), + (2, 3), + (3, 4), + (4, 5), + (5, 6), + (6, 7), + ]); + let mut mapping = HashMap::with_capacity(8); + for i in 0..8 { + mapping.insert(NodeIndex::new(i), NodeIndex::new(7 - i)); + } + // Do the token swap + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(8); + for i in 0..8 { + expected.insert(NodeIndex::new(i), NodeIndex::new(i)); + } + assert_eq!(expected, new_map); + } + + #[test] + fn test_happy_swap_chain() { + // Reverse all happy swap chain > 2 + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[ + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (1, 2), + (1, 3), + (1, 4), + (2, 3), + (2, 4), + (3, 4), + (3, 6), + ]); + let mapping = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(4)), + (NodeIndex::new(1), NodeIndex::new(0)), + (NodeIndex::new(2), NodeIndex::new(3)), + (NodeIndex::new(3), NodeIndex::new(6)), + (NodeIndex::new(4), NodeIndex::new(2)), + (NodeIndex::new(6), NodeIndex::new(1)), + ]); + // Do the token swap + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(6); + for i in (0..5).chain(6..7) { + expected.insert(NodeIndex::new(i), NodeIndex::new(i)); + } + assert_eq!(expected, new_map); + } + + #[test] + fn test_partial_simple() { + // Simple partial swap + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); + let mapping = HashMap::from([(NodeIndex::new(0), NodeIndex::new(3))]); + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(1)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(4); + expected.insert(NodeIndex::new(3), NodeIndex::new(3)); + assert_eq!(expected, new_map); + } + + #[test] + fn test_partial_simple_remove_node() { + // Simple partial swap + let mut g = + petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3), (3, 4)]); + let mapping = HashMap::from([(NodeIndex::new(0), NodeIndex::new(3))]); + g.remove_node(NodeIndex::new(2)); + g.add_edge(NodeIndex::new(1), NodeIndex::new(3), ()); + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(1)); + do_swap(&mut new_map, &swaps); + let mut expected = HashMap::with_capacity(4); + expected.insert(NodeIndex::new(3), NodeIndex::new(3)); + assert_eq!(expected, new_map); + } + + #[test] + fn test_partial_small() { + // Partial inverting on small path graph + let g = petgraph::graph::UnGraph::<(), ()>::from_edges(&[(0, 1), (1, 2), (2, 3)]); + let mapping = HashMap::from([ + (NodeIndex::new(0), NodeIndex::new(3)), + (NodeIndex::new(1), NodeIndex::new(2)), + ]); + let mut new_map = mapping.clone(); + let swaps = token_swapper(&g, mapping, Some(4), Some(4), Some(50)); + do_swap(&mut new_map, &swaps); + let expected = HashMap::from([ + (NodeIndex::new(2), NodeIndex::new(2)), + (NodeIndex::new(3), NodeIndex::new(3)), + ]); + assert_eq!(5, swaps.len()); + assert_eq!(expected, new_map); + } +} + +// TODO: Port this test when rustworkx-core adds random graphs + +// def test_large_partial_random(self) -> None: +// """Test a random (partial) mapping on a large randomly generated graph""" +// size = 100 +// # Note that graph may have "gaps" in the node counts, i.e. the numbering is noncontiguous. +// graph = rx.undirected_gnm_random_graph(size, size**2 // 10) +// for i in graph.node_indexes(): +// try: +// graph.remove_edge(i, i) # Remove self-loops. +// except rx.NoEdgeBetweenNodes: +// continue +// # Make sure the graph is connected by adding C_n +// graph.add_edges_from_no_data([(i, i + 1) for i in range(len(graph) - 1)]) +// swapper = ApproximateTokenSwapper(graph) # type: ApproximateTokenSwapper[int] + +// # Generate a randomized permutation. +// rand_perm = random.permutation(graph.nodes()) +// permutation = dict(zip(graph.nodes(), rand_perm)) +// mapping = dict(itertools.islice(permutation.items(), 0, size, 2)) # Drop every 2nd element. + +// out = list(swapper.map(mapping, trials=40)) +// util.swap_permutation([out], mapping, allow_missing_keys=True) +// self.assertEqual({i: i for i in mapping.values()}, mapping) diff --git a/src/lib.rs b/src/lib.rs index 87d6eea3c..7f941e2e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,7 @@ mod score; mod shortest_path; mod steiner_tree; mod tensor_product; +mod token_swapper; mod toposort; mod transitivity; mod traversal; @@ -52,6 +53,7 @@ use random_graph::*; use shortest_path::*; use steiner_tree::*; use tensor_product::*; +use token_swapper::*; use transitivity::*; use traversal::*; use tree::*; @@ -446,6 +448,7 @@ fn rustworkx(py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(minimum_spanning_tree))?; m.add_wrapped(wrap_pyfunction!(graph_transitivity))?; m.add_wrapped(wrap_pyfunction!(digraph_transitivity))?; + m.add_wrapped(wrap_pyfunction!(graph_token_swapper))?; m.add_wrapped(wrap_pyfunction!(graph_core_number))?; m.add_wrapped(wrap_pyfunction!(digraph_core_number))?; m.add_wrapped(wrap_pyfunction!(graph_complement))?; diff --git a/src/token_swapper.rs b/src/token_swapper.rs new file mode 100644 index 000000000..0adf5b85b --- /dev/null +++ b/src/token_swapper.rs @@ -0,0 +1,69 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations +// under the License. + +use crate::graph; +use crate::iterators::EdgeList; + +use hashbrown::HashMap; +use petgraph::graph::NodeIndex; +use pyo3::prelude::*; +use rustworkx_core::token_swapper; + +/// This module performs an approximately optimal Token Swapping algorithm +/// Supports partial mappings (i.e. not-permutations) for graphs with missing tokens. +/// +/// Based on the paper: Approximation and Hardness for Token Swapping by Miltzow et al. (2016) +/// ArXiV: https://arxiv.org/abs/1602.05150 +/// +/// The inputs are a partial ``mapping`` to be implemented in swaps, and the number of ``trials`` +/// to perform the mapping. It's minimized over the trials. +/// +/// It returns a list of tuples representing the swaps to perform. +/// +/// :param PyGraph graph: The input graph +/// :param dict[int: int] mapping: Map of (node, token) +/// :param int trials: The number of trials to run +/// :param int seed: The random seed to be used in producing random ints for selecting +/// which nodes to process next +/// :param int parallel_threshold: The number of nodes in the graph that will +/// trigger the use of parallel threads. If the number of nodes in the graph is less +/// than this value it will run in a single thread. The default value is 50. +/// +/// This function is multithreaded and will launch a thread pool with threads equal to +/// the number of CPUs by default. You can tune the number of threads with +/// the ``RAYON_NUM_THREADS`` environment variable. For example, setting ``RAYON_NUM_THREADS=4`` +/// would limit the thread pool to 4 threads. +/// +/// :returns: A list of tuples which are the swaps to be applied to the mapping to rearrange +/// the tokens. +/// :rtype: EdgeList +#[pyfunction] +#[pyo3(text_signature = "(graph, mapping, /, trials=None, seed=None, parallel_threshold=50)")] +pub fn graph_token_swapper( + graph: &graph::PyGraph, + mapping: HashMap, + trials: Option, + seed: Option, + parallel_threshold: Option, +) -> EdgeList { + let map: HashMap = mapping + .iter() + .map(|(s, t)| (NodeIndex::new(*s), NodeIndex::new(*t))) + .collect(); + let swaps = token_swapper::token_swapper(&graph.graph, map, trials, seed, parallel_threshold); + EdgeList { + edges: swaps + .into_iter() + .map(|(s, t)| (s.index(), t.index())) + .collect(), + } +} diff --git a/tests/rustworkx_tests/test_token_swapper.py b/tests/rustworkx_tests/test_token_swapper.py new file mode 100644 index 000000000..b5a207e32 --- /dev/null +++ b/tests/rustworkx_tests/test_token_swapper.py @@ -0,0 +1,118 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +import unittest +import itertools +import rustworkx as rx + +from numpy import random + + +def swap_permutation( + mapping, + swaps, +) -> None: + for (sw1, sw2) in list(swaps): + val1 = mapping.pop(sw1, None) + val2 = mapping.pop(sw2, None) + + if val1 is not None: + mapping[sw2] = val1 + if val2 is not None: + mapping[sw1] = val2 + + +class TestGeneral(unittest.TestCase): + """The test cases""" + + def setUp(self) -> None: + """Set up test cases.""" + super().setUp() + random.seed(0) + + def test_simple(self) -> None: + """Test a simple permutation on a path graph of size 4.""" + graph = rx.generators.path_graph(4) + permutation = {0: 0, 1: 3, 3: 1, 2: 2} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 1) + swap_permutation(permutation, swaps) + self.assertEqual(3, len(swaps)) + self.assertEqual({i: i for i in range(4)}, permutation) + + def test_small(self) -> None: + """Test an inverting permutation on a small path graph of size 8""" + graph = rx.generators.path_graph(8) + permutation = {i: 7 - i for i in range(8)} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 1) + swap_permutation(permutation, swaps) + self.assertEqual({i: i for i in range(8)}, permutation) + + def test_bug1(self) -> None: + """Tests for a bug that occured in happy swap chains of length >2.""" + graph = rx.PyGraph() + graph.extend_from_edge_list( + [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), (3, 6)] + ) + permutation = {0: 4, 1: 0, 2: 3, 3: 6, 4: 2, 6: 1} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 1) + swap_permutation(permutation, swaps) + self.assertEqual({i: i for i in permutation}, permutation) + + def test_partial_simple(self) -> None: + """Test a partial mapping on a small graph.""" + graph = rx.generators.path_graph(4) + mapping = {0: 3} + swaps = rx.graph_token_swapper(graph, mapping, 4, 4, 10) + swap_permutation(mapping, swaps) + self.assertEqual(3, len(swaps)) + self.assertEqual({3: 3}, mapping) + + def test_partial_simple_remove_node(self) -> None: + """Test a partial mapping on a small graph with a node removed.""" + graph = rx.generators.path_graph(5) + graph.remove_node(2) + graph.add_edge(1, 3, None) + mapping = {0: 3} + swaps = rx.graph_token_swapper(graph, mapping, 4, 4, 10) + swap_permutation(mapping, swaps) + self.assertEqual(2, len(swaps)) + self.assertEqual({3: 3}, mapping) + + def test_partial_small(self) -> None: + """Test an partial inverting permutation on a small path graph of size 5""" + graph = rx.generators.path_graph(4) + permutation = {i: 3 - i for i in range(2)} + swaps = rx.graph_token_swapper(graph, permutation, 4, 4, 10) + swap_permutation(permutation, swaps) + self.assertEqual(5, len(swaps)) + self.assertEqual({i: i for i in permutation.values()}, permutation) + + def test_large_partial_random(self) -> None: + """Test a random (partial) mapping on a large randomly generated graph""" + size = 100 + # Note that graph may have "gaps" in the node counts, i.e. the numbering is noncontiguous. + graph = rx.undirected_gnm_random_graph(size, size**2 // 10) + for i in graph.node_indexes(): + try: + graph.remove_edge(i, i) # Remove self-loops. + except rx.NoEdgeBetweenNodes: + continue + # Make sure the graph is connected by adding C_n + graph.add_edges_from_no_data([(i, i + 1) for i in range(len(graph) - 1)]) + + # Generate a randomized permutation. + rand_perm = random.permutation(graph.nodes()) + permutation = dict(zip(graph.nodes(), rand_perm)) + mapping = dict(itertools.islice(permutation.items(), 0, size, 2)) # Drop every 2nd element. + swaps = rx.graph_token_swapper(graph, permutation, 4, 4) + swap_permutation(mapping, swaps) + self.assertEqual({i: i for i in mapping.values()}, mapping)