Skip to content

Commit

Permalink
Fix tail_loop tests by adding extension_delta to TailLoop and impl Da…
Browse files Browse the repository at this point in the history
…taflowParent
  • Loading branch information
acl-cqc committed May 30, 2024
1 parent 07c09cd commit 08bb5a5
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 32 deletions.
2 changes: 2 additions & 0 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ pub trait Dataflow: Container {
just_inputs: impl IntoIterator<Item = (Type, Wire)>,
inputs_outputs: impl IntoIterator<Item = (Type, Wire)>,
just_out_types: TypeRow,
extension_delta: ExtensionSet,
) -> Result<TailLoopBuilder<&mut Hugr>, BuildError> {
let (input_types, mut input_wires): (Vec<Type>, Vec<Wire>) =
just_inputs.into_iter().unzip();
Expand All @@ -440,6 +441,7 @@ pub trait Dataflow: Container {
just_inputs: input_types.into(),
just_outputs: just_out_types,
rest: rest_types.into(),
extension_delta,
};
// TODO: Make input extensions a parameter
let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?;
Expand Down
17 changes: 12 additions & 5 deletions hugr-core/src/builder/tail_loop.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::extension::ExtensionSet;
use crate::ops;

use crate::hugr::{views::HugrView, NodeType};
Expand Down Expand Up @@ -74,11 +75,13 @@ impl TailLoopBuilder<Hugr> {
just_inputs: impl Into<TypeRow>,
inputs_outputs: impl Into<TypeRow>,
just_outputs: impl Into<TypeRow>,
extension_delta: ExtensionSet,
) -> Result<Self, BuildError> {
let tail_loop = ops::TailLoop {
just_inputs: just_inputs.into(),
just_outputs: just_outputs.into(),
rest: inputs_outputs.into(),
extension_delta,
};
// TODO: Allow input extensions to be specified
let base = Hugr::new(NodeType::new_open(tail_loop.clone()));
Expand All @@ -97,7 +100,6 @@ mod test {
DataflowSubContainer, HugrBuilder, ModuleBuilder,
},
extension::prelude::{ConstUsize, PRELUDE_ID, USIZE_T},
extension::ExtensionSet,
hugr::ValidationError,
ops::Value,
type_row,
Expand All @@ -107,7 +109,8 @@ mod test {
#[test]
fn basic_loop() -> Result<(), BuildError> {
let build_result: Result<Hugr, ValidationError> = {
let mut loop_b = TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T])?;
let mut loop_b =
TailLoopBuilder::new(vec![], vec![BIT], vec![USIZE_T], PRELUDE_ID.into())?;
let [i1] = loop_b.input_wires_arr();
let const_wire = loop_b.add_load_value(ConstUsize::new(1));

Expand Down Expand Up @@ -141,8 +144,12 @@ mod test {
)?
.outputs_arr();
let loop_id = {
let mut loop_b =
fbuild.tail_loop_builder(vec![(BIT, b1)], vec![], type_row![NAT])?;
let mut loop_b = fbuild.tail_loop_builder(
vec![(BIT, b1)],
vec![],
type_row![NAT],
PRELUDE_ID.into(),
)?;
let signature = loop_b.loop_signature()?.clone();
let const_val = Value::true_val();
let const_wire = loop_b.add_load_const(Value::true_val());
Expand All @@ -161,7 +168,7 @@ mod test {
([type_row![], type_row![]], const_wire),
vec![(BIT, b1)],
output_row,
ExtensionSet::new(),
PRELUDE_ID.into(),
)?;

let mut branch_0 = conditional_b.case_builder(0)?;
Expand Down
1 change: 0 additions & 1 deletion hugr-core/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,6 @@ impl OpParent for MakeTuple {}
impl OpParent for UnpackTuple {}
impl OpParent for Tag {}
impl OpParent for Lift {}
impl OpParent for TailLoop {}
impl OpParent for CFG {}
impl OpParent for Conditional {}
impl OpParent for FuncDecl {}
Expand Down
11 changes: 10 additions & 1 deletion hugr-core/src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ pub struct TailLoop {
pub just_outputs: TypeRow,
/// Types that are appended to both input and output
pub rest: TypeRow,
/// Extension requirements to execute the body
pub extension_delta: ExtensionSet,
}

impl_op_name!(TailLoop);
Expand All @@ -32,7 +34,7 @@ impl DataflowOpTrait for TailLoop {
fn signature(&self) -> FunctionType {
let [inputs, outputs] =
[&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter()));
FunctionType::new(inputs, outputs)
FunctionType::new(inputs, outputs).with_extension_delta(self.extension_delta.clone())
}
}

Expand All @@ -51,6 +53,13 @@ impl TailLoop {
}
}

impl DataflowParent for TailLoop {
fn inner_signature(&self) -> FunctionType {
FunctionType::new(self.body_input_row(), self.body_output_row())
.with_extension_delta(self.extension_delta.clone())
}
}

/// Conditional operation, defined by child `Case` nodes for each branch.
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
#[cfg_attr(test, derive(proptest_derive::Arbitrary))]
Expand Down
25 changes: 0 additions & 25 deletions hugr-core/src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,31 +106,6 @@ impl ValidateOp for super::Conditional {
}
}

impl ValidateOp for super::TailLoop {
fn validity_flags(&self) -> OpValidityFlags {
OpValidityFlags {
allowed_children: OpTag::DataflowChild,
allowed_first_child: OpTag::Input,
allowed_second_child: OpTag::Output,
requires_children: true,
requires_dag: true,
..Default::default()
}
}

fn validate_op_children<'a>(
&self,
children: impl DoubleEndedIterator<Item = (NodeIndex, &'a OpType)>,
) -> Result<(), ChildrenValidationError> {
validate_io_nodes(
&self.body_input_row(),
&self.body_output_row(),
"tail-controlled loop graph",
children,
)
}
}

impl ValidateOp for super::CFG {
fn validity_flags(&self) -> OpValidityFlags {
OpValidityFlags {
Expand Down

0 comments on commit 08bb5a5

Please sign in to comment.