From 02157484bcb40d90c934f55c795d7627f3ee8768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 12 Sep 2023 12:40:06 +0200 Subject: [PATCH] feat!: Let Hugr implement Circuit (#91) - Replaces the `HierarchyView` bound in `Circuit` with `HugrView` (from https://github.com/CQCL-DEV/hugr/pull/498), which is implemented by `Hugr`. - This simplifies the tests, where we don't need to wrap `Hugr`s anymore. - We still need a lifetime for the Commands iterator, so `Circuit::Comands` now has a generic lifetime. - Uses the new `PetgraphWrapper` when needed instead of requiring Petgraph traits on everything. - This lets us drop the lifetime from `Circuit<'a>`. Closes #84. ~Blocked by https://github.com/CQCL-DEV/hugr/pull/498.~ --- Cargo.toml | 2 +- benches/benchmarks/hash.rs | 4 +-- compile-matcher/src/main.rs | 2 +- src/circuit.rs | 64 +++++++++++++++---------------------- src/circuit/command.rs | 17 +++------- src/circuit/hash.rs | 10 +++--- src/json.rs | 4 +-- src/json/encoder.rs | 2 +- src/ops.rs | 8 ++--- src/portmatching/matcher.rs | 62 +++++++++++++---------------------- src/portmatching/pattern.rs | 20 +++--------- src/portmatching/pyo3.rs | 4 +-- src/rewrite.rs | 2 +- src/rewrite/ecc_rewriter.rs | 2 +- 14 files changed, 76 insertions(+), 127 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9b0f6292..e9d0ce53 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,7 +69,7 @@ members = ["pyrs", "compile-matcher"] [workspace.dependencies] -quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "5a97a635" } +quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", rev = "e23323d" } portgraph = { version = "0.9", features = ["serde"] } pyo3 = { version = "0.19" } itertools = { version = "0.11.0" } diff --git a/benches/benchmarks/hash.rs b/benches/benchmarks/hash.rs index ad62304c..a9f28c87 100644 --- a/benches/benchmarks/hash.rs +++ b/benches/benchmarks/hash.rs @@ -1,7 +1,7 @@ use criterion::{black_box, criterion_group, AxisScale, BenchmarkId, Criterion, PlotConfiguration}; -use hugr::hugr::views::SiblingGraph; +use hugr::hugr::views::{HierarchyView, SiblingGraph}; use hugr::HugrView; -use tket2::circuit::{CircuitHash, HierarchyView}; +use tket2::circuit::CircuitHash; use super::generators::make_cnot_layers; diff --git a/compile-matcher/src/main.rs b/compile-matcher/src/main.rs index 5c0f59dc..2078a94f 100644 --- a/compile-matcher/src/main.rs +++ b/compile-matcher/src/main.rs @@ -2,7 +2,7 @@ use std::fs; use std::path::Path; use clap::Parser; -use hugr::hugr::views::{HierarchyView, SiblingGraph}; +use hugr::hugr::views::SiblingGraph; use hugr::ops::handle::DfgID; use hugr::HugrView; use itertools::Itertools; diff --git a/src/circuit.rs b/src/circuit.rs index a36fd2da..e7ffc4fc 100644 --- a/src/circuit.rs +++ b/src/circuit.rs @@ -18,12 +18,10 @@ use hugr::hugr::{CircuitUnit, NodeType}; use hugr::ops::OpTrait; use hugr::HugrView; -pub use hugr::hugr::views::HierarchyView; pub use hugr::ops::OpType; use hugr::types::TypeBound; pub use hugr::types::{EdgeKind, Signature, Type, TypeRow}; pub use hugr::{Node, Port, Wire}; -use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers}; /// An object behaving like a quantum circuit. // @@ -32,9 +30,11 @@ use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers}; // - Vertical slice iterator // - Gate count map // - Depth -pub trait Circuit<'circ>: HugrView { +pub trait Circuit: HugrView { /// An iterator over the commands in the circuit. - type Commands: Iterator; + type Commands<'a>: Iterator + where + Self: 'a; /// An iterator over the commands applied to an unit. type UnitCommands: Iterator; @@ -67,10 +67,10 @@ pub trait Circuit<'circ>: HugrView { /// Returns all the commands in the circuit, in some topological order. /// /// Ignores the Input and Output nodes. - fn commands(&'circ self) -> Self::Commands; + fn commands(&self) -> Self::Commands<'_>; /// Returns all the commands applied to the given unit, in order. - fn unit_commands(&'circ self) -> Self::UnitCommands; + fn unit_commands(&self) -> Self::UnitCommands; /// Returns the [`NodeType`] of a command. fn command_nodetype(&self, command: &Command) -> &NodeType { @@ -86,12 +86,11 @@ pub trait Circuit<'circ>: HugrView { fn num_gates(&self) -> usize; } -impl<'circ, T> Circuit<'circ> for T +impl Circuit for T where - T: 'circ + HierarchyView<'circ>, - for<'a> &'a T: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, + T: HugrView, { - type Commands = CommandIterator<'circ, T>; + type Commands<'a> = CommandIterator<'a, T> where Self: 'a; type UnitCommands = std::iter::Empty; #[inline] @@ -129,12 +128,12 @@ where } } - fn commands(&'circ self) -> Self::Commands { + fn commands(&self) -> Self::Commands<'_> { // Traverse the circuit in topological order. CommandIterator::new(self) } - fn unit_commands(&'circ self) -> Self::UnitCommands { + fn unit_commands(&self) -> Self::UnitCommands { // TODO Can we associate linear i/o with the corresponding unit without // doing the full toposort? unimplemented!() @@ -158,35 +157,24 @@ where #[cfg(test)] mod tests { - use std::sync::OnceLock; - - use hugr::{ - hugr::views::{DescendantsGraph, HierarchyView}, - ops::handle::DfgID, - Hugr, HugrView, - }; + use hugr::Hugr; use crate::{circuit::Circuit, json::load_tk1_json_str}; - static CIRC: OnceLock = OnceLock::new(); - - fn test_circuit() -> DescendantsGraph<'static, DfgID> { - let hugr = CIRC.get_or_init(|| { - load_tk1_json_str( - r#"{ - "phase": "0", - "bits": [], - "qubits": [["q", [0]], ["q", [1]]], - "commands": [ - {"args": [["q", [0]]], "op": {"type": "H"}}, - {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}} - ], - "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] - }"#, - ) - .unwrap() - }); - DescendantsGraph::new(hugr, hugr.root()) + fn test_circuit() -> Hugr { + load_tk1_json_str( + r#"{ + "phase": "0", + "bits": [], + "qubits": [["q", [0]], ["q", [1]]], + "commands": [ + {"args": [["q", [0]]], "op": {"type": "H"}}, + {"args": [["q", [0]], ["q", [1]]], "op": {"type": "CX"}} + ], + "implicit_permutation": [[["q", [0]], ["q", [0]]], [["q", [1]], ["q", [1]]]] + }"#, + ) + .unwrap() } #[test] diff --git a/src/circuit/command.rs b/src/circuit/command.rs index 747a47cb..570c8bb2 100644 --- a/src/circuit/command.rs +++ b/src/circuit/command.rs @@ -5,9 +5,7 @@ use std::collections::HashMap; use std::iter::FusedIterator; -use hugr::hugr::views::HierarchyView; use hugr::ops::{OpTag, OpTrait}; -use petgraph::visit::{GraphBase, IntoNeighborsDirected, IntoNodeIdentifiers}; use super::Circuit; @@ -59,8 +57,7 @@ pub struct CommandIterator<'circ, Circ> { impl<'circ, Circ> CommandIterator<'circ, Circ> where - Circ: HierarchyView<'circ>, - for<'a> &'a Circ: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, + Circ: Circuit, { /// Create a new iterator over the commands of a circuit. pub(super) fn new(circ: &'circ Circ) -> Self { @@ -77,7 +74,7 @@ where }) .collect(); - let nodes = petgraph::algo::toposort(circ, None).unwrap(); + let nodes = petgraph::algo::toposort(&circ.as_petgraph(), None).unwrap(); Self { circ, nodes, @@ -157,8 +154,7 @@ where impl<'circ, Circ> Iterator for CommandIterator<'circ, Circ> where - Circ: HierarchyView<'circ>, - for<'a> &'a Circ: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, + Circ: Circuit, { type Item = Command; @@ -182,12 +178,7 @@ where } } -impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> -where - Circ: HierarchyView<'circ>, - for<'a> &'a Circ: GraphBase + IntoNeighborsDirected + IntoNodeIdentifiers, -{ -} +impl<'circ, Circ> FusedIterator for CommandIterator<'circ, Circ> where Circ: Circuit {} #[cfg(test)] mod test { diff --git a/src/circuit/hash.rs b/src/circuit/hash.rs index b62ddd46..e9a0109d 100644 --- a/src/circuit/hash.rs +++ b/src/circuit/hash.rs @@ -4,7 +4,6 @@ use core::panic; use std::hash::{Hash, Hasher}; use fxhash::{FxHashMap, FxHasher64}; -use hugr::hugr::views::HierarchyView; use hugr::ops::{LeafOp, OpName, OpTag, OpTrait, OpType}; use hugr::types::TypeBound; use hugr::{HugrView, Node, Port}; @@ -29,14 +28,15 @@ pub trait CircuitHash<'circ>: HugrView { impl<'circ, T> CircuitHash<'circ> for T where - T: HugrView + HierarchyView<'circ>, - for<'a> &'a T: - pg::GraphBase + pg::IntoNeighborsDirected + pg::IntoNodeIdentifiers, + T: HugrView, { fn circuit_hash(&'circ self) -> u64 { let mut hash_state = HashState::default(); - for node in pg::Topo::new(self).iter(self).filter(|&n| n != self.root()) { + for node in pg::Topo::new(&self.as_petgraph()) + .iter(&self.as_petgraph()) + .filter(|&n| n != self.root()) + { let hash = hash_node(self, node, &mut hash_state); hash_state.set_node(self, node, hash); } diff --git a/src/json.rs b/src/json.rs index 9afc2189..ff9ebc85 100644 --- a/src/json.rs +++ b/src/json.rs @@ -38,7 +38,7 @@ pub trait TKETDecode: Sized { /// Convert the serialized circuit to a [`Hugr`]. fn decode(self) -> Result; /// Convert a [`Hugr`] to a new serialized circuit. - fn encode<'circ>(circuit: &'circ impl Circuit<'circ>) -> Result; + fn encode(circuit: &impl Circuit) -> Result; } impl TKETDecode for SerialCircuit { @@ -60,7 +60,7 @@ impl TKETDecode for SerialCircuit { Ok(decoder.finish()) } - fn encode<'circ>(circ: &'circ impl Circuit<'circ>) -> Result { + fn encode(circ: &impl Circuit) -> Result { let mut encoder = JsonEncoder::new(circ); for com in circ.commands() { let optype = circ.command_optype(&com); diff --git a/src/json/encoder.rs b/src/json/encoder.rs index ad81b653..e087d147 100644 --- a/src/json/encoder.rs +++ b/src/json/encoder.rs @@ -40,7 +40,7 @@ pub(super) struct JsonEncoder { impl JsonEncoder { /// Create a new [`JsonEncoder`] from a [`Circuit`]. - pub fn new<'circ>(circ: &impl Circuit<'circ>) -> Self { + pub fn new(circ: &impl Circuit) -> Self { let name = circ.name().map(str::to_string); // Compute the linear qubit and bit registers. Each one have independent diff --git a/src/ops.rs b/src/ops.rs index 8783e698..f29894c5 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -289,12 +289,8 @@ pub(crate) mod test { use std::sync::Arc; - use hugr::{ - extension::OpDef, - hugr::views::{HierarchyView, SiblingGraph}, - ops::handle::DfgID, - Hugr, HugrView, - }; + use hugr::hugr::views::HierarchyView; + use hugr::{extension::OpDef, hugr::views::SiblingGraph, ops::handle::DfgID, Hugr, HugrView}; use rstest::{fixture, rstest}; use crate::{circuit::Circuit, ops::SimpleOpEnum, utils::build_simple_circuit}; diff --git a/src/portmatching/matcher.rs b/src/portmatching/matcher.rs index 4dcd8915..fe3655ff 100644 --- a/src/portmatching/matcher.rs +++ b/src/portmatching/matcher.rs @@ -73,7 +73,7 @@ pub struct PatternMatch<'a, C> { pub(super) root: Node, } -impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> { +impl<'a, C: Circuit + Clone> PatternMatch<'a, C> { /// The matcher's pattern ID of the match. pub fn pattern_id(&self) -> PatternID { self.pattern @@ -187,7 +187,7 @@ impl<'a, C: Circuit<'a> + Clone> PatternMatch<'a, C> { } } -impl<'a, C: Circuit<'a>> Debug for PatternMatch<'a, C> { +impl<'a, C: Circuit> Debug for PatternMatch<'a, C> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("PatternMatch") .field("root", &self.root) @@ -237,10 +237,7 @@ impl PatternMatcher { } /// Find all convex pattern matches in a circuit. - pub fn find_matches<'a, C: Circuit<'a> + Clone>( - &self, - circuit: &'a C, - ) -> Vec> { + pub fn find_matches<'a, C: Circuit + Clone>(&self, circuit: &'a C) -> Vec> { let mut checker = ConvexChecker::new(circuit); circuit .commands() @@ -249,7 +246,7 @@ impl PatternMatcher { } /// Find all convex pattern matches in a circuit rooted at a given node. - fn find_rooted_matches<'a, C: Circuit<'a> + Clone>( + fn find_rooted_matches<'a, C: Circuit + Clone>( &self, circ: &'a C, root: Node, @@ -385,8 +382,8 @@ fn compatible_offsets((_, pout): &(Port, Port), (pin, _): &(Port, Port)) -> bool } /// Check if an edge `e` is valid in a portgraph `g` without weights. -pub(crate) fn validate_unweighted_edge<'circ>( - circ: &impl Circuit<'circ>, +pub(crate) fn validate_unweighted_edge( + circ: &impl Circuit, ) -> impl for<'a> Fn(Node, &'a PEdge) -> Option + '_ { move |src, &(src_port, tgt_port)| { let (next_node, _) = circ @@ -397,8 +394,8 @@ pub(crate) fn validate_unweighted_edge<'circ>( } /// Check if a node `n` is valid in a weighted portgraph `g`. -pub(crate) fn validate_weighted_node<'circ>( - circ: &impl Circuit<'circ>, +pub(crate) fn validate_weighted_node( + circ: &impl Circuit, ) -> impl for<'a> Fn(Node, &PNode) -> bool + '_ { move |v, prop| { let v_weight = MatchOp::try_from(circ.get_optype(v).clone()); @@ -425,11 +422,7 @@ fn handle_match_error(match_res: Result, root: Node) #[cfg(test)] mod tests { - use std::sync::OnceLock; - - use hugr::hugr::views::{DescendantsGraph, HierarchyView}; - use hugr::ops::handle::DfgID; - use hugr::{Hugr, HugrView}; + use hugr::Hugr; use itertools::Itertools; use crate::utils::build_simple_circuit; @@ -437,31 +430,22 @@ mod tests { use super::{CircuitPattern, PatternMatcher}; - static H_CX: OnceLock = OnceLock::new(); - static CX_CX: OnceLock = OnceLock::new(); - - fn h_cx<'a>() -> DescendantsGraph<'a, DfgID> { - let circ = H_CX.get_or_init(|| { - build_simple_circuit(2, |circ| { - circ.append(T2Op::CX, [0, 1]).unwrap(); - circ.append(T2Op::H, [0]).unwrap(); - Ok(()) - }) - .unwrap() - }); - DescendantsGraph::new(circ, circ.root()) + fn h_cx() -> Hugr { + build_simple_circuit(2, |circ| { + circ.append(T2Op::CX, [0, 1]).unwrap(); + circ.append(T2Op::H, [0]).unwrap(); + Ok(()) + }) + .unwrap() } - fn cx_xc<'a>() -> DescendantsGraph<'a, DfgID> { - let circ = CX_CX.get_or_init(|| { - build_simple_circuit(2, |circ| { - circ.append(T2Op::CX, [0, 1]).unwrap(); - circ.append(T2Op::CX, [1, 0]).unwrap(); - Ok(()) - }) - .unwrap() - }); - DescendantsGraph::new(circ, circ.root()) + fn cx_xc() -> Hugr { + build_simple_circuit(2, |circ| { + circ.append(T2Op::CX, [0, 1]).unwrap(); + circ.append(T2Op::CX, [1, 0]).unwrap(); + Ok(()) + }) + .unwrap() } #[test] diff --git a/src/portmatching/pattern.rs b/src/portmatching/pattern.rs index 36e3b56f..0d123f10 100644 --- a/src/portmatching/pattern.rs +++ b/src/portmatching/pattern.rs @@ -33,9 +33,7 @@ impl CircuitPattern { } /// Construct a pattern from a circuit. - pub fn try_from_circuit<'circ, C: Circuit<'circ>>( - circuit: &'circ C, - ) -> Result { + pub fn try_from_circuit(circuit: &C) -> Result { if circuit.num_gates() == 0 { return Err(InvalidPattern::EmptyCircuit); } @@ -79,11 +77,7 @@ impl CircuitPattern { } /// Compute the map from pattern nodes to circuit nodes in `circ`. - pub fn get_match_map<'a, C: Circuit<'a>>( - &self, - root: Node, - circ: &C, - ) -> Option> { + pub fn get_match_map(&self, root: Node, circ: &C) -> Option> { let single_matcher = SinglePatternMatcher::from_pattern(self.pattern.clone()); single_matcher .get_match_map( @@ -121,9 +115,7 @@ impl From for InvalidPattern { #[cfg(test)] mod tests { - use hugr::hugr::views::{DescendantsGraph, HierarchyView, SiblingGraph}; - use hugr::ops::handle::DfgID; - use hugr::{Hugr, HugrView}; + use hugr::Hugr; use itertools::Itertools; use crate::utils::build_simple_circuit; @@ -143,9 +135,8 @@ mod tests { #[test] fn construct_pattern() { let hugr = h_cx(); - let circ: DescendantsGraph<'_, DfgID> = DescendantsGraph::new(&hugr, hugr.root()); - let p = CircuitPattern::try_from_circuit(&circ).unwrap(); + let p = CircuitPattern::try_from_circuit(&hugr).unwrap(); let edges = p .pattern @@ -163,13 +154,12 @@ mod tests { #[test] fn disconnected_pattern() { - let hugr = build_simple_circuit(2, |circ| { + let circ = build_simple_circuit(2, |circ| { circ.append(T2Op::X, [0])?; circ.append(T2Op::T, [1])?; Ok(()) }) .unwrap(); - let circ: SiblingGraph<'_, DfgID> = SiblingGraph::new(&hugr, hugr.root()); assert_eq!( CircuitPattern::try_from_circuit(&circ).unwrap_err(), InvalidPattern::NotConnected diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs index 41faed6c..a08e90a3 100644 --- a/src/portmatching/pyo3.rs +++ b/src/portmatching/pyo3.rs @@ -121,8 +121,8 @@ impl PyPatternMatch { /// /// Requires references to the circuit and pattern to resolve indices /// into these objects. - pub fn try_from_rust<'circ, C: Circuit<'circ> + Clone>( - m: PatternMatch<'circ, C>, + pub fn try_from_rust( + m: PatternMatch, circ: &C, matcher: &PatternMatcher, ) -> PyResult { diff --git a/src/rewrite.rs b/src/rewrite.rs index d4919878..07f0ee7f 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -52,5 +52,5 @@ impl CircuitRewrite { /// Generate rewrite rules for circuits. pub trait Rewriter { /// Get the rewrite rules for a circuit. - fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec; + fn get_rewrites<'a, C: Circuit + Clone>(&'a self, circ: &'a C) -> Vec; } diff --git a/src/rewrite/ecc_rewriter.rs b/src/rewrite/ecc_rewriter.rs index 32d181ed..af4a3eeb 100644 --- a/src/rewrite/ecc_rewriter.rs +++ b/src/rewrite/ecc_rewriter.rs @@ -96,7 +96,7 @@ impl ECCRewriter { } impl Rewriter for ECCRewriter { - fn get_rewrites<'a, C: Circuit<'a> + Clone>(&'a self, circ: &'a C) -> Vec { + fn get_rewrites<'a, C: Circuit + Clone>(&'a self, circ: &'a C) -> Vec { let matches = self.matcher.find_matches(circ); matches .into_iter()