diff --git a/badger-optimiser/src/main.rs b/badger-optimiser/src/main.rs index 706b8455..861f94af 100644 --- a/badger-optimiser/src/main.rs +++ b/badger-optimiser/src/main.rs @@ -13,6 +13,7 @@ use std::process::exit; use clap::Parser; use tket2::json::{load_tk1_json_file, save_tk1_json_file}; use tket2::optimiser::badger::log::BadgerLogger; +use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerOptimiser, DefaultBadgerOptimiser}; #[cfg(feature = "peak_alloc")] @@ -72,6 +73,14 @@ struct CmdLineArgs { help = "Timeout in seconds (default=None)." )] timeout: Option, + /// Maximum time in seconds to wait between circuit improvements (default=no timeout) + #[arg( + short = 'p', + long, + value_name = "PROGRESS_TIMEOUT", + help = "Maximum time in seconds to wait between circuit improvements (default=None)." + )] + progress_timeout: Option, /// Number of threads (default=1) #[arg( short = 'j', @@ -140,10 +149,13 @@ fn main() -> Result<(), Box> { let opt_circ = optimiser.optimise_with_log( &circ, badger_logger, - opts.timeout, - n_threads, - opts.split_circ, - opts.queue_size, + BadgerOptions { + timeout: opts.timeout, + progress_timeout: opts.progress_timeout, + n_threads, + split_circuit: opts.split_circ, + queue_size: opts.queue_size, + }, ); println!("Saving result"); diff --git a/tket2-py/src/optimiser.rs b/tket2-py/src/optimiser.rs index 38a8fd11..2c92a89f 100644 --- a/tket2-py/src/optimiser.rs +++ b/tket2-py/src/optimiser.rs @@ -5,6 +5,7 @@ use std::{fs, num::NonZeroUsize, path::PathBuf}; use hugr::Hugr; use pyo3::prelude::*; +use tket2::optimiser::badger::BadgerOptions; use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser}; use crate::circuit::update_hugr; @@ -47,35 +48,55 @@ impl PyBadgerOptimiser { /// # Parameters /// /// * `circ`: The circuit to optimise. - /// * `timeout`: The timeout in seconds. - /// * `n_threads`: The number of threads to use. - /// * `split_circ`: Whether to split the circuit into chunks before - /// processing. /// - /// If this option is set, the optimise will divide the circuit into - /// `n_threads` chunks and optimise each on a separate thread. + /// * `timeout`: The maximum time (in seconds) to run the optimiser. + /// + /// If `None` the optimiser will run indefinitely, or until + /// `progress_timeout` is reached. + /// + /// * `progress_timeout`: The maximum time (in seconds) to search for new + /// improvements to the circuit. If no progress is made in this time, + /// the optimiser will stop. + /// + /// If `None` the optimiser will run indefinitely, or until `timeout` is + /// reached. + /// + /// * `n_threads`: The number of threads to use. Defaults to `1`. + /// + /// * `split_circ`: Whether to split the circuit into chunks and process + /// each in a separate thread. + /// + /// If this option is set to `true`, the optimiser will split the + /// circuit into `n_threads` chunks. + /// + /// If this option is set to `false`, the optimiser will run `n_threads` + /// parallel searches on the whole circuit (default). + /// + /// * `queue_size`: The maximum size of the circuit candidates priority + /// queue. Defaults to `20`. + /// /// * `log_progress`: The path to a CSV file to log progress to. /// #[pyo3(name = "optimise")] + #[allow(clippy::too_many_arguments)] pub fn py_optimise<'py>( &self, circ: &'py PyAny, timeout: Option, + progress_timeout: Option, n_threads: Option, split_circ: Option, - log_progress: Option, queue_size: Option, + log_progress: Option, ) -> PyResult<&'py PyAny> { - update_hugr(circ, |circ, _| { - self.optimise( - circ, - timeout, - n_threads, - split_circ, - log_progress, - queue_size, - ) - }) + let options = BadgerOptions { + timeout, + progress_timeout, + n_threads: n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), + split_circuit: split_circ.unwrap_or(false), + queue_size: queue_size.unwrap_or(100), + }; + update_hugr(circ, |circ, _| self.optimise(circ, log_progress, options)) } } @@ -84,11 +105,8 @@ impl PyBadgerOptimiser { pub(super) fn optimise( &self, circ: Hugr, - timeout: Option, - n_threads: Option, - split_circ: Option, log_progress: Option, - queue_size: Option, + options: BadgerOptions, ) -> Hugr { let badger_logger = log_progress .map(|file_name| { @@ -97,13 +115,6 @@ impl PyBadgerOptimiser { BadgerLogger::new(log_file) }) .unwrap_or_default(); - self.0.optimise_with_log( - &circ, - badger_logger, - timeout, - n_threads.unwrap_or(NonZeroUsize::new(1).unwrap()), - split_circ.unwrap_or(false), - queue_size.unwrap_or(100), - ) + self.0.optimise_with_log(&circ, badger_logger, options) } } diff --git a/tket2-py/src/passes.rs b/tket2-py/src/passes.rs index 3cef24cd..086a9256 100644 --- a/tket2-py/src/passes.rs +++ b/tket2-py/src/passes.rs @@ -5,6 +5,7 @@ pub mod chunks; use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf}; use pyo3::{prelude::*, types::IntoPyDict}; +use tket2::optimiser::badger::BadgerOptions; use tket2::{op_matches, passes::apply_greedy_commutation, Circuit, Tk2Op}; use crate::utils::{create_py_exception, ConvertPyErr}; @@ -78,6 +79,7 @@ fn badger_optimise<'py>( optimiser: &PyBadgerOptimiser, max_threads: Option, timeout: Option, + progress_timeout: Option, log_dir: Option, rebase: Option, ) -> PyResult<&'py PyAny> { @@ -124,14 +126,14 @@ fn badger_optimise<'py>( log_file.push(format!("cycle-{i}.log")); log_file }); - circ = optimiser.optimise( - circ, - Some(timeout), - Some(n_threads.try_into().unwrap()), - Some(true), - log_file, - None, - ); + let options = BadgerOptions { + timeout: Some(timeout), + progress_timeout, + n_threads: n_threads.try_into().unwrap(), + split_circuit: true, + ..Default::default() + }; + circ = optimiser.optimise(circ, log_file, options); } PyResult::Ok(circ) }) diff --git a/tket2/src/optimiser/badger.rs b/tket2/src/optimiser/badger.rs index ccf5000c..687d2813 100644 --- a/tket2/src/optimiser/badger.rs +++ b/tket2/src/optimiser/badger.rs @@ -40,6 +40,50 @@ use crate::rewrite::strategy::RewriteStrategy; use crate::rewrite::Rewriter; use crate::Circuit; +/// Configuration options for the Badger optimiser. +#[derive(Copy, Clone, Debug)] +pub struct BadgerOptions { + /// The maximum time (in seconds) to run the optimiser. + /// + /// Defaults to `None`, which means no timeout. + pub timeout: Option, + /// The maximum time (in seconds) to search for new improvements to the + /// circuit. If no progress is made in this time, the optimiser will stop. + /// + /// Defaults to `None`, which means no timeout. + pub progress_timeout: Option, + /// The number of threads to use. + /// + /// Defaults to `1`. + pub n_threads: NonZeroUsize, + /// Whether to split the circuit into chunks and process each in a separate thread. + /// + /// If this option is set to `true`, the optimiser will split the circuit into `n_threads` + /// chunks. + /// + /// If this option is set to `false`, the optimiser will run parallel searches on the whole + /// circuit. + /// + /// Defaults to `false`. + pub split_circuit: bool, + /// The maximum size of the circuit candidates priority queue. + /// + /// Defaults to `20`. + pub queue_size: usize, +} + +impl Default for BadgerOptions { + fn default() -> Self { + Self { + timeout: Default::default(), + progress_timeout: Default::default(), + n_threads: NonZeroUsize::new(1).unwrap(), + split_circuit: Default::default(), + queue_size: 20, + } + } +} + /// The Badger optimiser. /// /// Adapted from [Quartz][], and originally [TASO][]. @@ -85,22 +129,8 @@ where /// Run the Badger optimiser on a circuit. /// /// A timeout (in seconds) can be provided. - pub fn optimise( - &self, - circ: &Hugr, - timeout: Option, - n_threads: NonZeroUsize, - split_circuit: bool, - queue_size: usize, - ) -> Hugr { - self.optimise_with_log( - circ, - Default::default(), - timeout, - n_threads, - split_circuit, - queue_size, - ) + pub fn optimise(&self, circ: &Hugr, options: BadgerOptions) -> Hugr { + self.optimise_with_log(circ, Default::default(), options) } /// Run the Badger optimiser on a circuit with logging activated. @@ -110,31 +140,21 @@ where &self, circ: &Hugr, log_config: BadgerLogger, - timeout: Option, - n_threads: NonZeroUsize, - split_circuit: bool, - queue_size: usize, + options: BadgerOptions, ) -> Hugr { - if split_circuit && n_threads.get() > 1 { - return self - .split_run(circ, log_config, timeout, n_threads, queue_size) - .unwrap(); + if options.split_circuit && options.n_threads.get() > 1 { + return self.split_run(circ, log_config, options).unwrap(); } - match n_threads.get() { - 1 => self.badger(circ, log_config, timeout, queue_size), - _ => self.badger_multithreaded(circ, log_config, timeout, n_threads, queue_size), + match options.n_threads.get() { + 1 => self.badger(circ, log_config, options), + _ => self.badger_multithreaded(circ, log_config, options), } } #[tracing::instrument(target = "badger::metrics", skip(self, circ, logger))] - fn badger( - &self, - circ: &Hugr, - mut logger: BadgerLogger, - timeout: Option, - queue_size: usize, - ) -> Hugr { + fn badger(&self, circ: &Hugr, mut logger: BadgerLogger, opt: BadgerOptions) -> Hugr { let start_time = Instant::now(); + let mut last_best_time = Instant::now(); let mut best_circ = circ.clone(); let mut best_circ_cost = self.cost(circ); @@ -152,7 +172,7 @@ where }; let cost = (cost_fn)(circ); - let mut pq = HugrPQ::new(cost_fn, queue_size); + let mut pq = HugrPQ::new(cost_fn, opt.queue_size); pq.push_unchecked(circ.clone(), hash, cost); let mut circ_cnt = 0; @@ -162,6 +182,7 @@ where best_circ = circ.clone(); best_circ_cost = cost.clone(); logger.log_best(&best_circ_cost); + last_best_time = Instant::now(); } circ_cnt += 1; @@ -187,12 +208,18 @@ where logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len()); } - if let Some(timeout) = timeout { + if let Some(timeout) = opt.timeout { if start_time.elapsed().as_secs() > timeout { timeout_flag = true; break; } } + if let Some(p_timeout) = opt.progress_timeout { + if last_best_time.elapsed().as_secs() > p_timeout { + timeout_flag = true; + break; + } + } } logger.log_processing_end( @@ -214,18 +241,16 @@ where &self, circ: &Hugr, mut logger: BadgerLogger, - timeout: Option, - n_threads: NonZeroUsize, - queue_size: usize, + opt: BadgerOptions, ) -> Hugr { - let n_threads: usize = n_threads.get(); + let n_threads: usize = opt.n_threads.get(); // multi-consumer priority channel for queuing circuits to be processed by the workers let cost_fn = { let strategy = self.strategy.clone(); move |circ: &'_ Hugr| strategy.circuit_cost(circ) }; - let (pq, rx_log) = HugrPriorityChannel::init(cost_fn.clone(), queue_size); + let (pq, rx_log) = HugrPriorityChannel::init(cost_fn.clone(), opt.queue_size); let initial_circ_hash = circ.circuit_hash().unwrap(); let mut best_circ = circ.clone(); @@ -248,7 +273,13 @@ where .collect(); // Deadline for the optimisation timeout - let timeout_event = match timeout { + let timeout_event = match opt.timeout { + None => crossbeam_channel::never(), + Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), + }; + + // Deadline for the timeout when no progress is made + let mut progress_timeout_event = match opt.progress_timeout { None => crossbeam_channel::never(), Some(t) => crossbeam_channel::at(Instant::now() + Duration::from_secs(t)), }; @@ -267,6 +298,9 @@ where best_circ = circ; best_circ_cost = cost; logger.log_best(&best_circ_cost); + if let Some(t) = opt.progress_timeout { + progress_timeout_event = crossbeam_channel::at(Instant::now() + Duration::from_secs(t)); + } } }, Ok(PriorityChannelLog::CircuitCount{processed_count: proc, seen_count: seen, queue_length}) => { @@ -287,6 +321,12 @@ where let _ = pq.close(); break; } + recv(progress_timeout_event) -> _ => { + timeout_flag = true; + // Signal the workers to stop. + let _ = pq.close(); + break; + } } } @@ -330,12 +370,10 @@ where &self, circ: &Hugr, mut logger: BadgerLogger, - timeout: Option, - n_threads: NonZeroUsize, - queue_size: usize, + opt: BadgerOptions, ) -> Result { let circ_cost = self.cost(circ); - let max_chunk_cost = circ_cost.clone().div_cost(n_threads); + let max_chunk_cost = circ_cost.clone().div_cost(opt.n_threads); logger.log(format!( "Splitting circuit with cost {:?} into chunks of at most {max_chunk_cost:?}.", circ_cost.clone() @@ -359,10 +397,11 @@ where .spawn(move || { let res = badger.optimise( &chunk, - timeout, - NonZeroUsize::new(1).unwrap(), - false, - queue_size, + BadgerOptions { + n_threads: NonZeroUsize::new(1).unwrap(), + split_circuit: false, + ..opt + }, ); tx.send(res).unwrap(); }) @@ -384,7 +423,7 @@ where logger.log_best(best_circ_cost.clone()); } - logger.log_processing_end(n_threads.get(), None, best_circ_cost, true, false); + logger.log_processing_end(opt.n_threads.get(), None, best_circ_cost, true, false); joins.into_iter().for_each(|j| j.join().unwrap()); Ok(best_circ) @@ -446,6 +485,7 @@ mod tests { use rstest::{fixture, rstest}; use crate::json::load_tk1_json_str; + use crate::optimiser::badger::BadgerOptions; use crate::{extension::REGISTRY, Circuit, Tk2Op}; use super::{BadgerOptimiser, DefaultBadgerOptimiser}; @@ -517,14 +557,28 @@ mod tests { #[rstest] fn rz_rz_cancellation(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) { - let opt_rz = badger_opt.optimise(&rz_rz, None, 1.try_into().unwrap(), false, 4); + let opt_rz = badger_opt.optimise( + &rz_rz, + BadgerOptions { + queue_size: 4, + ..Default::default() + }, + ); // Rzs combined into a single one. assert_eq!(gates(&opt_rz), vec![Tk2Op::AngleAdd, Tk2Op::RzF64]); } #[rstest] fn rz_rz_cancellation_parallel(rz_rz: Hugr, badger_opt: DefaultBadgerOptimiser) { - let mut opt_rz = badger_opt.optimise(&rz_rz, Some(0), 2.try_into().unwrap(), false, 4); + let mut opt_rz = badger_opt.optimise( + &rz_rz, + BadgerOptions { + timeout: Some(0), + n_threads: 2.try_into().unwrap(), + queue_size: 4, + ..Default::default() + }, + ); opt_rz.update_validate(®ISTRY).unwrap(); } @@ -536,10 +590,11 @@ mod tests { ) { let mut opt = badger_opt_full.optimise( &non_composable_rw_hugr, - Some(0), - 1.try_into().unwrap(), - false, - 10, + BadgerOptions { + timeout: Some(0), + queue_size: 4, + ..Default::default() + }, ); // No rewrites applied. opt.update_validate(®ISTRY).unwrap(); diff --git a/tket2/tests/taso_termination.rs b/tket2/tests/badger_termination.rs similarity index 93% rename from tket2/tests/taso_termination.rs rename to tket2/tests/badger_termination.rs index 04d285c4..f899ca9d 100644 --- a/tket2/tests/taso_termination.rs +++ b/tket2/tests/badger_termination.rs @@ -2,6 +2,7 @@ use hugr::Hugr; use rstest::{fixture, rstest}; +use tket2::optimiser::badger::BadgerOptions; use tket2::{ json::TKETDecode, optimiser::{BadgerOptimiser, DefaultBadgerOptimiser}, @@ -57,6 +58,12 @@ fn simple_circ() -> Hugr { #[rstest] //#[ignore = "Takes 200ms"] fn badger_termination(simple_circ: Hugr, nam_4_2: DefaultBadgerOptimiser) { - let opt_circ = nam_4_2.optimise(&simple_circ, None, 1.try_into().unwrap(), false, 10); + let opt_circ = nam_4_2.optimise( + &simple_circ, + BadgerOptions { + queue_size: 10, + ..Default::default() + }, + ); assert_eq!(opt_circ.commands().count(), 11); }