diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index b2b0c3f6..97825ccb 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -24,6 +24,7 @@ use fxhash::FxHashSet; use hugr::hugr::HugrError; use hugr::HugrView; pub use log::BadgerLogger; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; use std::num::NonZeroUsize; use std::time::{Duration, Instant}; @@ -131,7 +132,7 @@ impl BadgerOptimiser { impl BadgerOptimiser where - R: Rewriter + Send + Clone + 'static, + R: Rewriter + Send + Clone + Sync + 'static, S: RewriteStrategy + Send + Sync + Clone + 'static, S::Cost: serde::Serialize + Send + Sync, { @@ -440,7 +441,7 @@ where logger.log_best(circ_cost.clone(), num_rewrites); let (joins, rx_work): (Vec<_>, Vec<_>) = chunks - .iter_mut() + .par_iter_mut() .enumerate() .map(|(i, chunk)| { let (tx, rx) = crossbeam_channel::unbounded(); diff --git a/tket2/src/optimiser/badger/log.rs b/tket2/src/optimiser/badger/log.rs index 116b7ad4..c24fe379 100644 --- a/tket2/src/optimiser/badger/log.rs +++ b/tket2/src/optimiser/badger/log.rs @@ -5,7 +5,7 @@ use std::{fmt::Debug, io}; /// Logging configuration for the Badger optimiser. pub struct BadgerLogger<'w> { - circ_candidates_csv: Option>>, + circ_candidates_csv: Option>>, last_circ_processed: usize, last_progress_time: Instant, branching_factor: UsizeAverage, @@ -41,8 +41,9 @@ impl<'w> BadgerLogger<'w> { /// or [`PROGRESS_TARGET`]. /// /// [`log`]: - pub fn new(best_progress_csv_writer: impl io::Write + 'w) -> Self { - let boxed_candidates_writer: Box = Box::new(best_progress_csv_writer); + pub fn new(best_progress_csv_writer: impl io::Write + Send + Sync + 'w) -> Self { + let boxed_candidates_writer: Box = + Box::new(best_progress_csv_writer); Self { circ_candidates_csv: Some(csv::Writer::from_writer(boxed_candidates_writer)), ..Default::default() diff --git a/tket2/src/passes/chunks.rs b/tket2/src/passes/chunks.rs index d07c12d3..75b6d74e 100644 --- a/tket2/src/passes/chunks.rs +++ b/tket2/src/passes/chunks.rs @@ -18,6 +18,8 @@ use hugr::types::Signature; use hugr::{HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; use itertools::Itertools; use portgraph::algorithms::ConvexChecker; +use rayon::iter::{IntoParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; +use rayon::slice::ParallelSliceMut; use crate::Circuit; @@ -442,6 +444,19 @@ impl CircuitChunks { pub fn is_empty(&self) -> bool { self.chunks.is_empty() } + + /// Supports implementation of rayon::iter::IntoParallelRefMutIterator + fn par_iter_mut( + &mut self, + ) -> rayon::iter::Map< + rayon::slice::IterMut<'_, Chunk>, + for<'a> fn(&'a mut Chunk) -> &'a mut Circuit, + > { + self.chunks + .as_parallel_slice_mut() + .into_par_iter() + .map(|chunk| &mut chunk.circ) + } } impl Index for CircuitChunks { @@ -458,6 +473,18 @@ impl IndexMut for CircuitChunks { } } +impl<'data> IntoParallelRefMutIterator<'data> for CircuitChunks { + type Item = &'data mut Circuit; + type Iter = rayon::iter::Map< + rayon::slice::IterMut<'data, Chunk>, + for<'a> fn(&'a mut Chunk) -> &'a mut Circuit, + >; + + fn par_iter_mut(&'data mut self) -> Self::Iter { + self.par_iter_mut() + } +} + #[cfg(test)] mod test { use crate::circuit::CircuitHash;