diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 9703d65c3..9279a6ba5 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -429,6 +429,7 @@ pub trait Dataflow: Container { just_inputs: impl IntoIterator, inputs_outputs: impl IntoIterator, just_out_types: TypeRow, + extension_delta: ExtensionSet, ) -> Result, BuildError> { let (input_types, mut input_wires): (Vec, Vec) = just_inputs.into_iter().unzip(); @@ -440,6 +441,7 @@ pub trait Dataflow: Container { just_inputs: input_types.into(), just_outputs: just_out_types, rest: rest_types.into(), + extension_delta, }; // TODO: Make input extensions a parameter let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?; diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index 901324fe9..3f41574b5 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -1,3 +1,4 @@ +use crate::extension::ExtensionSet; use crate::ops; use crate::hugr::{views::HugrView, NodeType}; @@ -74,11 +75,13 @@ impl TailLoopBuilder { just_inputs: impl Into, inputs_outputs: impl Into, just_outputs: impl Into, + extension_delta: ExtensionSet, ) -> Result { let tail_loop = ops::TailLoop { just_inputs: just_inputs.into(), just_outputs: just_outputs.into(), rest: inputs_outputs.into(), + extension_delta, }; // TODO: Allow input extensions to be specified let base = Hugr::new(NodeType::new_open(tail_loop.clone())); @@ -97,7 +100,6 @@ mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }, extension::prelude::{ConstUsize, PRELUDE_ID, USIZE_T}, - extension::ExtensionSet, hugr::ValidationError, ops::Value, type_row, @@ -107,7 +109,8 @@ mod test { #[test] fn basic_loop() -> Result<(), BuildError> { let build_result: Result = { - let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?; + let mut loop_b = + TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID.into())?; let [i1] = loop_b.input_wires_arr(); let const_wire = loop_b.add_load_value(ConstUsize::new(1)); @@ -141,8 +144,12 @@ mod test { )? .outputs_arr(); let loop_id = { - let mut loop_b = - fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?; + let mut loop_b = fbuild.tail_loop_builder( + vec![(BIT, b1)], + vec![], + type_row![NAT], + PRELUDE_ID.into(), + )?; let signature = loop_b.loop_signature()?.clone(); let const_val = Value::true_val(); let const_wire = loop_b.add_load_const(Value::true_val()); @@ -161,7 +168,7 @@ mod test { ([type_row![], type_row![]], const_wire), vec![(BIT, b1)], output_row, - ExtensionSet::new(), + PRELUDE_ID.into(), )?; let mut branch_0 = conditional_b.case_builder(0)?; diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 2e48d1f9d..76f3e54a4 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -432,7 +432,6 @@ impl OpParent for MakeTuple {} impl OpParent for UnpackTuple {} impl OpParent for Tag {} impl OpParent for Lift {} -impl OpParent for TailLoop {} impl OpParent for CFG {} impl OpParent for Conditional {} impl OpParent for FuncDecl {} diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 2aba3667a..b1a6b0e37 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -18,6 +18,8 @@ pub struct TailLoop { pub just_outputs: TypeRow, /// Types that are appended to both input and output pub rest: TypeRow, + /// Extension requirements to execute the body + pub extension_delta: ExtensionSet, } impl_op_name!(TailLoop); @@ -32,7 +34,7 @@ impl DataflowOpTrait for TailLoop { fn signature(&self) -> FunctionType { let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter())); - FunctionType::new(inputs, outputs) + FunctionType::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()) } } @@ -51,6 +53,13 @@ impl TailLoop { } } +impl DataflowParent for TailLoop { + fn inner_signature(&self) -> FunctionType { + FunctionType::new(self.body_input_row(), self.body_output_row()) + .with_extension_delta(self.extension_delta.clone()) + } +} + /// Conditional operation, defined by child `Case` nodes for each branch. #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] diff --git a/hugr-core/src/ops/validate.rs b/hugr-core/src/ops/validate.rs index c87034d60..aaee46049 100644 --- a/hugr-core/src/ops/validate.rs +++ b/hugr-core/src/ops/validate.rs @@ -106,31 +106,6 @@ impl ValidateOp for super::Conditional { } } -impl ValidateOp for super::TailLoop { - fn validity_flags(&self) -> OpValidityFlags { - OpValidityFlags { - allowed_children: OpTag::DataflowChild, - allowed_first_child: OpTag::Input, - allowed_second_child: OpTag::Output, - requires_children: true, - requires_dag: true, - ..Default::default() - } - } - - fn validate_op_children<'a>( - &self, - children: impl DoubleEndedIterator, - ) -> Result<(), ChildrenValidationError> { - validate_io_nodes( - &self.body_input_row(), - &self.body_output_row(), - "tail-controlled loop graph", - children, - ) - } -} - impl ValidateOp for super::CFG { fn validity_flags(&self) -> OpValidityFlags { OpValidityFlags {