From eb1b1501c649cbeca3173663d6a89f7f8e51a369 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Thu, 27 Jun 2024 15:30:19 +0100 Subject: [PATCH 1/3] fix: `extract_dfg` re-adding the output node in the wrong child order --- tket2/src/circuit/extract_dfg.rs | 6 +- tket2/src/passes/pytket.rs | 100 +++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/tket2/src/circuit/extract_dfg.rs b/tket2/src/circuit/extract_dfg.rs index d7deb6de..17c10ed2 100644 --- a/tket2/src/circuit/extract_dfg.rs +++ b/tket2/src/circuit/extract_dfg.rs @@ -47,7 +47,7 @@ fn remove_cfg_empty_output_tuple( signature: FunctionType, ) -> Result { let sig = signature; - let parent = circ.parent(); + let input_node = circ.input_node(); let output_node = circ.output_node(); let output_nodetype = circ.hugr.get_nodetype(output_node).clone(); @@ -89,8 +89,8 @@ fn remove_cfg_empty_output_tuple( let new_op = Output { types: new_types.clone().into(), }; - let new_node = hugr.add_node_with_parent( - parent, + let new_node = hugr.add_node_after( + input_node, NodeType::new( new_op, output_nodetype diff --git a/tket2/src/passes/pytket.rs b/tket2/src/passes/pytket.rs index c4dd3ada..88906742 100644 --- a/tket2/src/passes/pytket.rs +++ b/tket2/src/passes/pytket.rs @@ -38,3 +38,103 @@ pub enum PytketLoweringError { #[error("Non-local operations found. Function calls are not supported.")] NonLocalOperations, } + +#[cfg(test)] +mod test { + use crate::Tk2Op; + + use super::*; + use hugr::builder::{ + Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer, + }; + use hugr::extension::prelude::QB_T; + use hugr::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use hugr::ops::handle::NodeHandle; + use hugr::ops::{MakeTuple, OpType, UnpackTuple, Value}; + use hugr::types::FunctionType; + use hugr::{type_row, HugrView}; + use rstest::{fixture, rstest}; + + /// Builds a circuit in the style of guppy's output. + /// + /// This is composed of a `Module`, containing a `FuncDefn`, containing a + /// `CFG`, containing an `Exit` and a `DataflowBlock` with the actual + /// circuit. + #[fixture] + fn guppy_like_circuit() -> Circuit { + fn build() -> Result { + let two_qbs = type_row![QB_T, QB_T]; + let circ_signature = FunctionType::new_endo(two_qbs.clone()); + let circ; + + let mut builder = ModuleBuilder::new(); + let _func = { + let mut func = builder.define_function("main", circ_signature.into())?; + let [q1, q2] = func.input_wires_arr(); + + let cfg = { + let mut cfg = func.cfg_builder( + [(QB_T, q1), (QB_T, q2)], + None, + two_qbs.clone(), + ExtensionSet::new(), + )?; + + circ = { + let mut dfg = + cfg.simple_entry_builder(two_qbs.clone(), 1, ExtensionSet::new())?; + let [q1, q2] = dfg.input_wires_arr(); + + let [q1] = dfg.add_dataflow_op(Tk2Op::H, [q1])?.outputs_arr(); + let [q1, q2] = dfg.add_dataflow_op(Tk2Op::CX, [q1, q2])?.outputs_arr(); + + let [tup] = dfg + .add_dataflow_op(MakeTuple::new(two_qbs.clone()), [q1, q2])? + .outputs_arr(); + let [q1, q2] = dfg + .add_dataflow_op(UnpackTuple::new(two_qbs), [tup])? + .outputs_arr(); + + let branch = dfg.add_load_const(Value::tuple([])); + + dfg.finish_with_outputs(branch, [q1, q2])? + }; + cfg.branch(&circ, 0, &cfg.exit_block())?; + + cfg.finish_sub_container()? + }; + let [q1, q2] = cfg.outputs_arr(); + + func.finish_with_outputs([q1, q2])? + }; + + let hugr = builder.finish_hugr(&PRELUDE_REGISTRY)?; + Ok(Circuit::new(hugr, circ.node())) + } + build().unwrap() + } + + #[rstest] + #[case::guppy_like_circuit(guppy_like_circuit())] + fn test_pytket_lowering(#[case] circ: Circuit) { + use cool_asserts::assert_matches; + + assert_eq!(circ.num_operations(), 2); + + let lowered_circ = lower_to_pytket(&circ).unwrap(); + + assert_eq!(lowered_circ.num_operations(), 2); + assert_matches!( + lowered_circ.hugr().get_optype(lowered_circ.parent()), + OpType::DFG(_) + ); + assert_matches!( + lowered_circ.hugr().get_optype(lowered_circ.input_node()), + OpType::Input(_) + ); + assert_matches!( + lowered_circ.hugr().get_optype(lowered_circ.output_node()), + OpType::Output(_) + ); + } +} From 15745a1cd7f87813a80c91a37377ed192b59cf67 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Fri, 28 Jun 2024 11:43:04 +0100 Subject: [PATCH 2/3] Fix the guppy-like test not being guppy-like --- tket2/src/circuit.rs | 7 +++++- tket2/src/passes/pytket.rs | 44 ++++++++++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/tket2/src/circuit.rs b/tket2/src/circuit.rs index 92b72f6f..cfe3da4f 100644 --- a/tket2/src/circuit.rs +++ b/tket2/src/circuit.rs @@ -11,6 +11,7 @@ use std::iter::Sum; pub use command::{Command, CommandIterator}; pub use hash::CircuitHash; use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView}; +use hugr_core::hugr::internal::HugrMutInternals; use itertools::Either::{Left, Right}; use hugr::hugr::hugrmut::HugrMut; @@ -317,7 +318,11 @@ impl Circuit { } else { let view: DescendantsGraph = DescendantsGraph::try_new(&self.hugr, self.parent) .expect("Circuit parent was not a dataflow container."); - view.extract_hugr().into() + let mut hugr = view.extract_hugr(); + // TODO: Remove this once hugr 0.6.0 gets released. + // https://github.com/CQCL/hugr/pull/1239 + hugr.set_num_ports(hugr.root(), 0, 0); + hugr.into() }; extract_dfg::rewrite_into_dfg(&mut circ)?; Ok(circ) diff --git a/tket2/src/passes/pytket.rs b/tket2/src/passes/pytket.rs index 88906742..be72bbb0 100644 --- a/tket2/src/passes/pytket.rs +++ b/tket2/src/passes/pytket.rs @@ -41,6 +41,7 @@ pub enum PytketLoweringError { #[cfg(test)] mod test { + use crate::extension::REGISTRY; use crate::Tk2Op; use super::*; @@ -50,8 +51,8 @@ mod test { use hugr::extension::prelude::QB_T; use hugr::extension::{ExtensionSet, PRELUDE_REGISTRY}; use hugr::ops::handle::NodeHandle; - use hugr::ops::{MakeTuple, OpType, UnpackTuple, Value}; - use hugr::types::FunctionType; + use hugr::ops::{MakeTuple, OpType, Tag, UnpackTuple}; + use hugr::types::{FunctionType, TypeRow}; use hugr::{type_row, HugrView}; use rstest::{fixture, rstest}; @@ -95,7 +96,10 @@ mod test { .add_dataflow_op(UnpackTuple::new(two_qbs), [tup])? .outputs_arr(); - let branch = dfg.add_load_const(Value::tuple([])); + // Adds an empty Unit branch. + let [branch] = dfg + .add_dataflow_op(Tag::new(0, vec![TypeRow::new()]), [])? + .outputs_arr(); dfg.finish_with_outputs(branch, [q1, q2])? }; @@ -119,11 +123,10 @@ mod test { fn test_pytket_lowering(#[case] circ: Circuit) { use cool_asserts::assert_matches; - assert_eq!(circ.num_operations(), 2); - let lowered_circ = lower_to_pytket(&circ).unwrap(); + lowered_circ.hugr().validate(®ISTRY).unwrap(); - assert_eq!(lowered_circ.num_operations(), 2); + assert_eq!(lowered_circ.parent(), lowered_circ.hugr().root()); assert_matches!( lowered_circ.hugr().get_optype(lowered_circ.parent()), OpType::DFG(_) @@ -136,5 +139,34 @@ mod test { lowered_circ.hugr().get_optype(lowered_circ.output_node()), OpType::Output(_) ); + assert_eq!(lowered_circ.num_operations(), circ.num_operations()); + + // Check that the circuit signature is preserved. + let original_sig = circ.circuit_signature(); + let lowered_sig = lowered_circ.circuit_signature(); + assert_eq!(lowered_sig.input(), original_sig.input()); + + // The output signature may have changed due CFG branch tag removal. + let output_count_diff = + original_sig.output().len() as isize - lowered_sig.output().len() as isize; + assert!( + output_count_diff == 0 || output_count_diff == 1, + "Output count mismatch. Original: {}, Lowered: {}", + original_sig, + lowered_sig + ); + assert_eq!( + lowered_sig.output()[..], + original_sig.output()[output_count_diff as usize..] + ); + + // Check that the output node was successfully updated + let output_sig = lowered_circ + .hugr() + .signature(lowered_circ.output_node()) + .unwrap(); + assert_eq!(lowered_sig.output(), output_sig.input()); + println!("Lowered circuit: {}", lowered_sig); + println!("Output node: {}", output_sig); } } From 54a63ad9c939d99722a63117905de19bf80486e5 Mon Sep 17 00:00:00 2001 From: Agustin Borgna Date: Fri, 28 Jun 2024 11:48:48 +0100 Subject: [PATCH 3/3] Remove prints --- tket2/src/passes/pytket.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/tket2/src/passes/pytket.rs b/tket2/src/passes/pytket.rs index be72bbb0..9def01fa 100644 --- a/tket2/src/passes/pytket.rs +++ b/tket2/src/passes/pytket.rs @@ -166,7 +166,5 @@ mod test { .signature(lowered_circ.output_node()) .unwrap(); assert_eq!(lowered_sig.output(), output_sig.input()); - println!("Lowered circuit: {}", lowered_sig); - println!("Output node: {}", output_sig); } }