Skip to content

Commit

Permalink
test!: Improve coverage in signature and validate (#643)
Browse files Browse the repository at this point in the history
Builder couldn't handle non-local edges between cfg blocks so fix that
and add tests.

BREAKING CHANGES: FunctionType::linear no longer exists
  • Loading branch information
ss2165 authored Nov 8, 2023
1 parent b71cae6 commit 250f221
Show file tree
Hide file tree
Showing 6 changed files with 937 additions and 780 deletions.
17 changes: 12 additions & 5 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::hugr::hugrmut::InsertionResult;
use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::views::HugrView;
use crate::hugr::{NodeMetadata, ValidationError};
use crate::ops::{self, LeafOp, OpTrait, OpType};
use crate::ops::{self, LeafOp, OpTag, OpTrait, OpType};
use crate::{IncomingPort, Node, OutgoingPort};

use std::iter;
Expand Down Expand Up @@ -666,6 +666,7 @@ fn wire_up<T: Dataflow + ?Sized>(
let base = data_builder.hugr_mut();

let src_parent = base.get_parent(src);
let src_parent_parent = src_parent.and_then(|src| base.get_parent(src));
let dst_parent = base.get_parent(dst);
let local_source = src_parent == dst_parent;
if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() {
Expand All @@ -687,7 +688,10 @@ fn wire_up<T: Dataflow + ?Sized>(
let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p))
.tuple_windows()
.find_map(|(ancestor, ancestor_parent)| {
(ancestor_parent == src_parent).then_some(ancestor)
(ancestor_parent == src_parent ||
// Dom edge - in CFGs
Some(ancestor_parent) == src_parent_parent)
.then_some(ancestor)
})
else {
let val_err: ValidationError = InterGraphEdgeError::NoRelation {
Expand All @@ -700,9 +704,12 @@ fn wire_up<T: Dataflow + ?Sized>(
return Err(val_err.into());
};

// TODO: Avoid adding duplicate edges
// This should be easy with https://github.com/CQCL-DEV/hugr/issues/130
base.add_other_edge(src, src_sibling)?;
if !OpTag::BasicBlock.is_superset(base.get_optype(src).tag())
&& !OpTag::BasicBlock.is_superset(base.get_optype(src_sibling).tag())
{
// Add a state order constraint unless one of the nodes is a CFG BasicBlock
base.add_other_edge(src, src_sibling)?;
}
} else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() {
// Don't copy linear edges.
return Err(BuildError::NoCopyLinear(typ));
Expand Down
65 changes: 65 additions & 0 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ mod test {
use crate::builder::build_traits::HugrBuilder;
use crate::builder::{DataflowSubContainer, ModuleBuilder};

use crate::hugr::validate::InterGraphEdgeError;
use crate::hugr::ValidationError;
use crate::{builder::test::NAT, type_row};
use cool_asserts::assert_matches;

Expand Down Expand Up @@ -393,4 +395,67 @@ mod test {
cfg_builder.branch(&entry, 1, &exit)?;
Ok(())
}
#[test]
fn test_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_variants = vec![type_row![]];

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?;
let [inw] = entry_b.input_wires_arr();
let entry = {
let sum = entry_b.load_const(&sum_tuple_const)?;

entry_b.finish_with_outputs(sum, [])?
};
let mut middle_b =
cfg_builder.simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?;
let middle = {
let c = middle_b.load_const(&sum_tuple_const)?;
middle_b.finish_with_outputs(c, [inw])?
};
let exit = cfg_builder.exit_block();
cfg_builder.branch(&entry, 0, &middle)?;
cfg_builder.branch(&middle, 0, &exit)?;
assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_));

Ok(())
}

#[test]
fn test_non_dom_edge() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?;
let sum_tuple_const =
cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?;
let sum_variants = vec![type_row![]];
let mut middle_b = cfg_builder
.simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?;
let [inw] = middle_b.input_wires_arr();
let middle = {
let c = middle_b.load_const(&sum_tuple_const)?;
middle_b.finish_with_outputs(c, [inw])?
};

let mut entry_b =
cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?;
let entry = {
let sum = entry_b.load_const(&sum_tuple_const)?;
// entry block uses wire from middle block even though middle block
// does not dominate entry
entry_b.finish_with_outputs(sum, [inw])?
};
let exit = cfg_builder.exit_block();
cfg_builder.branch(&entry, 0, &middle)?;
cfg_builder.branch(&middle, 0, &exit)?;
assert_matches!(
cfg_builder.finish_prelude_hugr(),
Err(ValidationError::InterGraphEdgeError(
InterGraphEdgeError::NonDominatedAncestor { .. }
))
);

Ok(())
}
}
52 changes: 52 additions & 0 deletions src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ pub(crate) mod test {

use crate::std_extensions::logic::test::and_op;
use crate::std_extensions::quantum::test::h_gate;
use crate::types::Type;
use crate::{
builder::{
test::{n_identity, BIT, NAT, QB},
Expand Down Expand Up @@ -500,4 +501,55 @@ pub(crate) mod test {

Ok(())
}

#[test]
fn non_cfg_ancestor() -> Result<(), BuildError> {
let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]);
let mut b = DFGBuilder::new(unit_sig.clone())?;
let b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?;
let b_child_in_wire = b_child.input().out_wire(0);
b_child.finish_with_outputs([])?;
let b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?;

// DFG block has edge coming a sibling block, which is only valid for
// CFGs
let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?;

let res = b.finish_prelude_hugr_with_outputs([b_child_2_handle.out_wire(0)]);

assert_matches!(
res,
Err(BuildError::InvalidHUGR(
ValidationError::InterGraphEdgeError(InterGraphEdgeError::NonCFGAncestor { .. })
))
);
Ok(())
}

#[test]
fn no_relation_edge() -> Result<(), BuildError> {
let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]);
let mut b = DFGBuilder::new(unit_sig.clone())?;
let mut b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?;
let b_child_child =
b_child.dfg_builder(unit_sig.clone(), None, [b_child.input().out_wire(0)])?;
let b_child_child_in_wire = b_child_child.input().out_wire(0);

b_child_child.finish_with_outputs([])?;
b_child.finish_with_outputs([])?;

let mut b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?;
let b_child_2_child =
b_child_2.dfg_builder(unit_sig.clone(), None, [b_child_2.input().out_wire(0)])?;

let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]);

assert_matches!(
res.map(|h| h.handle().node()), // map to something that implements Debug
Err(BuildError::InvalidHUGR(
ValidationError::InterGraphEdgeError(InterGraphEdgeError::NoRelation { .. })
))
);
Ok(())
}
}
Loading

0 comments on commit 250f221

Please sign in to comment.