From 41e15da4360eeaaba48e6bda74bffcf3d24bc067 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Thu, 16 Nov 2023 13:43:28 +0000 Subject: [PATCH] fix: FuncDefns don't require that their extensions match their children (#688) Resolves #673 --- src/extension/infer.rs | 14 ++++++- src/extension/infer/test.rs | 84 ++++++++++++++++++++++++++++++++++++- src/hugr/validate.rs | 14 ++++--- src/hugr/validate/test.rs | 52 ++++++++++++++++++++++- 4 files changed, 156 insertions(+), 8 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index bcea6175f..0b99789c2 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -279,7 +279,19 @@ impl UnificationContext { let m_input_node = self.make_or_get_meta(input, dir); self.add_constraint(m_input_node, Constraint::Equal(m_input)); let m_output_node = self.make_or_get_meta(output, dir); - self.add_constraint(m_output_node, Constraint::Equal(m_output)); + // If the parent node is a FuncDefn, it will have no + // op_signature, so the Incoming and Outgoing ports will + // have equal extension requirements. + // The function that it contains, however, may have an + // extension delta, so its output shouldn't be equal to the + // FuncDefn's output. + // + // TODO: Add a constraint that the extensions of the output + // node of a FuncDefn should be those of the input node plus + // the extension delta specified in the function signature. + if node_type.tag() != OpTag::FuncDefn { + self.add_constraint(m_output_node, Constraint::Equal(m_output)); + } } } diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 5e6452219..c54f38206 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -2,7 +2,9 @@ use std::error::Error; use super::*; use crate::builder::test::closed_dfg_root_hugr; -use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; +use crate::builder::{ + Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, +}; use crate::extension::prelude::QB_T; use crate::extension::ExtensionId; use crate::extension::{prelude::PRELUDE_REGISTRY, ExtensionSet}; @@ -940,3 +942,83 @@ fn sccs() { Some(&ExtensionSet::from_iter([A, B, C, UNKNOWN_EXTENSION])) ); } + +#[test] +/// Note: This test is relying on the builder's `define_function` doing the +/// right thing: it takes input resources via a [`Signature`], which it passes +/// to `create_with_io`, creating concrete resource sets. +/// Inference can still fail for a valid FuncDefn hugr created without using +/// the builder API. +fn simple_funcdefn() -> Result<(), Box> { + let mut builder = ModuleBuilder::new(); + let mut func_builder = builder.define_function( + "F", + FunctionType::new(vec![NAT], vec![NAT]) + .with_extension_delta(&ExtensionSet::singleton(&A)) + .pure(), + )?; + + let [w] = func_builder.input_wires_arr(); + let lift = func_builder.add_dataflow_op( + ops::LeafOp::Lift { + type_row: type_row![NAT], + new_extension: A, + }, + [w], + )?; + let [w] = lift.outputs_arr(); + func_builder.finish_with_outputs([w])?; + builder.finish_prelude_hugr()?; + Ok(()) +} + +#[test] +fn funcdefn_signature_mismatch() -> Result<(), Box> { + let mut builder = ModuleBuilder::new(); + let mut func_builder = builder.define_function( + "F", + FunctionType::new(vec![NAT], vec![NAT]) + .with_extension_delta(&ExtensionSet::singleton(&A)) + .pure(), + )?; + + let [w] = func_builder.input_wires_arr(); + let lift = func_builder.add_dataflow_op( + ops::LeafOp::Lift { + type_row: type_row![NAT], + new_extension: B, + }, + [w], + )?; + let [w] = lift.outputs_arr(); + func_builder.finish_with_outputs([w])?; + let result = builder.finish_prelude_hugr(); + assert_matches!( + result, + Err(ValidationError::CantInfer( + InferExtensionError::MismatchedConcreteWithLocations { .. } + )) + ); + Ok(()) +} + +#[test] +// Test that the difference between a FuncDefn's input and output nodes is being +// constrained to be the same as the extension delta in the FuncDefn signature. +// The FuncDefn here is declared to add resource "A", but its body just wires +// the input to the output. +fn funcdefn_signature_mismatch2() -> Result<(), Box> { + let mut builder = ModuleBuilder::new(); + let func_builder = builder.define_function( + "F", + FunctionType::new(vec![NAT], vec![NAT]) + .with_extension_delta(&ExtensionSet::singleton(&A)) + .pure(), + )?; + + let [w] = func_builder.input_wires_arr(); + func_builder.finish_with_outputs([w])?; + let result = builder.finish_prelude_hugr(); + assert_matches!(result, Err(ValidationError::CantInfer(..))); + Ok(()) +} diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index 0d0724379..2725c5e25 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -200,11 +200,15 @@ impl<'a, 'b> ValidationContext<'a, 'b> { // Secondly that the node has correct children self.validate_children(node, node_type)?; - // If this is a container with I/O nodes, check that the extension they - // define match the extensions of the container. - if let Some([input, output]) = self.hugr.get_io(node) { - self.extension_validator - .validate_io_extensions(node, input, output)?; + // FuncDefns have no resources since they're static nodes, but the + // functions they define can have any extension delta. + if node_type.tag() != OpTag::FuncDefn { + // If this is a container with I/O nodes, check that the extension they + // define match the extensions of the container. + if let Some([input, output]) = self.hugr.get_io(node) { + self.extension_validator + .validate_io_extensions(node, input, output)?; + } } Ok(()) diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index 542573175..edc80c955 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -565,7 +565,7 @@ fn extensions_mismatch() -> Result<(), BuildError> { assert_matches!( handle, Err(ValidationError::ExtensionError( - ExtensionError::ParentIOExtensionMismatch { .. } + ExtensionError::TgtExceedsSrcExtensionsAtPort { .. } )) ); Ok(()) @@ -752,3 +752,53 @@ fn invalid_types() { SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) ); } + +#[test] +fn parent_io_mismatch() { + // The DFG node declares that it has an empty extension delta, + // but it's child graph adds extension "XB", causing a mismatch. + let mut hugr = Hugr::new(NodeType::new_pure(ops::DFG { + signature: FunctionType::new(type_row![USIZE_T], type_row![USIZE_T]), + })); + + let input = hugr + .add_node_with_parent( + hugr.root(), + NodeType::new_pure(ops::Input { + types: type_row![USIZE_T], + }), + ) + .unwrap(); + let output = hugr + .add_node_with_parent( + hugr.root(), + NodeType::new( + ops::Output { + types: type_row![USIZE_T], + }, + ExtensionSet::singleton(&XB), + ), + ) + .unwrap(); + + let lift = hugr + .add_node_with_parent( + hugr.root(), + NodeType::new_pure(ops::LeafOp::Lift { + type_row: type_row![USIZE_T], + new_extension: XB, + }), + ) + .unwrap(); + + hugr.connect(input, 0, lift, 0).unwrap(); + hugr.connect(lift, 0, output, 0).unwrap(); + + let result = hugr.validate(&PRELUDE_REGISTRY); + assert_matches!( + result, + Err(ValidationError::ExtensionError( + ExtensionError::ParentIOExtensionMismatch { .. } + )) + ); +}