Skip to content

Commit

Permalink
fix: extract_dfg inserting the output node with an invalid child or…
Browse files Browse the repository at this point in the history
…der (#442)

When replacing a CFG's output with the correct DFG output operation, we
inserted the new node at the end of the container's children, instead of
just after the input.

Includes a new unit test which simulates a circuit definition from
guppy, where we can run the `lower_to_pytket` pass.
  • Loading branch information
aborgna-q authored Jun 28, 2024
1 parent 525b63f commit b98d6bc
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 4 deletions.
7 changes: 6 additions & 1 deletion tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::iter::Sum;
pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView};
use hugr_core::hugr::internal::HugrMutInternals;
use itertools::Either::{Left, Right};

use hugr::hugr::hugrmut::HugrMut;
Expand Down Expand Up @@ -317,7 +318,11 @@ impl<T: HugrView> Circuit<T> {
} else {
let view: DescendantsGraph = DescendantsGraph::try_new(&self.hugr, self.parent)
.expect("Circuit parent was not a dataflow container.");
view.extract_hugr().into()
let mut hugr = view.extract_hugr();
// TODO: Remove this once hugr 0.6.0 gets released.
// https://github.com/CQCL/hugr/pull/1239
hugr.set_num_ports(hugr.root(), 0, 0);
hugr.into()
};
extract_dfg::rewrite_into_dfg(&mut circ)?;
Ok(circ)
Expand Down
6 changes: 3 additions & 3 deletions tket2/src/circuit/extract_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn remove_cfg_empty_output_tuple(
signature: FunctionType,
) -> Result<FunctionType, CircuitMutError> {
let sig = signature;
let parent = circ.parent();
let input_node = circ.input_node();

let output_node = circ.output_node();
let output_nodetype = circ.hugr.get_nodetype(output_node).clone();
Expand Down Expand Up @@ -89,8 +89,8 @@ fn remove_cfg_empty_output_tuple(
let new_op = Output {
types: new_types.clone().into(),
};
let new_node = hugr.add_node_with_parent(
parent,
let new_node = hugr.add_node_after(
input_node,
NodeType::new(
new_op,
output_nodetype
Expand Down
130 changes: 130 additions & 0 deletions tket2/src/passes/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,133 @@ pub enum PytketLoweringError {
#[error("Non-local operations found. Function calls are not supported.")]
NonLocalOperations,
}

#[cfg(test)]
mod test {
use crate::extension::REGISTRY;
use crate::Tk2Op;

use super::*;
use hugr::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer,
};
use hugr::extension::prelude::QB_T;
use hugr::extension::{ExtensionSet, PRELUDE_REGISTRY};
use hugr::ops::handle::NodeHandle;
use hugr::ops::{MakeTuple, OpType, Tag, UnpackTuple};
use hugr::types::{FunctionType, TypeRow};
use hugr::{type_row, HugrView};
use rstest::{fixture, rstest};

/// Builds a circuit in the style of guppy's output.
///
/// This is composed of a `Module`, containing a `FuncDefn`, containing a
/// `CFG`, containing an `Exit` and a `DataflowBlock` with the actual
/// circuit.
#[fixture]
fn guppy_like_circuit() -> Circuit {
fn build() -> Result<Circuit, hugr::builder::BuildError> {
let two_qbs = type_row![QB_T, QB_T];
let circ_signature = FunctionType::new_endo(two_qbs.clone());
let circ;

let mut builder = ModuleBuilder::new();
let _func = {
let mut func = builder.define_function("main", circ_signature.into())?;
let [q1, q2] = func.input_wires_arr();

let cfg = {
let mut cfg = func.cfg_builder(
[(QB_T, q1), (QB_T, q2)],
None,
two_qbs.clone(),
ExtensionSet::new(),
)?;

circ = {
let mut dfg =
cfg.simple_entry_builder(two_qbs.clone(), 1, ExtensionSet::new())?;
let [q1, q2] = dfg.input_wires_arr();

let [q1] = dfg.add_dataflow_op(Tk2Op::H, [q1])?.outputs_arr();
let [q1, q2] = dfg.add_dataflow_op(Tk2Op::CX, [q1, q2])?.outputs_arr();

let [tup] = dfg
.add_dataflow_op(MakeTuple::new(two_qbs.clone()), [q1, q2])?
.outputs_arr();
let [q1, q2] = dfg
.add_dataflow_op(UnpackTuple::new(two_qbs), [tup])?
.outputs_arr();

// Adds an empty Unit branch.
let [branch] = dfg
.add_dataflow_op(Tag::new(0, vec![TypeRow::new()]), [])?
.outputs_arr();

dfg.finish_with_outputs(branch, [q1, q2])?
};
cfg.branch(&circ, 0, &cfg.exit_block())?;

cfg.finish_sub_container()?
};
let [q1, q2] = cfg.outputs_arr();

func.finish_with_outputs([q1, q2])?
};

let hugr = builder.finish_hugr(&PRELUDE_REGISTRY)?;
Ok(Circuit::new(hugr, circ.node()))
}
build().unwrap()
}

#[rstest]
#[case::guppy_like_circuit(guppy_like_circuit())]
fn test_pytket_lowering(#[case] circ: Circuit) {
use cool_asserts::assert_matches;

let lowered_circ = lower_to_pytket(&circ).unwrap();
lowered_circ.hugr().validate(&REGISTRY).unwrap();

assert_eq!(lowered_circ.parent(), lowered_circ.hugr().root());
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.parent()),
OpType::DFG(_)
);
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.input_node()),
OpType::Input(_)
);
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.output_node()),
OpType::Output(_)
);
assert_eq!(lowered_circ.num_operations(), circ.num_operations());

// Check that the circuit signature is preserved.
let original_sig = circ.circuit_signature();
let lowered_sig = lowered_circ.circuit_signature();
assert_eq!(lowered_sig.input(), original_sig.input());

// The output signature may have changed due CFG branch tag removal.
let output_count_diff =
original_sig.output().len() as isize - lowered_sig.output().len() as isize;
assert!(
output_count_diff == 0 || output_count_diff == 1,
"Output count mismatch. Original: {}, Lowered: {}",
original_sig,
lowered_sig
);
assert_eq!(
lowered_sig.output()[..],
original_sig.output()[output_count_diff as usize..]
);

// Check that the output node was successfully updated
let output_sig = lowered_circ
.hugr()
.signature(lowered_circ.output_node())
.unwrap();
assert_eq!(lowered_sig.output(), output_sig.input());
}
}

0 comments on commit b98d6bc

Please sign in to comment.