diff --git a/src/builder/build_traits.rs b/src/builder/build_traits.rs index 300900c0b..dfdc86488 100644 --- a/src/builder/build_traits.rs +++ b/src/builder/build_traits.rs @@ -2,7 +2,7 @@ use crate::hugr::hugrmut::InsertionResult; use crate::hugr::validate::InterGraphEdgeError; use crate::hugr::views::HugrView; use crate::hugr::{NodeMetadata, ValidationError}; -use crate::ops::{self, LeafOp, OpTrait, OpType}; +use crate::ops::{self, LeafOp, OpTag, OpTrait, OpType}; use crate::{IncomingPort, Node, OutgoingPort}; use std::iter; @@ -666,6 +666,7 @@ fn wire_up( let base = data_builder.hugr_mut(); let src_parent = base.get_parent(src); + let src_parent_parent = src_parent.and_then(|src| base.get_parent(src)); let dst_parent = base.get_parent(dst); let local_source = src_parent == dst_parent; if let EdgeKind::Value(typ) = base.get_optype(src).port_kind(src_port).unwrap() { @@ -687,7 +688,10 @@ fn wire_up( let Some(src_sibling) = iter::successors(dst_parent, |&p| base.get_parent(p)) .tuple_windows() .find_map(|(ancestor, ancestor_parent)| { - (ancestor_parent == src_parent).then_some(ancestor) + (ancestor_parent == src_parent || + // Dom edge - in CFGs + Some(ancestor_parent) == src_parent_parent) + .then_some(ancestor) }) else { let val_err: ValidationError = InterGraphEdgeError::NoRelation { @@ -700,9 +704,12 @@ fn wire_up( return Err(val_err.into()); }; - // TODO: Avoid adding duplicate edges - // This should be easy with https://github.com/CQCL-DEV/hugr/issues/130 - base.add_other_edge(src, src_sibling)?; + if !OpTag::BasicBlock.is_superset(base.get_optype(src).tag()) + && !OpTag::BasicBlock.is_superset(base.get_optype(src_sibling).tag()) + { + // Add a state order constraint unless one of the nodes is a CFG BasicBlock + base.add_other_edge(src, src_sibling)?; + } } else if !typ.copyable() & base.linked_ports(src, src_port).next().is_some() { // Don't copy linear edges. return Err(BuildError::NoCopyLinear(typ)); diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index de99f5de9..fe2954bd5 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -325,6 +325,8 @@ mod test { use crate::builder::build_traits::HugrBuilder; use crate::builder::{DataflowSubContainer, ModuleBuilder}; + use crate::hugr::validate::InterGraphEdgeError; + use crate::hugr::ValidationError; use crate::{builder::test::NAT, type_row}; use cool_asserts::assert_matches; @@ -393,4 +395,67 @@ mod test { cfg_builder.branch(&entry, 1, &exit)?; Ok(()) } + #[test] + fn test_dom_edge() -> Result<(), BuildError> { + let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; + let sum_tuple_const = + cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let sum_variants = vec![type_row![]]; + + let mut entry_b = + cfg_builder.entry_builder(sum_variants.clone(), type_row![], ExtensionSet::new())?; + let [inw] = entry_b.input_wires_arr(); + let entry = { + let sum = entry_b.load_const(&sum_tuple_const)?; + + entry_b.finish_with_outputs(sum, [])? + }; + let mut middle_b = + cfg_builder.simple_block_builder(FunctionType::new(type_row![], type_row![NAT]), 1)?; + let middle = { + let c = middle_b.load_const(&sum_tuple_const)?; + middle_b.finish_with_outputs(c, [inw])? + }; + let exit = cfg_builder.exit_block(); + cfg_builder.branch(&entry, 0, &middle)?; + cfg_builder.branch(&middle, 0, &exit)?; + assert_matches!(cfg_builder.finish_prelude_hugr(), Ok(_)); + + Ok(()) + } + + #[test] + fn test_non_dom_edge() -> Result<(), BuildError> { + let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; + let sum_tuple_const = + cfg_builder.add_constant(ops::Const::unary_unit_sum(), ExtensionSet::new())?; + let sum_variants = vec![type_row![]]; + let mut middle_b = cfg_builder + .simple_block_builder(FunctionType::new(type_row![NAT], type_row![NAT]), 1)?; + let [inw] = middle_b.input_wires_arr(); + let middle = { + let c = middle_b.load_const(&sum_tuple_const)?; + middle_b.finish_with_outputs(c, [inw])? + }; + + let mut entry_b = + cfg_builder.entry_builder(sum_variants.clone(), type_row![NAT], ExtensionSet::new())?; + let entry = { + let sum = entry_b.load_const(&sum_tuple_const)?; + // entry block uses wire from middle block even though middle block + // does not dominate entry + entry_b.finish_with_outputs(sum, [inw])? + }; + let exit = cfg_builder.exit_block(); + cfg_builder.branch(&entry, 0, &middle)?; + cfg_builder.branch(&middle, 0, &exit)?; + assert_matches!( + cfg_builder.finish_prelude_hugr(), + Err(ValidationError::InterGraphEdgeError( + InterGraphEdgeError::NonDominatedAncestor { .. } + )) + ); + + Ok(()) + } } diff --git a/src/builder/dataflow.rs b/src/builder/dataflow.rs index b5f1960b8..d4fa1eafe 100644 --- a/src/builder/dataflow.rs +++ b/src/builder/dataflow.rs @@ -219,6 +219,7 @@ pub(crate) mod test { use crate::std_extensions::logic::test::and_op; use crate::std_extensions::quantum::test::h_gate; + use crate::types::Type; use crate::{ builder::{ test::{n_identity, BIT, NAT, QB}, @@ -500,4 +501,55 @@ pub(crate) mod test { Ok(()) } + + #[test] + fn non_cfg_ancestor() -> Result<(), BuildError> { + let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]); + let mut b = DFGBuilder::new(unit_sig.clone())?; + let b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?; + let b_child_in_wire = b_child.input().out_wire(0); + b_child.finish_with_outputs([])?; + let b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?; + + // DFG block has edge coming a sibling block, which is only valid for + // CFGs + let b_child_2_handle = b_child_2.finish_with_outputs([b_child_in_wire])?; + + let res = b.finish_prelude_hugr_with_outputs([b_child_2_handle.out_wire(0)]); + + assert_matches!( + res, + Err(BuildError::InvalidHUGR( + ValidationError::InterGraphEdgeError(InterGraphEdgeError::NonCFGAncestor { .. }) + )) + ); + Ok(()) + } + + #[test] + fn no_relation_edge() -> Result<(), BuildError> { + let unit_sig = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]); + let mut b = DFGBuilder::new(unit_sig.clone())?; + let mut b_child = b.dfg_builder(unit_sig.clone(), None, [b.input().out_wire(0)])?; + let b_child_child = + b_child.dfg_builder(unit_sig.clone(), None, [b_child.input().out_wire(0)])?; + let b_child_child_in_wire = b_child_child.input().out_wire(0); + + b_child_child.finish_with_outputs([])?; + b_child.finish_with_outputs([])?; + + let mut b_child_2 = b.dfg_builder(unit_sig.clone(), None, [])?; + let b_child_2_child = + b_child_2.dfg_builder(unit_sig.clone(), None, [b_child_2.input().out_wire(0)])?; + + let res = b_child_2_child.finish_with_outputs([b_child_child_in_wire]); + + assert_matches!( + res.map(|h| h.handle().node()), // map to something that implements Debug + Err(BuildError::InvalidHUGR( + ValidationError::InterGraphEdgeError(InterGraphEdgeError::NoRelation { .. }) + )) + ); + Ok(()) + } } diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 7f5a8bbe8..ec601dfa6 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -737,758 +737,4 @@ pub enum InterGraphEdgeError { } #[cfg(test)] -mod test { - use cool_asserts::assert_matches; - - use super::*; - use crate::builder::test::closed_dfg_root_hugr; - use crate::builder::{BuildError, Container, Dataflow, DataflowSubContainer, ModuleBuilder}; - use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; - use crate::extension::{ - Extension, ExtensionId, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY, - }; - use crate::hugr::hugrmut::sealed::HugrMutInternals; - use crate::hugr::{HugrError, HugrMut, NodeType}; - use crate::macros::const_extension_ids; - use crate::ops::dataflow::IOTrait; - use crate::ops::{self, LeafOp, OpType}; - use crate::std_extensions::logic; - use crate::std_extensions::logic::test::{and_op, not_op}; - use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; - use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow}; - use crate::{type_row, Direction, IncomingPort, Node}; - - const NAT: Type = crate::extension::prelude::USIZE_T; - const Q: Type = crate::extension::prelude::QB_T; - - /// Creates a hugr with a single function definition that copies a bit `copies` times. - /// - /// Returns the hugr and the node index of the definition. - fn make_simple_hugr(copies: usize) -> (Hugr, Node) { - let def_op: OpType = ops::FuncDefn { - name: "main".into(), - signature: FunctionType::new(type_row![BOOL_T], vec![BOOL_T; copies]), - } - .into(); - - let mut b = Hugr::default(); - let root = b.root(); - - let def = b.add_node_with_parent(root, def_op).unwrap(); - let _ = add_df_children(&mut b, def, copies); - - (b, def) - } - - /// Adds an input{BOOL_T}, copy{BOOL_T -> BOOL_T^copies}, and output{BOOL_T^copies} operation to a dataflow container. - /// - /// Returns the node indices of each of the operations. - fn add_df_children(b: &mut Hugr, parent: Node, copies: usize) -> (Node, Node, Node) { - let input = b - .add_node_with_parent(parent, ops::Input::new(type_row![BOOL_T])) - .unwrap(); - let output = b - .add_node_with_parent(parent, ops::Output::new(vec![BOOL_T; copies])) - .unwrap(); - let copy = b - .add_node_with_parent(parent, LeafOp::Noop { ty: BOOL_T }) - .unwrap(); - - b.connect(input, 0, copy, 0).unwrap(); - for i in 0..copies { - b.connect(copy, 0, output, i).unwrap(); - } - - (input, copy, output) - } - - /// Adds an input{BOOL_T}, tag_constant(0, BOOL_T^tuple_sum_size), tag(BOOL_T^tuple_sum_size), and - /// output{Sum{unit^tuple_sum_size}, BOOL_T} operation to a dataflow container. - /// Intended to be used to populate a BasicBlock node in a CFG. - /// - /// Returns the node indices of each of the operations. - fn add_block_children( - b: &mut Hugr, - parent: Node, - tuple_sum_size: usize, - ) -> (Node, Node, Node, Node) { - let const_op = ops::Const::unit_sum(0, tuple_sum_size as u8); - let tag_type = Type::new_unit_sum(tuple_sum_size as u8); - - let input = b - .add_node_with_parent(parent, ops::Input::new(type_row![BOOL_T])) - .unwrap(); - let output = b - .add_node_with_parent(parent, ops::Output::new(vec![tag_type.clone(), BOOL_T])) - .unwrap(); - let tag_def = b.add_node_with_parent(b.root(), const_op).unwrap(); - let tag = b - .add_node_with_parent(parent, ops::LoadConstant { datatype: tag_type }) - .unwrap(); - - b.connect(tag_def, 0, tag, 0).unwrap(); - b.add_other_edge(input, tag).unwrap(); - b.connect(tag, 0, output, 0).unwrap(); - b.connect(input, 0, output, 1).unwrap(); - - (input, tag_def, tag, output) - } - - #[test] - fn invalid_root() { - let declare_op: OpType = ops::FuncDecl { - name: "main".into(), - signature: Default::default(), - } - .into(); - - let mut b = Hugr::default(); - let root = b.root(); - assert_eq!(b.validate(&EMPTY_REG), Ok(())); - - // Add another hierarchy root - let other = b.add_node(ops::Module.into()); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::NoParent { node }) => assert_eq!(node, other) - ); - b.set_parent(other, root).unwrap(); - b.replace_op(other, NodeType::new_pure(declare_op)).unwrap(); - b.add_ports(other, Direction::Outgoing, 1); - assert_eq!(b.validate(&EMPTY_REG), Ok(())); - - // Make the hugr root not a hierarchy root - { - let mut hugr = b.clone(); - hugr.root = other.pg_index(); - assert_matches!( - hugr.validate(&EMPTY_REG), - Err(ValidationError::RootNotRoot { node }) => assert_eq!(node, other) - ); - } - } - - #[test] - fn leaf_root() { - let leaf_op: OpType = LeafOp::Noop { ty: USIZE_T }.into(); - - let b = Hugr::new(NodeType::new_pure(leaf_op)); - assert_eq!(b.validate(&EMPTY_REG), Ok(())); - } - - #[test] - fn dfg_root() { - let dfg_op: OpType = ops::DFG { - signature: FunctionType::new_linear(type_row![BOOL_T]), - } - .into(); - - let mut b = Hugr::new(NodeType::new_pure(dfg_op)); - let root = b.root(); - add_df_children(&mut b, root, 1); - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); - } - - #[test] - fn simple_hugr() { - let mut b = make_simple_hugr(2).0; - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); - } - - #[test] - /// General children restrictions. - fn children_restrictions() { - let (mut b, def) = make_simple_hugr(2); - let root = b.root(); - let (_input, copy, _output) = b - .hierarchy - .children(def.pg_index()) - .map_into() - .collect_tuple() - .unwrap(); - - // Add a definition without children - let def_sig = FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]); - let new_def = b - .add_node_with_parent( - root, - ops::FuncDefn { - signature: def_sig, - name: "main".into(), - }, - ) - .unwrap(); - assert_matches!( - b.update_validate(&EMPTY_REG), - Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) - ); - - // Add children to the definition, but move it to be a child of the copy - add_df_children(&mut b, new_def, 2); - b.set_parent(new_def, copy).unwrap(); - assert_matches!( - b.update_validate(&EMPTY_REG), - Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) - ); - let closure = b.infer_extensions().unwrap(); - b.set_parent(new_def, root).unwrap(); - - // After moving the previous definition to a valid place, - // add an input node to the module subgraph - let new_input = b - .add_node_with_parent(root, ops::Input::new(type_row![])) - .unwrap(); - assert_matches!( - b.validate_with_extension_closure(closure, &EMPTY_REG), - Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)} - ); - } - - #[test] - /// Validation errors in a dataflow subgraph. - fn df_children_restrictions() { - let (mut b, def) = make_simple_hugr(2); - let (_input, output, copy) = b - .hierarchy - .children(def.pg_index()) - .map_into() - .collect_tuple() - .unwrap(); - - // Replace the output operation of the df subgraph with a copy - b.replace_op(output, NodeType::new_pure(LeafOp::Noop { ty: NAT })) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) - ); - - // Revert it back to an output, but with the wrong number of ports - b.replace_op( - output, - NodeType::new_pure(ops::Output::new(type_row![BOOL_T])), - ) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} - ); - b.replace_op( - output, - NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), - ) - .unwrap(); - - // After fixing the output back, replace the copy with an output op - b.replace_op( - copy, - NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), - ) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} - ); - } - - #[test] - /// Validation errors in a dataflow subgraph. - fn cfg_children_restrictions() { - let (mut b, def) = make_simple_hugr(1); - let (_input, _output, copy) = b - .hierarchy - .children(def.pg_index()) - .map_into() - .collect_tuple() - .unwrap(); - // Write Extension annotations into the Hugr while it's still well-formed - // enough for us to compute them - let closure = b.infer_extensions().unwrap(); - b.instantiate_extensions(closure); - b.validate(&EMPTY_REG).unwrap(); - b.replace_op( - copy, - NodeType::new_pure(ops::CFG { - signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), - }), - ) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::ContainerWithoutChildren { .. }) - ); - let cfg = copy; - - // Construct a valid CFG, with one BasicBlock node and one exit node - let block = b - .add_node_with_parent( - cfg, - ops::BasicBlock::DFB { - inputs: type_row![BOOL_T], - tuple_sum_rows: vec![type_row![]], - other_outputs: type_row![BOOL_T], - extension_delta: ExtensionSet::new(), - }, - ) - .unwrap(); - add_block_children(&mut b, block, 1); - let exit = b - .add_node_with_parent( - cfg, - ops::BasicBlock::Exit { - cfg_outputs: type_row![BOOL_T], - }, - ) - .unwrap(); - b.add_other_edge(block, exit).unwrap(); - assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); - - // Test malformed errors - - // Add an internal exit node - let exit2 = b - .add_node_after( - exit, - ops::BasicBlock::Exit { - cfg_outputs: type_row![BOOL_T], - }, - ) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) - => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} - ); - b.remove_node(exit2).unwrap(); - - // Change the types in the BasicBlock node to work on qubits instead of bits - b.replace_op( - block, - NodeType::new_pure(ops::BasicBlock::DFB { - inputs: type_row![Q], - tuple_sum_rows: vec![type_row![]], - other_outputs: type_row![Q], - extension_delta: ExtensionSet::new(), - }), - ) - .unwrap(); - let mut block_children = b.hierarchy.children(block.pg_index()); - let block_input = block_children.next().unwrap().into(); - let block_output = block_children.next_back().unwrap().into(); - b.replace_op( - block_input, - NodeType::new_pure(ops::Input::new(type_row![Q])), - ) - .unwrap(); - b.replace_op( - block_output, - NodeType::new_pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), - ) - .unwrap(); - assert_matches!( - b.validate(&EMPTY_REG), - Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) - => assert_eq!(parent, cfg) - ); - } - - #[test] - fn test_ext_edge() -> Result<(), HugrError> { - let mut h = closed_dfg_root_hugr(FunctionType::new( - type_row![BOOL_T, BOOL_T], - type_row![BOOL_T], - )); - let [input, output] = h.get_io(h.root()).unwrap(); - - // Nested DFG BOOL_T -> BOOL_T - let sub_dfg = h.add_node_with_parent( - h.root(), - ops::DFG { - signature: FunctionType::new_linear(type_row![BOOL_T]), - }, - )?; - // this Xor has its 2nd input unconnected - let sub_op = { - let sub_input = h.add_node_with_parent(sub_dfg, ops::Input::new(type_row![BOOL_T]))?; - let sub_output = - h.add_node_with_parent(sub_dfg, ops::Output::new(type_row![BOOL_T]))?; - let sub_op = h.add_node_with_parent(sub_dfg, and_op())?; - h.connect(sub_input, 0, sub_op, 0)?; - h.connect(sub_op, 0, sub_output, 0)?; - sub_op - }; - - h.connect(input, 0, sub_dfg, 0)?; - h.connect(sub_dfg, 0, output, 0)?; - - assert_matches!( - h.update_validate(&EMPTY_REG), - Err(ValidationError::UnconnectedPort { .. }) - ); - - h.connect(input, 1, sub_op, 1)?; - assert_matches!( - h.update_validate(&EMPTY_REG), - Err(ValidationError::InterGraphEdgeError( - InterGraphEdgeError::MissingOrderEdge { .. } - )) - ); - //Order edge. This will need metadata indicating its purpose. - h.add_other_edge(input, sub_dfg)?; - h.update_validate(&EMPTY_REG).unwrap(); - Ok(()) - } - - const_extension_ids! { - const XA: ExtensionId = "A"; - const XB: ExtensionId = "BOOL_EXT"; - } - - #[test] - fn test_local_const() -> Result<(), HugrError> { - let mut h = closed_dfg_root_hugr(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])); - let [input, output] = h.get_io(h.root()).unwrap(); - let and = h.add_node_with_parent(h.root(), and_op())?; - h.connect(input, 0, and, 0)?; - h.connect(and, 0, output, 0)?; - assert_eq!( - h.update_validate(&EMPTY_REG), - Err(ValidationError::UnconnectedPort { - node: and, - port: IncomingPort::from(1).into(), - port_kind: EdgeKind::Value(BOOL_T) - }) - ); - let const_op: ops::Const = logic::EXTENSION - .get_value(logic::TRUE_NAME) - .unwrap() - .typed_value() - .clone(); - // Second input of Xor from a constant - let cst = h.add_node_with_parent(h.root(), const_op)?; - let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: BOOL_T })?; - h.connect(cst, 0, lcst, 0)?; - h.connect(lcst, 0, and, 1)?; - // There is no edge from Input to LoadConstant, but that's OK: - h.update_validate(&EMPTY_REG).unwrap(); - Ok(()) - } - - #[test] - /// A wire with no extension requirements is wired into a node which has - /// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed - /// by adding a lift node, but for validation this is an error. - fn missing_lift_node() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - let mut main = module_builder.define_function( - "main", - FunctionType::new(type_row![NAT], type_row![NAT]).pure(), - )?; - let [main_input] = main.input_wires_arr(); - - let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - // Inner DFG has extension requirements that the wire wont satisfy - .with_input_extensions(ExtensionSet::from_iter([XA, XB])); - - let f_builder = main.dfg_builder( - inner_sig.signature, - Some(inner_sig.input_extensions), - [main_input], - )?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) - ); - Ok(()) - } - - #[test] - /// A wire with extension requirement `[A]` is wired into a an output with no - /// extension req. In the validation extension typechecking, we don't do any - /// unification, so don't allow open extension variables on the function - /// signature, so this fails. - fn too_many_extension() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - - let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]).pure(); - - let mut main = module_builder.define_function("main", main_sig)?; - let [main_input] = main.input_wires_arr(); - - let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]) - .with_extension_delta(&ExtensionSet::singleton(&XA)) - .with_input_extensions(ExtensionSet::new()); - - let f_builder = main.dfg_builder( - inner_sig.signature, - Some(inner_sig.input_extensions), - [main_input], - )?; - let f_inputs = f_builder.input_wires(); - let f_handle = f_builder.finish_with_outputs(f_inputs)?; - let [f_output] = f_handle.outputs_arr(); - main.finish_with_outputs([f_output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::SrcExceedsTgtExtensionsAtPort { .. } - )) - ); - Ok(()) - } - - #[test] - /// A wire with extension requirements `[A]` and another with requirements - /// `[BOOL_T]` are both wired into a node which requires its inputs to have - /// requirements `[A,BOOL_T]`. A slightly more complex test of the error from - /// `missing_lift_node`. - fn extensions_mismatch() -> Result<(), BuildError> { - let mut module_builder = ModuleBuilder::new(); - - let all_rs = ExtensionSet::from_iter([XA, XB]); - - let main_sig = FunctionType::new(type_row![], type_row![NAT]) - .with_extension_delta(&all_rs) - .with_input_extensions(ExtensionSet::new()); - - let mut main = module_builder.define_function("main", main_sig)?; - - let inner_left_sig = FunctionType::new(type_row![], type_row![NAT]) - .with_input_extensions(ExtensionSet::singleton(&XA)); - - let inner_right_sig = FunctionType::new(type_row![], type_row![NAT]) - .with_input_extensions(ExtensionSet::singleton(&XB)); - - let inner_mult_sig = - FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_input_extensions(all_rs); - - let [left_wire] = main - .dfg_builder( - inner_left_sig.signature, - Some(inner_left_sig.input_extensions), - [], - )? - .finish_with_outputs([])? - .outputs_arr(); - - let [right_wire] = main - .dfg_builder( - inner_right_sig.signature, - Some(inner_right_sig.input_extensions), - [], - )? - .finish_with_outputs([])? - .outputs_arr(); - - let builder = main.dfg_builder( - inner_mult_sig.signature, - Some(inner_mult_sig.input_extensions), - [left_wire, right_wire], - )?; - let [_left, _right] = builder.input_wires_arr(); - let [output] = builder.finish_with_outputs([])?.outputs_arr(); - - main.finish_with_outputs([output])?; - let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); - assert_matches!( - handle, - Err(ValidationError::ExtensionError( - ExtensionError::ParentIOExtensionMismatch { .. } - )) - ); - Ok(()) - } - - #[test] - fn parent_signature_mismatch() -> Result<(), BuildError> { - let rs = ExtensionSet::singleton(&XA); - - let main_signature = - FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); - - let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { - signature: main_signature, - })); - let input = hugr.add_node_with_parent( - hugr.root(), - NodeType::new_pure(ops::Input { - types: type_row![NAT], - }), - )?; - let output = hugr.add_node_with_parent( - hugr.root(), - NodeType::new( - ops::Output { - types: type_row![NAT], - }, - rs, - ), - )?; - hugr.connect(input, 0, output, 0)?; - - assert_matches!( - hugr.validate(&PRELUDE_REGISTRY), - Err(ValidationError::ExtensionError( - ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } - )) - ); - Ok(()) - } - - #[test] - fn dfg_with_cycles() -> Result<(), HugrError> { - let mut h = closed_dfg_root_hugr(FunctionType::new( - type_row![BOOL_T, BOOL_T], - type_row![BOOL_T], - )); - let [input, output] = h.get_io(h.root()).unwrap(); - let and = h.add_node_with_parent(h.root(), and_op())?; - let not1 = h.add_node_with_parent(h.root(), not_op())?; - let not2 = h.add_node_with_parent(h.root(), not_op())?; - h.connect(input, 0, and, 0)?; - h.connect(and, 0, not1, 0)?; - h.connect(not1, 0, and, 1)?; - h.connect(input, 1, not2, 0)?; - h.connect(not2, 0, output, 0)?; - // The graph contains a cycle: - assert_matches!(h.validate(&EMPTY_REG), Err(ValidationError::NotADag { .. })); - Ok(()) - } - - fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { - let mut b = Hugr::default(); - let row: TypeRow = vec![t].into(); - - let def = b - .add_node_with_parent( - b.root(), - ops::FuncDefn { - name: "main".into(), - signature: FunctionType::new(row.clone(), row.clone()), - }, - ) - .unwrap(); - - let input = b - .add_node_with_parent(def, ops::Input::new(row.clone())) - .unwrap(); - let output = b.add_node_with_parent(def, ops::Output::new(row)).unwrap(); - b.connect(input, 0, output, 0).unwrap(); - (b, def) - } - #[test] - fn unregistered_extension() { - let (mut h, def) = identity_hugr_with_type(USIZE_T); - assert_eq!( - h.validate(&EMPTY_REG), - Err(ValidationError::SignatureError { - node: def, - cause: SignatureError::ExtensionNotFound(PRELUDE.name.clone()) - }) - ); - h.update_validate(&PRELUDE_REGISTRY).unwrap(); - } - - #[test] - fn invalid_types() { - let name: ExtensionId = "MyExt".try_into().unwrap(); - let mut e = Extension::new(name.clone()); - e.add_type( - "MyContainer".into(), - vec![TypeParam::Type(TypeBound::Copyable)], - "".into(), - TypeDefBound::Explicit(TypeBound::Any), - ) - .unwrap(); - let reg: ExtensionRegistry = [e, PRELUDE.to_owned()].into(); - - let validate_to_sig_error = |t: CustomType| { - let (h, def) = identity_hugr_with_type(Type::new_extension(t)); - match h.validate(®) { - Err(ValidationError::SignatureError { node, cause }) if node == def => cause, - e => panic!("Expected SignatureError at def node, got {:?}", e), - } - }; - - let valid = Type::new_extension(CustomType::new( - "MyContainer", - vec![TypeArg::Type { ty: USIZE_T }], - name.clone(), - TypeBound::Any, - )); - assert_eq!( - identity_hugr_with_type(valid.clone()) - .0 - .update_validate(®), - Ok(()) - ); - - // valid is Any, so is not allowed as an element of an outer MyContainer. - let element_outside_bound = CustomType::new( - "MyContainer", - vec![TypeArg::Type { ty: valid.clone() }], - name.clone(), - TypeBound::Any, - ); - assert_eq!( - validate_to_sig_error(element_outside_bound), - SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { - param: TypeParam::Type(TypeBound::Copyable), - arg: TypeArg::Type { ty: valid } - }) - ); - - let bad_bound = CustomType::new( - "MyContainer", - vec![TypeArg::Type { ty: USIZE_T }], - name.clone(), - TypeBound::Copyable, - ); - assert_eq!( - validate_to_sig_error(bad_bound.clone()), - SignatureError::WrongBound { - actual: TypeBound::Copyable, - expected: TypeBound::Any - } - ); - - // bad_bound claims to be Copyable, which is valid as an element for the outer MyContainer. - let nested = CustomType::new( - "MyContainer", - vec![TypeArg::Type { - ty: Type::new_extension(bad_bound), - }], - name.clone(), - TypeBound::Any, - ); - assert_eq!( - validate_to_sig_error(nested), - SignatureError::WrongBound { - actual: TypeBound::Copyable, - expected: TypeBound::Any - } - ); - - let too_many_type_args = CustomType::new( - "MyContainer", - vec![TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 3 }], - name.clone(), - TypeBound::Any, - ); - assert_eq!( - validate_to_sig_error(too_many_type_args), - SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) - ); - } -} +mod test; diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs new file mode 100644 index 000000000..a4a4dd1a2 --- /dev/null +++ b/src/hugr/validate/test.rs @@ -0,0 +1,752 @@ +use cool_asserts::assert_matches; + +use super::*; +use crate::builder::test::closed_dfg_root_hugr; +use crate::builder::{BuildError, Container, Dataflow, DataflowSubContainer, ModuleBuilder}; +use crate::extension::prelude::{BOOL_T, PRELUDE, USIZE_T}; +use crate::extension::{ + Extension, ExtensionId, ExtensionSet, TypeDefBound, EMPTY_REG, PRELUDE_REGISTRY, +}; +use crate::hugr::hugrmut::sealed::HugrMutInternals; +use crate::hugr::{HugrError, HugrMut, NodeType}; +use crate::macros::const_extension_ids; +use crate::ops::dataflow::IOTrait; +use crate::ops::{self, LeafOp, OpType}; +use crate::std_extensions::logic; +use crate::std_extensions::logic::test::{and_op, not_op}; +use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; +use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow}; +use crate::{type_row, Direction, IncomingPort, Node}; + +const NAT: Type = crate::extension::prelude::USIZE_T; +const Q: Type = crate::extension::prelude::QB_T; + +/// Creates a hugr with a single function definition that copies a bit `copies` times. +/// +/// Returns the hugr and the node index of the definition. +fn make_simple_hugr(copies: usize) -> (Hugr, Node) { + let def_op: OpType = ops::FuncDefn { + name: "main".into(), + signature: FunctionType::new(type_row![BOOL_T], vec![BOOL_T; copies]), + } + .into(); + + let mut b = Hugr::default(); + let root = b.root(); + + let def = b.add_node_with_parent(root, def_op).unwrap(); + let _ = add_df_children(&mut b, def, copies); + + (b, def) +} + +/// Adds an input{BOOL_T}, copy{BOOL_T -> BOOL_T^copies}, and output{BOOL_T^copies} operation to a dataflow container. +/// +/// Returns the node indices of each of the operations. +fn add_df_children(b: &mut Hugr, parent: Node, copies: usize) -> (Node, Node, Node) { + let input = b + .add_node_with_parent(parent, ops::Input::new(type_row![BOOL_T])) + .unwrap(); + let output = b + .add_node_with_parent(parent, ops::Output::new(vec![BOOL_T; copies])) + .unwrap(); + let copy = b + .add_node_with_parent(parent, LeafOp::Noop { ty: BOOL_T }) + .unwrap(); + + b.connect(input, 0, copy, 0).unwrap(); + for i in 0..copies { + b.connect(copy, 0, output, i).unwrap(); + } + + (input, copy, output) +} + +/// Adds an input{BOOL_T}, tag_constant(0, BOOL_T^tuple_sum_size), tag(BOOL_T^tuple_sum_size), and +/// output{Sum{unit^tuple_sum_size}, BOOL_T} operation to a dataflow container. +/// Intended to be used to populate a BasicBlock node in a CFG. +/// +/// Returns the node indices of each of the operations. +fn add_block_children( + b: &mut Hugr, + parent: Node, + tuple_sum_size: usize, +) -> (Node, Node, Node, Node) { + let const_op = ops::Const::unit_sum(0, tuple_sum_size as u8); + let tag_type = Type::new_unit_sum(tuple_sum_size as u8); + + let input = b + .add_node_with_parent(parent, ops::Input::new(type_row![BOOL_T])) + .unwrap(); + let output = b + .add_node_with_parent(parent, ops::Output::new(vec![tag_type.clone(), BOOL_T])) + .unwrap(); + let tag_def = b.add_node_with_parent(b.root(), const_op).unwrap(); + let tag = b + .add_node_with_parent(parent, ops::LoadConstant { datatype: tag_type }) + .unwrap(); + + b.connect(tag_def, 0, tag, 0).unwrap(); + b.add_other_edge(input, tag).unwrap(); + b.connect(tag, 0, output, 0).unwrap(); + b.connect(input, 0, output, 1).unwrap(); + + (input, tag_def, tag, output) +} + +#[test] +fn invalid_root() { + let declare_op: OpType = ops::FuncDecl { + name: "main".into(), + signature: Default::default(), + } + .into(); + + let mut b = Hugr::default(); + let root = b.root(); + assert_eq!(b.validate(&EMPTY_REG), Ok(())); + + // Add another hierarchy root + let other = b.add_node(ops::Module.into()); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::NoParent { node }) => assert_eq!(node, other) + ); + b.set_parent(other, root).unwrap(); + b.replace_op(other, NodeType::new_pure(declare_op)).unwrap(); + b.add_ports(other, Direction::Outgoing, 1); + assert_eq!(b.validate(&EMPTY_REG), Ok(())); + + // Make the hugr root not a hierarchy root + { + let mut hugr = b.clone(); + hugr.root = other.pg_index(); + assert_matches!( + hugr.validate(&EMPTY_REG), + Err(ValidationError::RootNotRoot { node }) => assert_eq!(node, other) + ); + } +} + +#[test] +fn leaf_root() { + let leaf_op: OpType = LeafOp::Noop { ty: USIZE_T }.into(); + + let b = Hugr::new(NodeType::new_pure(leaf_op)); + assert_eq!(b.validate(&EMPTY_REG), Ok(())); +} + +#[test] +fn dfg_root() { + let dfg_op: OpType = ops::DFG { + signature: FunctionType::new_linear(type_row![BOOL_T]), + } + .into(); + + let mut b = Hugr::new(NodeType::new_pure(dfg_op)); + let root = b.root(); + add_df_children(&mut b, root, 1); + assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); +} + +#[test] +fn simple_hugr() { + let mut b = make_simple_hugr(2).0; + assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); +} + +#[test] +/// General children restrictions. +fn children_restrictions() { + let (mut b, def) = make_simple_hugr(2); + let root = b.root(); + let (_input, copy, _output) = b + .hierarchy + .children(def.pg_index()) + .map_into() + .collect_tuple() + .unwrap(); + + // Add a definition without children + let def_sig = FunctionType::new(type_row![BOOL_T], type_row![BOOL_T, BOOL_T]); + let new_def = b + .add_node_with_parent( + root, + ops::FuncDefn { + signature: def_sig, + name: "main".into(), + }, + ) + .unwrap(); + assert_matches!( + b.update_validate(&EMPTY_REG), + Err(ValidationError::ContainerWithoutChildren { node, .. }) => assert_eq!(node, new_def) + ); + + // Add children to the definition, but move it to be a child of the copy + add_df_children(&mut b, new_def, 2); + b.set_parent(new_def, copy).unwrap(); + assert_matches!( + b.update_validate(&EMPTY_REG), + Err(ValidationError::NonContainerWithChildren { node, .. }) => assert_eq!(node, copy) + ); + let closure = b.infer_extensions().unwrap(); + b.set_parent(new_def, root).unwrap(); + + // After moving the previous definition to a valid place, + // add an input node to the module subgraph + let new_input = b + .add_node_with_parent(root, ops::Input::new(type_row![])) + .unwrap(); + assert_matches!( + b.validate_with_extension_closure(closure, &EMPTY_REG), + Err(ValidationError::InvalidParentOp { parent, child, .. }) => {assert_eq!(parent, root); assert_eq!(child, new_input)} + ); +} + +#[test] +/// Validation errors in a dataflow subgraph. +fn df_children_restrictions() { + let (mut b, def) = make_simple_hugr(2); + let (_input, output, copy) = b + .hierarchy + .children(def.pg_index()) + .map_into() + .collect_tuple() + .unwrap(); + + // Replace the output operation of the df subgraph with a copy + b.replace_op(output, NodeType::new_pure(LeafOp::Noop { ty: NAT })) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) + ); + + // Revert it back to an output, but with the wrong number of ports + b.replace_op( + output, + NodeType::new_pure(ops::Output::new(type_row![BOOL_T])), + ) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) + => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} + ); + b.replace_op( + output, + NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), + ) + .unwrap(); + + // After fixing the output back, replace the copy with an output op + b.replace_op( + copy, + NodeType::new_pure(ops::Output::new(type_row![BOOL_T, BOOL_T])), + ) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) + => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} + ); +} + +#[test] +/// Validation errors in a dataflow subgraph. +fn cfg_children_restrictions() { + let (mut b, def) = make_simple_hugr(1); + let (_input, _output, copy) = b + .hierarchy + .children(def.pg_index()) + .map_into() + .collect_tuple() + .unwrap(); + // Write Extension annotations into the Hugr while it's still well-formed + // enough for us to compute them + let closure = b.infer_extensions().unwrap(); + b.instantiate_extensions(closure); + b.validate(&EMPTY_REG).unwrap(); + b.replace_op( + copy, + NodeType::new_pure(ops::CFG { + signature: FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), + }), + ) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::ContainerWithoutChildren { .. }) + ); + let cfg = copy; + + // Construct a valid CFG, with one BasicBlock node and one exit node + let block = b + .add_node_with_parent( + cfg, + ops::BasicBlock::DFB { + inputs: type_row![BOOL_T], + tuple_sum_rows: vec![type_row![]], + other_outputs: type_row![BOOL_T], + extension_delta: ExtensionSet::new(), + }, + ) + .unwrap(); + add_block_children(&mut b, block, 1); + let exit = b + .add_node_with_parent( + cfg, + ops::BasicBlock::Exit { + cfg_outputs: type_row![BOOL_T], + }, + ) + .unwrap(); + b.add_other_edge(block, exit).unwrap(); + assert_eq!(b.update_validate(&EMPTY_REG), Ok(())); + + // Test malformed errors + + // Add an internal exit node + let exit2 = b + .add_node_after( + exit, + ops::BasicBlock::Exit { + cfg_outputs: type_row![BOOL_T], + }, + ) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) + => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} + ); + b.remove_node(exit2).unwrap(); + + // Change the types in the BasicBlock node to work on qubits instead of bits + b.replace_op( + block, + NodeType::new_pure(ops::BasicBlock::DFB { + inputs: type_row![Q], + tuple_sum_rows: vec![type_row![]], + other_outputs: type_row![Q], + extension_delta: ExtensionSet::new(), + }), + ) + .unwrap(); + let mut block_children = b.hierarchy.children(block.pg_index()); + let block_input = block_children.next().unwrap().into(); + let block_output = block_children.next_back().unwrap().into(); + b.replace_op( + block_input, + NodeType::new_pure(ops::Input::new(type_row![Q])), + ) + .unwrap(); + b.replace_op( + block_output, + NodeType::new_pure(ops::Output::new(type_row![Type::new_unit_sum(1), Q])), + ) + .unwrap(); + assert_matches!( + b.validate(&EMPTY_REG), + Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) + => assert_eq!(parent, cfg) + ); +} + +#[test] +fn test_ext_edge() -> Result<(), HugrError> { + let mut h = closed_dfg_root_hugr(FunctionType::new( + type_row![BOOL_T, BOOL_T], + type_row![BOOL_T], + )); + let [input, output] = h.get_io(h.root()).unwrap(); + + // Nested DFG BOOL_T -> BOOL_T + let sub_dfg = h.add_node_with_parent( + h.root(), + ops::DFG { + signature: FunctionType::new_linear(type_row![BOOL_T]), + }, + )?; + // this Xor has its 2nd input unconnected + let sub_op = { + let sub_input = h.add_node_with_parent(sub_dfg, ops::Input::new(type_row![BOOL_T]))?; + let sub_output = h.add_node_with_parent(sub_dfg, ops::Output::new(type_row![BOOL_T]))?; + let sub_op = h.add_node_with_parent(sub_dfg, and_op())?; + h.connect(sub_input, 0, sub_op, 0)?; + h.connect(sub_op, 0, sub_output, 0)?; + sub_op + }; + + h.connect(input, 0, sub_dfg, 0)?; + h.connect(sub_dfg, 0, output, 0)?; + + assert_matches!( + h.update_validate(&EMPTY_REG), + Err(ValidationError::UnconnectedPort { .. }) + ); + + h.connect(input, 1, sub_op, 1)?; + assert_matches!( + h.update_validate(&EMPTY_REG), + Err(ValidationError::InterGraphEdgeError( + InterGraphEdgeError::MissingOrderEdge { .. } + )) + ); + //Order edge. This will need metadata indicating its purpose. + h.add_other_edge(input, sub_dfg)?; + h.update_validate(&EMPTY_REG).unwrap(); + Ok(()) +} + +const_extension_ids! { + const XA: ExtensionId = "A"; + const XB: ExtensionId = "BOOL_EXT"; +} + +#[test] +fn test_local_const() -> Result<(), HugrError> { + let mut h = closed_dfg_root_hugr(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])); + let [input, output] = h.get_io(h.root()).unwrap(); + let and = h.add_node_with_parent(h.root(), and_op())?; + h.connect(input, 0, and, 0)?; + h.connect(and, 0, output, 0)?; + assert_eq!( + h.update_validate(&EMPTY_REG), + Err(ValidationError::UnconnectedPort { + node: and, + port: IncomingPort::from(1).into(), + port_kind: EdgeKind::Value(BOOL_T) + }) + ); + let const_op: ops::Const = logic::EXTENSION + .get_value(logic::TRUE_NAME) + .unwrap() + .typed_value() + .clone(); + // Second input of Xor from a constant + let cst = h.add_node_with_parent(h.root(), const_op)?; + let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: BOOL_T })?; + h.connect(cst, 0, lcst, 0)?; + h.connect(lcst, 0, and, 1)?; + // There is no edge from Input to LoadConstant, but that's OK: + h.update_validate(&EMPTY_REG).unwrap(); + Ok(()) +} + +#[test] +/// A wire with no extension requirements is wired into a node which has +/// [A,BOOL_T] extensions required on its inputs and outputs. This could be fixed +/// by adding a lift node, but for validation this is an error. +fn missing_lift_node() -> Result<(), BuildError> { + let mut module_builder = ModuleBuilder::new(); + let mut main = module_builder.define_function( + "main", + FunctionType::new(type_row![NAT], type_row![NAT]).pure(), + )?; + let [main_input] = main.input_wires_arr(); + + let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]) + // Inner DFG has extension requirements that the wire wont satisfy + .with_input_extensions(ExtensionSet::from_iter([XA, XB])); + + let f_builder = main.dfg_builder( + inner_sig.signature, + Some(inner_sig.input_extensions), + [main_input], + )?; + let f_inputs = f_builder.input_wires(); + let f_handle = f_builder.finish_with_outputs(f_inputs)?; + let [f_output] = f_handle.outputs_arr(); + main.finish_with_outputs([f_output])?; + let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); + + assert_matches!( + handle, + Err(ValidationError::ExtensionError( + ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } + )) + ); + Ok(()) +} + +#[test] +/// A wire with extension requirement `[A]` is wired into a an output with no +/// extension req. In the validation extension typechecking, we don't do any +/// unification, so don't allow open extension variables on the function +/// signature, so this fails. +fn too_many_extension() -> Result<(), BuildError> { + let mut module_builder = ModuleBuilder::new(); + + let main_sig = FunctionType::new(type_row![NAT], type_row![NAT]).pure(); + + let mut main = module_builder.define_function("main", main_sig)?; + let [main_input] = main.input_wires_arr(); + + let inner_sig = FunctionType::new(type_row![NAT], type_row![NAT]) + .with_extension_delta(&ExtensionSet::singleton(&XA)) + .with_input_extensions(ExtensionSet::new()); + + let f_builder = main.dfg_builder( + inner_sig.signature, + Some(inner_sig.input_extensions), + [main_input], + )?; + let f_inputs = f_builder.input_wires(); + let f_handle = f_builder.finish_with_outputs(f_inputs)?; + let [f_output] = f_handle.outputs_arr(); + main.finish_with_outputs([f_output])?; + let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); + assert_matches!( + handle, + Err(ValidationError::ExtensionError( + ExtensionError::SrcExceedsTgtExtensionsAtPort { .. } + )) + ); + Ok(()) +} + +#[test] +/// A wire with extension requirements `[A]` and another with requirements +/// `[BOOL_T]` are both wired into a node which requires its inputs to have +/// requirements `[A,BOOL_T]`. A slightly more complex test of the error from +/// `missing_lift_node`. +fn extensions_mismatch() -> Result<(), BuildError> { + let mut module_builder = ModuleBuilder::new(); + + let all_rs = ExtensionSet::from_iter([XA, XB]); + + let main_sig = FunctionType::new(type_row![], type_row![NAT]) + .with_extension_delta(&all_rs) + .with_input_extensions(ExtensionSet::new()); + + let mut main = module_builder.define_function("main", main_sig)?; + + let inner_left_sig = FunctionType::new(type_row![], type_row![NAT]) + .with_input_extensions(ExtensionSet::singleton(&XA)); + + let inner_right_sig = FunctionType::new(type_row![], type_row![NAT]) + .with_input_extensions(ExtensionSet::singleton(&XB)); + + let inner_mult_sig = + FunctionType::new(type_row![NAT, NAT], type_row![NAT]).with_input_extensions(all_rs); + + let [left_wire] = main + .dfg_builder( + inner_left_sig.signature, + Some(inner_left_sig.input_extensions), + [], + )? + .finish_with_outputs([])? + .outputs_arr(); + + let [right_wire] = main + .dfg_builder( + inner_right_sig.signature, + Some(inner_right_sig.input_extensions), + [], + )? + .finish_with_outputs([])? + .outputs_arr(); + + let builder = main.dfg_builder( + inner_mult_sig.signature, + Some(inner_mult_sig.input_extensions), + [left_wire, right_wire], + )?; + let [_left, _right] = builder.input_wires_arr(); + let [output] = builder.finish_with_outputs([])?.outputs_arr(); + + main.finish_with_outputs([output])?; + let handle = module_builder.hugr().validate(&PRELUDE_REGISTRY); + assert_matches!( + handle, + Err(ValidationError::ExtensionError( + ExtensionError::ParentIOExtensionMismatch { .. } + )) + ); + Ok(()) +} + +#[test] +fn parent_signature_mismatch() -> Result<(), BuildError> { + let rs = ExtensionSet::singleton(&XA); + + let main_signature = + FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&rs); + + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { + signature: main_signature, + })); + let input = hugr.add_node_with_parent( + hugr.root(), + NodeType::new_pure(ops::Input { + types: type_row![NAT], + }), + )?; + let output = hugr.add_node_with_parent( + hugr.root(), + NodeType::new( + ops::Output { + types: type_row![NAT], + }, + rs, + ), + )?; + hugr.connect(input, 0, output, 0)?; + + assert_matches!( + hugr.validate(&PRELUDE_REGISTRY), + Err(ValidationError::ExtensionError( + ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } + )) + ); + Ok(()) +} + +#[test] +fn dfg_with_cycles() -> Result<(), HugrError> { + let mut h = closed_dfg_root_hugr(FunctionType::new( + type_row![BOOL_T, BOOL_T], + type_row![BOOL_T], + )); + let [input, output] = h.get_io(h.root()).unwrap(); + let and = h.add_node_with_parent(h.root(), and_op())?; + let not1 = h.add_node_with_parent(h.root(), not_op())?; + let not2 = h.add_node_with_parent(h.root(), not_op())?; + h.connect(input, 0, and, 0)?; + h.connect(and, 0, not1, 0)?; + h.connect(not1, 0, and, 1)?; + h.connect(input, 1, not2, 0)?; + h.connect(not2, 0, output, 0)?; + // The graph contains a cycle: + assert_matches!(h.validate(&EMPTY_REG), Err(ValidationError::NotADag { .. })); + Ok(()) +} + +fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { + let mut b = Hugr::default(); + let row: TypeRow = vec![t].into(); + + let def = b + .add_node_with_parent( + b.root(), + ops::FuncDefn { + name: "main".into(), + signature: FunctionType::new(row.clone(), row.clone()), + }, + ) + .unwrap(); + + let input = b + .add_node_with_parent(def, ops::Input::new(row.clone())) + .unwrap(); + let output = b.add_node_with_parent(def, ops::Output::new(row)).unwrap(); + b.connect(input, 0, output, 0).unwrap(); + (b, def) +} +#[test] +fn unregistered_extension() { + let (mut h, def) = identity_hugr_with_type(USIZE_T); + assert_eq!( + h.validate(&EMPTY_REG), + Err(ValidationError::SignatureError { + node: def, + cause: SignatureError::ExtensionNotFound(PRELUDE.name.clone()) + }) + ); + h.update_validate(&PRELUDE_REGISTRY).unwrap(); +} + +#[test] +fn invalid_types() { + let name: ExtensionId = "MyExt".try_into().unwrap(); + let mut e = Extension::new(name.clone()); + e.add_type( + "MyContainer".into(), + vec![TypeParam::Type(TypeBound::Copyable)], + "".into(), + TypeDefBound::Explicit(TypeBound::Any), + ) + .unwrap(); + let reg: ExtensionRegistry = [e, PRELUDE.to_owned()].into(); + + let validate_to_sig_error = |t: CustomType| { + let (h, def) = identity_hugr_with_type(Type::new_extension(t)); + match h.validate(®) { + Err(ValidationError::SignatureError { node, cause }) if node == def => cause, + e => panic!("Expected SignatureError at def node, got {:?}", e), + } + }; + + let valid = Type::new_extension(CustomType::new( + "MyContainer", + vec![TypeArg::Type { ty: USIZE_T }], + name.clone(), + TypeBound::Any, + )); + assert_eq!( + identity_hugr_with_type(valid.clone()) + .0 + .update_validate(®), + Ok(()) + ); + + // valid is Any, so is not allowed as an element of an outer MyContainer. + let element_outside_bound = CustomType::new( + "MyContainer", + vec![TypeArg::Type { ty: valid.clone() }], + name.clone(), + TypeBound::Any, + ); + assert_eq!( + validate_to_sig_error(element_outside_bound), + SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { + param: TypeParam::Type(TypeBound::Copyable), + arg: TypeArg::Type { ty: valid } + }) + ); + + let bad_bound = CustomType::new( + "MyContainer", + vec![TypeArg::Type { ty: USIZE_T }], + name.clone(), + TypeBound::Copyable, + ); + assert_eq!( + validate_to_sig_error(bad_bound.clone()), + SignatureError::WrongBound { + actual: TypeBound::Copyable, + expected: TypeBound::Any + } + ); + + // bad_bound claims to be Copyable, which is valid as an element for the outer MyContainer. + let nested = CustomType::new( + "MyContainer", + vec![TypeArg::Type { + ty: Type::new_extension(bad_bound), + }], + name.clone(), + TypeBound::Any, + ); + assert_eq!( + validate_to_sig_error(nested), + SignatureError::WrongBound { + actual: TypeBound::Copyable, + expected: TypeBound::Any + } + ); + + let too_many_type_args = CustomType::new( + "MyContainer", + vec![TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 3 }], + name.clone(), + TypeBound::Any, + ); + assert_eq!( + validate_to_sig_error(too_many_type_args), + SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) + ); +} diff --git a/src/types/signature.rs b/src/types/signature.rs index b2995dec0..bcd28d2ea 100644 --- a/src/types/signature.rs +++ b/src/types/signature.rs @@ -103,10 +103,11 @@ impl FunctionType { } } - /// Returns the type of a value [`Port`]. Returns `None` if the port is out - /// of bounds. + /// Returns a mutable reference to the type of a value [`Port`]. + /// Returns `None` if the port is out of bounds. #[inline] - pub fn get_mut(&mut self, port: Port) -> Option<&mut Type> { + pub fn get_mut(&mut self, port: impl Into) -> Option<&mut Type> { + let port = port.into(); match port.direction() { Direction::Incoming => self.input.get_mut(port), Direction::Outgoing => self.output.get_mut(port), @@ -169,23 +170,6 @@ impl FunctionType { } impl FunctionType { - /// Returns the linear part of the signature - /// TODO: This fails when mixing different linear types. - #[inline(always)] - pub fn linear(&self) -> impl Iterator { - debug_assert_eq!( - self.input - .iter() - .filter(|t| !t.copyable()) - .collect::>(), - self.output - .iter() - .filter(|t| !t.copyable()) - .collect::>() - ); - self.input.iter().filter(|t| !t.copyable()) - } - /// Returns the `Port`s in the signature for a given direction. #[inline] pub fn ports(&self, dir: Direction) -> impl Iterator { @@ -247,3 +231,54 @@ impl Display for Signature { } } } + +#[cfg(test)] +mod test { + use crate::{ + extension::{prelude::USIZE_T, ExtensionId}, + type_row, + }; + + use super::*; + #[test] + fn test_function_type() { + let mut f_type = FunctionType::new(type_row![Type::UNIT], type_row![Type::UNIT]); + assert_eq!(f_type.input_count(), 1); + assert_eq!(f_type.output_count(), 1); + + assert_eq!(f_type.input_types(), &[Type::UNIT]); + + assert_eq!( + f_type.get(Port::new(Direction::Incoming, 0)), + Some(&Type::UNIT) + ); + + let out = Port::new(Direction::Outgoing, 0); + *(f_type.get_mut(out).unwrap()) = USIZE_T; + + assert_eq!(f_type.get(out), Some(&USIZE_T)); + + assert_eq!(f_type.input_types(), &[Type::UNIT]); + assert_eq!(f_type.output_types(), &[USIZE_T]); + } + + #[test] + fn test_signature() { + let f_type = FunctionType::new(type_row![Type::UNIT], type_row![USIZE_T]); + + let sig: Signature = f_type.pure(); + + assert_eq!(sig.input(), &type_row![Type::UNIT]); + assert_eq!(sig.output(), &type_row![USIZE_T]); + } + + #[test] + fn test_display() { + let f_type = FunctionType::new(type_row![Type::UNIT], type_row![USIZE_T]); + assert_eq!(f_type.to_string(), "[Tuple([])] -> [[]][usize([])]"); + let sig: Signature = f_type.with_input_extensions(ExtensionSet::singleton( + &ExtensionId::new("Example").unwrap(), + )); + assert_eq!(sig.to_string(), "[Tuple([])] -> [[]][usize([])]"); + } +}