Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: extract_dfg inserting the output node with an invalid child order #442

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tket2/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use std::iter::Sum;
pub use command::{Command, CommandIterator};
pub use hash::CircuitHash;
use hugr::hugr::views::{DescendantsGraph, ExtractHugr, HierarchyView};
use hugr_core::hugr::internal::HugrMutInternals;
use itertools::Either::{Left, Right};

use hugr::hugr::hugrmut::HugrMut;
Expand Down Expand Up @@ -317,7 +318,11 @@ impl<T: HugrView> Circuit<T> {
} else {
let view: DescendantsGraph = DescendantsGraph::try_new(&self.hugr, self.parent)
.expect("Circuit parent was not a dataflow container.");
view.extract_hugr().into()
let mut hugr = view.extract_hugr();
// TODO: Remove this once hugr 0.6.0 gets released.
// https://github.com/CQCL/hugr/pull/1239
hugr.set_num_ports(hugr.root(), 0, 0);
hugr.into()
};
extract_dfg::rewrite_into_dfg(&mut circ)?;
Ok(circ)
Expand Down
6 changes: 3 additions & 3 deletions tket2/src/circuit/extract_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn remove_cfg_empty_output_tuple(
signature: FunctionType,
) -> Result<FunctionType, CircuitMutError> {
let sig = signature;
let parent = circ.parent();
let input_node = circ.input_node();

let output_node = circ.output_node();
let output_nodetype = circ.hugr.get_nodetype(output_node).clone();
Expand Down Expand Up @@ -89,8 +89,8 @@ fn remove_cfg_empty_output_tuple(
let new_op = Output {
types: new_types.clone().into(),
};
let new_node = hugr.add_node_with_parent(
parent,
let new_node = hugr.add_node_after(
input_node,
Comment on lines +92 to +93
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be covered right? the test looks ok to me

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Turns out the "guppy-like" circuit was not guppy-like. I was loading a constant tag instead of using the Tag operation...

It should be fixed now.

NodeType::new(
new_op,
output_nodetype
Expand Down
130 changes: 130 additions & 0 deletions tket2/src/passes/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,133 @@ pub enum PytketLoweringError {
#[error("Non-local operations found. Function calls are not supported.")]
NonLocalOperations,
}

#[cfg(test)]
mod test {
use crate::extension::REGISTRY;
use crate::Tk2Op;

use super::*;
use hugr::builder::{
Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer,
};
use hugr::extension::prelude::QB_T;
use hugr::extension::{ExtensionSet, PRELUDE_REGISTRY};
use hugr::ops::handle::NodeHandle;
use hugr::ops::{MakeTuple, OpType, Tag, UnpackTuple};
use hugr::types::{FunctionType, TypeRow};
use hugr::{type_row, HugrView};
use rstest::{fixture, rstest};

/// Builds a circuit in the style of guppy's output.
///
/// This is composed of a `Module`, containing a `FuncDefn`, containing a
/// `CFG`, containing an `Exit` and a `DataflowBlock` with the actual
/// circuit.
#[fixture]
fn guppy_like_circuit() -> Circuit {
fn build() -> Result<Circuit, hugr::builder::BuildError> {
let two_qbs = type_row![QB_T, QB_T];
let circ_signature = FunctionType::new_endo(two_qbs.clone());
let circ;

let mut builder = ModuleBuilder::new();
let _func = {
let mut func = builder.define_function("main", circ_signature.into())?;
let [q1, q2] = func.input_wires_arr();

let cfg = {
let mut cfg = func.cfg_builder(
[(QB_T, q1), (QB_T, q2)],
None,
two_qbs.clone(),
ExtensionSet::new(),
)?;

circ = {
let mut dfg =
cfg.simple_entry_builder(two_qbs.clone(), 1, ExtensionSet::new())?;
let [q1, q2] = dfg.input_wires_arr();

let [q1] = dfg.add_dataflow_op(Tk2Op::H, [q1])?.outputs_arr();
let [q1, q2] = dfg.add_dataflow_op(Tk2Op::CX, [q1, q2])?.outputs_arr();

let [tup] = dfg
.add_dataflow_op(MakeTuple::new(two_qbs.clone()), [q1, q2])?
.outputs_arr();
let [q1, q2] = dfg
.add_dataflow_op(UnpackTuple::new(two_qbs), [tup])?
.outputs_arr();

// Adds an empty Unit branch.
let [branch] = dfg
.add_dataflow_op(Tag::new(0, vec![TypeRow::new()]), [])?
.outputs_arr();

dfg.finish_with_outputs(branch, [q1, q2])?
};
cfg.branch(&circ, 0, &cfg.exit_block())?;

cfg.finish_sub_container()?
};
let [q1, q2] = cfg.outputs_arr();

func.finish_with_outputs([q1, q2])?
};

let hugr = builder.finish_hugr(&PRELUDE_REGISTRY)?;
Ok(Circuit::new(hugr, circ.node()))
}
build().unwrap()
}

#[rstest]
#[case::guppy_like_circuit(guppy_like_circuit())]
fn test_pytket_lowering(#[case] circ: Circuit) {
use cool_asserts::assert_matches;

let lowered_circ = lower_to_pytket(&circ).unwrap();
lowered_circ.hugr().validate(&REGISTRY).unwrap();

assert_eq!(lowered_circ.parent(), lowered_circ.hugr().root());
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.parent()),
OpType::DFG(_)
);
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.input_node()),
OpType::Input(_)
);
assert_matches!(
lowered_circ.hugr().get_optype(lowered_circ.output_node()),
OpType::Output(_)
);
assert_eq!(lowered_circ.num_operations(), circ.num_operations());

// Check that the circuit signature is preserved.
let original_sig = circ.circuit_signature();
let lowered_sig = lowered_circ.circuit_signature();
assert_eq!(lowered_sig.input(), original_sig.input());

// The output signature may have changed due CFG branch tag removal.
let output_count_diff =
original_sig.output().len() as isize - lowered_sig.output().len() as isize;
assert!(
output_count_diff == 0 || output_count_diff == 1,
"Output count mismatch. Original: {}, Lowered: {}",
original_sig,
lowered_sig
);
assert_eq!(
lowered_sig.output()[..],
original_sig.output()[output_count_diff as usize..]
);

// Check that the output node was successfully updated
let output_sig = lowered_circ
.hugr()
.signature(lowered_circ.output_node())
.unwrap();
assert_eq!(lowered_sig.output(), output_sig.input());
}
}
Loading