From 9fe478e5e8794827ebc6aa2eaea06c6fc7c9d8b3 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Mon, 11 Sep 2023 11:45:02 +0100 Subject: [PATCH 01/16] feat: Extension inference for CFGs --- src/extension/infer.rs | 197 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 194 insertions(+), 3 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index f9f2f1751..84a12a4d9 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -302,6 +302,18 @@ impl UnificationContext { } } + if node_type.tag() == OpTag::Cfg { + let mut children = hugr.children(node); + let entry = children.next().unwrap(); + let exit = children.next().unwrap(); + let m_entry = self.make_or_get_meta(entry, Direction::Incoming); + let m_exit_in = self.make_or_get_meta(exit, Direction::Incoming); + let m_exit_out = self.make_or_get_meta(exit, Direction::Outgoing); + self.add_constraint(m_input, Constraint::Equal(m_entry)); + self.add_constraint(m_output, Constraint::Equal(m_exit_in)); + self.add_constraint(m_output, Constraint::Equal(m_exit_out)); + } + match node_type.signature() { // Input extensions are open None => { @@ -324,8 +336,12 @@ impl UnificationContext { let sig: &OpType = hugr.get_nodetype(tgt_node).into(); // Incoming ports with a dataflow edge for port in hugr.node_inputs(tgt_node).filter(|src_port| { - matches!(sig.port_kind(*src_port), Some(EdgeKind::Value(_))) - || matches!(sig.port_kind(*src_port), Some(EdgeKind::Static(_))) + matches!( + sig.port_kind(*src_port), + Some(EdgeKind::Value(_)) + | Some(EdgeKind::Static(_)) + | Some(EdgeKind::ControlFlow) + ) }) { for (src_node, _) in hugr.linked_ports(tgt_node, port) { let m_src = self @@ -1133,8 +1149,183 @@ mod test { for (src, tgt) in nodes.into_iter().tuple_windows() { hugr.connect(src, 0, tgt, 0)?; } - hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + Ok(()) + } + + fn make_opaque(extension: impl Into, signature: FunctionType) -> ops::LeafOp { + let opaque = + ops::custom::OpaqueOp::new(extension.into(), "", "".into(), vec![], Some(signature)); + ops::custom::ExternalOp::from(opaque).into() + } + + /// A CFG rooted hugr adding resources at each basic block. + /// Looks like this: + /// + /// +-------------+ + /// | Entry | + /// | (Adds [A]) | + /// +-/---------\-+ + /// / \ + /// +-------/-----+ +-\-------------+ + /// | BB0 | | BB1 | + /// | (Adds [BC]) | | (Adds [B]) | + /// +----\--------+ +---/------\----+ + /// \ / \ + /// \ / \ + /// \ +----/-------+ +-\---------+ + /// \ | BB10 | | BB11 | + /// \ | (Adds [C]) | | (Adds [C])| + /// \ +----+-------+ +/----------+ + /// \ | / + /// +-----\-------+---------/-+ + /// | Exit | + /// +-------------------------+ + #[test] + fn infer_cfg_test() -> Result<(), Box> { + let a = ExtensionSet::singleton(&A); + let abc = ExtensionSet::from_iter([A, B, C]); + let bc = ExtensionSet::from_iter([B, C]); + let b = ExtensionSet::singleton(&B); + let c = ExtensionSet::singleton(&C); + + let oneway = vec![Type::new_predicate([type_row![NAT]])]; + let twoway = vec![Type::new_predicate([type_row![NAT], type_row![NAT]])]; + + let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc), + })); + + let root = hugr.root(); + + let [entry, entry_in, entry_out] = create_with_io( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT], type_row![NAT]], + extension_delta: a.clone(), + }, + )?; + + let exit = hugr.add_node_with_parent( + hugr.root(), + NodeType::open_extensions(ops::BasicBlock::Exit { + cfg_outputs: type_row![NAT], + }), + )?; + + let mkpred = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + A, + FunctionType::new(vec![NAT], twoway.clone()).with_extension_delta(&a), + )), + )?; + + let [bb0, bb0_in, bb0_out] = create_with_io( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: bc.clone(), + }, + )?; + + let [bb1, bb1_in, bb1_out] = create_with_io( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT], type_row![NAT]], + extension_delta: b.clone(), + }, + )?; + + let add_bc = hugr.add_node_with_parent( + bb0, + NodeType::open_extensions(make_opaque( + B, + FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&bc), + )), + )?; + + let add_b = hugr.add_node_with_parent( + bb1, + NodeType::open_extensions(make_opaque( + B, + FunctionType::new(vec![NAT], twoway.clone()).with_extension_delta(&b), + )), + )?; + + let [bb10, bb10_in, bb10_out] = create_with_io( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: c.clone(), + }, + )?; + + let [bb11, bb11_in, bb11_out] = create_with_io( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: c.clone(), + }, + )?; + + let add_c0 = hugr.add_node_with_parent( + bb10, + NodeType::open_extensions(make_opaque( + C, + FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&c), + )), + )?; + + let add_c1 = hugr.add_node_with_parent( + bb11, + NodeType::open_extensions(make_opaque( + C, + FunctionType::new(vec![NAT], oneway).with_extension_delta(&c), + )), + )?; + + // Internal wiring for DFGs + hugr.connect(entry_in, 0, mkpred, 0)?; + hugr.connect(mkpred, 0, entry_out, 0)?; + + hugr.connect(bb0_in, 0, add_bc, 0)?; + hugr.connect(add_bc, 0, bb0_out, 0)?; + + hugr.connect(bb1_in, 0, add_b, 0)?; + hugr.connect(add_b, 0, bb1_out, 0)?; + + hugr.connect(bb10_in, 0, add_c0, 0)?; + hugr.connect(add_c0, 0, bb10_out, 0)?; + hugr.connect(bb11_in, 0, add_c1, 0)?; + hugr.connect(add_c1, 0, bb11_out, 0)?; + + // CFG Wiring + hugr.connect(entry, 0, bb0, 0)?; + hugr.connect(entry, 1, bb1, 0)?; + hugr.connect(bb1, 0, bb10, 0)?; + hugr.connect(bb1, 1, bb11, 0)?; + + hugr.connect(bb0, 0, exit, 0)?; + hugr.connect(bb10, 0, exit, 0)?; + hugr.connect(bb11, 0, exit, 0)?; + + hugr.infer_extensions()?; Ok(()) } From 16be9bceb2b28a7483305b932fa8d83a0fc44bbf Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 13 Sep 2023 09:15:23 +0100 Subject: [PATCH 02/16] refactor: Add helper method for CFG inference test --- src/extension/infer.rs | 75 ++++++++++++++++++------------------------ 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 84a12a4d9..a41aa6caf 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1159,6 +1159,22 @@ mod test { ops::custom::ExternalOp::from(opaque).into() } + fn make_block( + hugr: &mut Hugr, + bb_parent: Node, + dfb: impl Into, + op: impl Into, + ) -> Result> { + let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb)?; + + let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?; + + hugr.connect(bb_in, 0, dfg, 0)?; + hugr.connect(dfg, 0, bb_out, 0)?; + + Ok(bb) + } + /// A CFG rooted hugr adding resources at each basic block. /// Looks like this: /// @@ -1224,7 +1240,7 @@ mod test { )), )?; - let [bb0, bb0_in, bb0_out] = create_with_io( + let bb0 = make_block( &mut hugr, root, ops::BasicBlock::DFB { @@ -1233,9 +1249,13 @@ mod test { predicate_variants: vec![type_row![NAT]], extension_delta: bc.clone(), }, + make_opaque( + B, + FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&bc), + ), )?; - let [bb1, bb1_in, bb1_out] = create_with_io( + let bb1 = make_block( &mut hugr, root, ops::BasicBlock::DFB { @@ -1244,25 +1264,13 @@ mod test { predicate_variants: vec![type_row![NAT], type_row![NAT]], extension_delta: b.clone(), }, - )?; - - let add_bc = hugr.add_node_with_parent( - bb0, - NodeType::open_extensions(make_opaque( - B, - FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&bc), - )), - )?; - - let add_b = hugr.add_node_with_parent( - bb1, - NodeType::open_extensions(make_opaque( + make_opaque( B, FunctionType::new(vec![NAT], twoway.clone()).with_extension_delta(&b), - )), + ), )?; - let [bb10, bb10_in, bb10_out] = create_with_io( + let bb10 = make_block( &mut hugr, root, ops::BasicBlock::DFB { @@ -1271,9 +1279,13 @@ mod test { predicate_variants: vec![type_row![NAT]], extension_delta: c.clone(), }, + make_opaque( + C, + FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&c), + ), )?; - let [bb11, bb11_in, bb11_out] = create_with_io( + let bb11 = make_block( &mut hugr, root, ops::BasicBlock::DFB { @@ -1282,39 +1294,16 @@ mod test { predicate_variants: vec![type_row![NAT]], extension_delta: c.clone(), }, - )?; - - let add_c0 = hugr.add_node_with_parent( - bb10, - NodeType::open_extensions(make_opaque( - C, - FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&c), - )), - )?; - - let add_c1 = hugr.add_node_with_parent( - bb11, - NodeType::open_extensions(make_opaque( + make_opaque( C, FunctionType::new(vec![NAT], oneway).with_extension_delta(&c), - )), + ), )?; // Internal wiring for DFGs hugr.connect(entry_in, 0, mkpred, 0)?; hugr.connect(mkpred, 0, entry_out, 0)?; - hugr.connect(bb0_in, 0, add_bc, 0)?; - hugr.connect(add_bc, 0, bb0_out, 0)?; - - hugr.connect(bb1_in, 0, add_b, 0)?; - hugr.connect(add_b, 0, bb1_out, 0)?; - - hugr.connect(bb10_in, 0, add_c0, 0)?; - hugr.connect(add_c0, 0, bb10_out, 0)?; - hugr.connect(bb11_in, 0, add_c1, 0)?; - hugr.connect(add_c1, 0, bb11_out, 0)?; - // CFG Wiring hugr.connect(entry, 0, bb0, 0)?; hugr.connect(entry, 1, bb1, 0)?; From aa5b3a6b86ed498e1bb5d390bf82e5e71b352246 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 13 Sep 2023 16:30:45 +0100 Subject: [PATCH 03/16] docs: Drive by change to comments --- src/algorithm/nest_cfgs.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index 56aff219b..d9c67e0fd 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -429,7 +429,7 @@ pub(crate) mod test { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which + cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; @@ -678,7 +678,7 @@ pub(crate) mod test { ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which + cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; @@ -722,7 +722,7 @@ pub(crate) mod test { separate_headers: bool, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which + cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; From aa7c43608461831f77fa6d5d1f823bbcd2a2a187 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 27 Sep 2023 12:54:58 +0100 Subject: [PATCH 04/16] doc: Update a comment --- src/extension/infer.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index a41aa6caf..09e45099c 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -669,6 +669,8 @@ impl UnificationContext { } /// Instantiate all variables in the graph with the empty extension set. + /// Instantiate all variables in the graph with the empty extension set, or + /// the smallest solution possible given their constraints. /// This is done to solve metas which depend on variables, which allows /// us to come up with a fully concrete solution to pass into validation. pub fn instantiate_variables(&mut self) { From 05833515846f7ab364025ffc8fa84404f9c42941 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Fri, 22 Sep 2023 11:55:02 +0100 Subject: [PATCH 05/16] fixup: CFG Inference test --- src/extension/infer.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 09e45099c..c418d333c 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1308,9 +1308,9 @@ mod test { // CFG Wiring hugr.connect(entry, 0, bb0, 0)?; - hugr.connect(entry, 1, bb1, 0)?; + hugr.connect(entry, 0, bb1, 0)?; hugr.connect(bb1, 0, bb10, 0)?; - hugr.connect(bb1, 1, bb11, 0)?; + hugr.connect(bb1, 0, bb11, 0)?; hugr.connect(bb0, 0, exit, 0)?; hugr.connect(bb10, 0, exit, 0)?; From adfbbafff801830b0e0d27b3177245bbda9831f6 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Mon, 18 Sep 2023 10:28:29 +0100 Subject: [PATCH 06/16] tests: More CFG inference tests --- src/extension/infer.rs | 395 ++++++++++++++++++++++++++++++++++++++--- 1 file changed, 372 insertions(+), 23 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index c418d333c..3023e55a1 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -307,11 +307,9 @@ impl UnificationContext { let entry = children.next().unwrap(); let exit = children.next().unwrap(); let m_entry = self.make_or_get_meta(entry, Direction::Incoming); - let m_exit_in = self.make_or_get_meta(exit, Direction::Incoming); - let m_exit_out = self.make_or_get_meta(exit, Direction::Outgoing); + let m_exit = self.make_or_get_meta(exit, Direction::Outgoing); self.add_constraint(m_input, Constraint::Equal(m_entry)); - self.add_constraint(m_output, Constraint::Equal(m_exit_in)); - self.add_constraint(m_output, Constraint::Equal(m_exit_out)); + self.add_constraint(m_output, Constraint::Equal(m_exit)); } match node_type.signature() { @@ -705,7 +703,10 @@ mod test { use super::*; use crate::builder::test::closed_dfg_root_hugr; - use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use crate::extension::{ + prelude::{PRELUDE_ID, PRELUDE_REGISTRY}, + ExtensionSet, + }; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; @@ -993,20 +994,21 @@ mod test { hugr: &mut Hugr, parent: Node, op: impl Into, + op_sig: FunctionType, ) -> Result<[Node; 3], Box> { let op: OpType = op.into(); - let input_types = op.signature().input; - let output_types = op.signature().output; let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?; let input = hugr.add_node_with_parent( node, - NodeType::open_extensions(ops::Input { types: input_types }), + NodeType::open_extensions(ops::Input { + types: op_sig.input, + }), )?; let output = hugr.add_node_with_parent( node, NodeType::open_extensions(ops::Output { - types: output_types, + types: op_sig.output, }), )?; Ok([node, input, output]) @@ -1021,7 +1023,12 @@ mod test { first_ext: ExtensionId, second_ext: ExtensionId, ) -> Result> { - let [case, case_in, case_out] = create_with_io(hugr, conditional_node, op)?; + let [case, case_in, case_out] = create_with_io( + hugr, + conditional_node, + op.clone(), + Into::::into(op).signature(), + )?; let lift1 = hugr.add_node_with_parent( case, @@ -1118,14 +1125,16 @@ mod test { let df_nodes: Vec = vec![A, A, B, B, A, B] .into_iter() .map(|ext| { + let dfg_sig = df_sig + .clone() + .with_extension_delta(&ExtensionSet::singleton(&ext)); let [node, input, output] = create_with_io( &mut hugr, root, ops::DFG { - signature: df_sig - .clone() - .with_extension_delta(&ExtensionSet::singleton(&ext)), + signature: dfg_sig.clone(), }, + dfg_sig, ) .unwrap(); @@ -1165,9 +1174,10 @@ mod test { hugr: &mut Hugr, bb_parent: Node, dfb: impl Into, - op: impl Into, + dfb_sig: FunctionType, + op: impl Into, // A single op contained in the DFB, with inputs and outputs wired to it ) -> Result> { - let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb)?; + let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?; let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?; @@ -1177,6 +1187,14 @@ mod test { Ok(bb) } + fn oneway(ty: Type) -> Vec { + vec![Type::new_predicate([vec![ty]])] + } + + fn twoway(ty: Type) -> Vec { + vec![Type::new_predicate([vec![ty.clone()], vec![ty]])] + } + /// A CFG rooted hugr adding resources at each basic block. /// Looks like this: /// @@ -1207,9 +1225,6 @@ mod test { let b = ExtensionSet::singleton(&B); let c = ExtensionSet::singleton(&C); - let oneway = vec![Type::new_predicate([type_row![NAT]])]; - let twoway = vec![Type::new_predicate([type_row![NAT], type_row![NAT]])]; - let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc), })); @@ -1225,6 +1240,7 @@ mod test { predicate_variants: vec![type_row![NAT], type_row![NAT]], extension_delta: a.clone(), }, + FunctionType::new(vec![NAT], twoway(NAT)), )?; let exit = hugr.add_node_with_parent( @@ -1238,7 +1254,7 @@ mod test { entry, NodeType::open_extensions(make_opaque( A, - FunctionType::new(vec![NAT], twoway.clone()).with_extension_delta(&a), + FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a), )), )?; @@ -1251,9 +1267,10 @@ mod test { predicate_variants: vec![type_row![NAT]], extension_delta: bc.clone(), }, + FunctionType::new(vec![NAT], oneway(NAT)), make_opaque( B, - FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&bc), + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&bc), ), )?; @@ -1266,9 +1283,10 @@ mod test { predicate_variants: vec![type_row![NAT], type_row![NAT]], extension_delta: b.clone(), }, + FunctionType::new(vec![NAT], twoway(NAT)), make_opaque( B, - FunctionType::new(vec![NAT], twoway.clone()).with_extension_delta(&b), + FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&b), ), )?; @@ -1281,9 +1299,10 @@ mod test { predicate_variants: vec![type_row![NAT]], extension_delta: c.clone(), }, + FunctionType::new(vec![NAT], oneway(NAT)), make_opaque( C, - FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&c), + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&c), ), )?; @@ -1296,9 +1315,10 @@ mod test { predicate_variants: vec![type_row![NAT]], extension_delta: c.clone(), }, + FunctionType::new(vec![NAT], oneway(NAT)), make_opaque( C, - FunctionType::new(vec![NAT], oneway).with_extension_delta(&c), + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&c), ), )?; @@ -1320,4 +1340,333 @@ mod test { Ok(()) } + + #[test] + fn multi_entry() -> Result<(), Box> { + let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions? + })); + let cfg = hugr.root(); + let exit = hugr.add_node_with_parent( + cfg, + NodeType::open_extensions(ops::BasicBlock::Exit { + cfg_outputs: type_row![NAT], + }), + )?; + let entry = { + let dfb = ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT], type_row![NAT]], + extension_delta: ExtensionSet::new(), + }; + + let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; + + let entry_in = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Input { + types: type_row![NAT], + }), + )?; + + let entry_out = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Output { + types: twoway(NAT).into(), + }), + )?; + + let mid = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], twoway(NAT)), + )), + )?; + + hugr.connect(entry_in, 0, mid, 0)?; + hugr.connect(mid, 0, entry_out, 0)?; + + entry + }; + + let bb0 = make_block( + &mut hugr, + cfg, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: ExtensionSet::new(), + }, + FunctionType::new(vec![NAT], oneway(NAT)), + make_opaque(PRELUDE_ID, FunctionType::new(vec![NAT], oneway(NAT))), + )?; + + let bb1 = make_block( + &mut hugr, + cfg, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: ExtensionSet::new(), + }, + FunctionType::new(vec![NAT], oneway(NAT)), + make_opaque(PRELUDE_ID, FunctionType::new(vec![NAT], oneway(NAT))), + )?; + + // Mult + let bb2 = make_block( + &mut hugr, + cfg, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: ExtensionSet::new(), + }, + FunctionType::new(vec![NAT], oneway(NAT)), + make_opaque(PRELUDE_ID, FunctionType::new(vec![NAT], oneway(NAT))), + )?; + + hugr.connect(entry, 0, bb0, 0)?; + hugr.connect(entry, 0, bb1, 0)?; + hugr.connect(bb0, 0, bb2, 0)?; + hugr.connect(bb1, 0, bb2, 0)?; + hugr.connect(bb2, 0, exit, 0)?; + + hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + + Ok(()) + } + + /// +-----------+ + /// +--->| Entry | + /// | | (BB0) | + /// | | | + /// | +-----+-----+ + /// | | + /// | V + /// | +------------+ + /// | | BB1 | + /// | | +---+ + /// | +-----+------+ | + /// | | | + /// | V | + /// | +------------+ | + /// +----+ BB2 | | + /// | | | + /// +------------+ | + /// | + /// +------------+ | + /// | Exit | | + /// | |<--+ + /// +------------+ + /// + /// + /// + fn make_cfg_loop_test( + entry_ext: ExtensionSet, + bb1_ext: ExtensionSet, + bb2_ext: ExtensionSet, + ) -> Result> { + let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext); + + let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]) + .with_extension_delta(&hugr_delta), + })); + + let root = hugr.root(); + + let exit = hugr.add_node_with_parent( + root, + NodeType::open_extensions(ops::BasicBlock::Exit { + cfg_outputs: type_row![NAT], + }), + )?; + + let entry = { + let dfb: OpType = ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: entry_ext.clone(), + } + .into(); + + let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; + let entry_in = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Input { + types: type_row![NAT], + }), + )?; + let entry_out = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Output { + types: oneway(NAT).into(), + }), + )?; + + let dfg = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext), + )), + )?; + + hugr.connect(entry_in, 0, dfg, 0)?; + hugr.connect(dfg, 0, entry_out, 0)?; + entry + }; + + let bb1 = make_block( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT], type_row![NAT]], + extension_delta: bb1_ext.clone(), + }, + FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&bb1_ext), + make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&bb1_ext), + ), + )?; + + let bb2 = make_block( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: bb2_ext.clone(), + }, + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&bb2_ext), + make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&bb2_ext), + ), + )?; + + hugr.connect(entry, 0, bb1, 0)?; + hugr.connect(bb1, 0, bb2, 0)?; + hugr.connect(bb1, 0, exit, 0)?; + hugr.connect(bb2, 0, entry, 0)?; + + Ok(hugr) + } + + #[test] + fn test_cfg_loops() -> Result<(), Box> { + let just_a = ExtensionSet::singleton(&A); + let variants = vec![ + ( + ExtensionSet::new(), + ExtensionSet::new(), + ExtensionSet::new(), + ), + (just_a.clone(), ExtensionSet::new(), ExtensionSet::new()), + (ExtensionSet::new(), just_a.clone(), ExtensionSet::new()), + (ExtensionSet::new(), ExtensionSet::new(), just_a.clone()), + ]; + + for (bb0, bb1, bb2) in variants.into_iter() { + println!("{}, {}, {}", bb0, bb1, bb2); + let mut hugr = make_cfg_loop_test(bb0, bb1, bb2)?; + hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + } + Ok(()) + } + + #[test] + /// A control flow graph consisting of an entry node and a single block + /// which adds a resource and links to both itself and the exit node. + fn simple_cfg_loop() -> Result<(), Box> { + let just_a = ExtensionSet::singleton(&A); + + let mut hugr = Hugr::new(NodeType::new( + ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]) + .with_extension_delta(&just_a), + }, + just_a.clone(), + )); + + let root = hugr.root(); + + let exit = hugr.add_node_with_parent( + root, + NodeType::open_extensions(ops::BasicBlock::Exit { + cfg_outputs: type_row![NAT], + }), + )?; + + let entry = { + let dfb: OpType = ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT]], + extension_delta: ExtensionSet::new(), + } + .into(); + + let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; + let entry_in = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Input { + types: type_row![NAT], + }), + )?; + let entry_out = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Output { + types: oneway(NAT).into(), + }), + )?; + + let mid = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], oneway(NAT)), + )), + )?; + + hugr.connect(entry_in, 0, mid, 0)?; + hugr.connect(mid, 0, entry_out, 0)?; + entry + }; + + let bb = make_block( + &mut hugr, + root, + ops::BasicBlock::DFB { + inputs: type_row![NAT], + other_outputs: type_row![], + predicate_variants: vec![type_row![NAT], type_row![NAT]], + extension_delta: just_a.clone(), + }, + FunctionType::new(vec![NAT], twoway(NAT)), + make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&just_a), + ), + )?; + + hugr.connect(entry, 0, bb, 0)?; + hugr.connect(bb, 0, bb, 0)?; + hugr.connect(bb, 0, exit, 0)?; + + hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + + Ok(()) + } } From 52b0bc9f1d2c5ec8f40abc3be8ae3ab3a222f87c Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Wed, 27 Sep 2023 13:21:22 +0100 Subject: [PATCH 07/16] doc: Remove redundant whitespace from comment --- src/extension/infer.rs | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 3023e55a1..86aae3185 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1444,29 +1444,21 @@ mod test { /// +-----------+ /// +--->| Entry | - /// | | (BB0) | - /// | | | /// | +-----+-----+ /// | | /// | V /// | +------------+ - /// | | BB1 | - /// | | +---+ + /// | | BB1 +---+ /// | +-----+------+ | /// | | | /// | V | /// | +------------+ | /// +----+ BB2 | | - /// | | | /// +------------+ | /// | /// +------------+ | - /// | Exit | | - /// | |<--+ + /// | Exit |<--+ /// +------------+ - /// - /// - /// fn make_cfg_loop_test( entry_ext: ExtensionSet, bb1_ext: ExtensionSet, From aeffaddc7f562a92e4e11037bcb01143ae82b6ed Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 09:29:59 +0100 Subject: [PATCH 08/16] doc: Fix typo --- src/extension/infer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 86aae3185..ce05de026 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -328,7 +328,7 @@ impl UnificationContext { } } } - // Seperate loop so that we can assume that a metavariable has been + // Separate loop so that we can assume that a metavariable has been // added for every (Node, Direction) in the graph already. for tgt_node in hugr.nodes() { let sig: &OpType = hugr.get_nodetype(tgt_node).into(); From d9ea72b482051be9e687ec81f89752aa438163e2 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 09:30:41 +0100 Subject: [PATCH 09/16] doc: Update comment --- src/extension/infer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index ce05de026..5f4d8d9db 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -332,7 +332,7 @@ impl UnificationContext { // added for every (Node, Direction) in the graph already. for tgt_node in hugr.nodes() { let sig: &OpType = hugr.get_nodetype(tgt_node).into(); - // Incoming ports with a dataflow edge + // Incoming ports with an edge that should mean equal extension reqs for port in hugr.node_inputs(tgt_node).filter(|src_port| { matches!( sig.port_kind(*src_port), From 375437ccc56b6ffc0546ba5a68da85ee7863fce9 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 09:35:55 +0100 Subject: [PATCH 10/16] cosmetic: Move dfg wiring to be closer to dfg node creation --- src/extension/infer.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 5f4d8d9db..83a2ba718 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1258,6 +1258,10 @@ mod test { )), )?; + // Internal wiring for DFGs + hugr.connect(entry_in, 0, mkpred, 0)?; + hugr.connect(mkpred, 0, entry_out, 0)?; + let bb0 = make_block( &mut hugr, root, @@ -1322,10 +1326,6 @@ mod test { ), )?; - // Internal wiring for DFGs - hugr.connect(entry_in, 0, mkpred, 0)?; - hugr.connect(mkpred, 0, entry_out, 0)?; - // CFG Wiring hugr.connect(entry, 0, bb0, 0)?; hugr.connect(entry, 0, bb1, 0)?; From 69fb81183a4341950c67cd40fd2ede603fc3a4b8 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 09:38:54 +0100 Subject: [PATCH 11/16] Remove comment --- src/extension/infer.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 83a2ba718..c58a50570 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1417,7 +1417,6 @@ mod test { make_opaque(PRELUDE_ID, FunctionType::new(vec![NAT], oneway(NAT))), )?; - // Mult let bb2 = make_block( &mut hugr, cfg, From ead72ec976ebb0e40d0020fd09b15b35b3ff3eae Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 09:44:36 +0100 Subject: [PATCH 12/16] cleanup: Remove debug prints --- src/extension/infer.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index c58a50570..f8253c956 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1570,7 +1570,6 @@ mod test { ]; for (bb0, bb1, bb2) in variants.into_iter() { - println!("{}, {}, {}", bb0, bb1, bb2); let mut hugr = make_cfg_loop_test(bb0, bb1, bb2)?; hugr.infer_and_validate(&PRELUDE_REGISTRY)?; } From e7ec8ca853770a72ac54b855b55eb7d17daaf7ef Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 10:11:29 +0100 Subject: [PATCH 13/16] Add comment to and rename function --- src/extension/infer.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index f8253c956..2f0231a76 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1441,6 +1441,9 @@ mod test { Ok(()) } + /// Create a CFG of the form below, with the extension deltas for `Entry`, + /// `BB1`, and `BB2` specified by arguments to the function. + /// /// +-----------+ /// +--->| Entry | /// | +-----+-----+ @@ -1458,7 +1461,7 @@ mod test { /// +------------+ | /// | Exit |<--+ /// +------------+ - fn make_cfg_loop_test( + fn make_looping_cfg( entry_ext: ExtensionSet, bb1_ext: ExtensionSet, bb2_ext: ExtensionSet, @@ -1570,7 +1573,7 @@ mod test { ]; for (bb0, bb1, bb2) in variants.into_iter() { - let mut hugr = make_cfg_loop_test(bb0, bb1, bb2)?; + let mut hugr = make_looping_cfg(bb0, bb1, bb2)?; hugr.infer_and_validate(&PRELUDE_REGISTRY)?; } Ok(()) From f0c8b2ab2402c3165b422eb7bf5fc8bb8f87df27 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 10:24:12 +0100 Subject: [PATCH 14/16] Add ascii art comment --- src/extension/infer.rs | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 2f0231a76..1a1fcba5a 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1341,6 +1341,28 @@ mod test { Ok(()) } + /// A test case for a CFG with a node (BB2) which has multiple predecessors, + /// Like so: + /// + /// +-----------------+ + /// | Entry | + /// +------/--\-------+ + /// / \ + /// / \ + /// / \ + /// +---------/--+ +----\-------+ + /// | BB0 | | BB1 | + /// +--------\---+ +----/-------+ + /// \ / + /// \ / + /// \ / + /// +------\---/--------+ + /// | BB2 | + /// +---------+---------+ + /// | + /// +---------+----------+ + /// | Exit | + /// +--------------------+ #[test] fn multi_entry() -> Result<(), Box> { let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { From b2167351bde8bc044c09833241b4b5076103e890 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 14:39:34 +0100 Subject: [PATCH 15/16] refactor: Make `make_block` do more --- src/extension/infer.rs | 150 +++++++++++++---------------------------- 1 file changed, 45 insertions(+), 105 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index 1a1fcba5a..d5f3a3df0 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -711,7 +711,7 @@ mod test { use crate::macros::const_extension_ids; use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; use crate::type_row; - use crate::types::{FunctionType, Type}; + use crate::types::{FunctionType, Type, TypeRow}; use cool_asserts::assert_matches; use itertools::Itertools; @@ -1173,10 +1173,21 @@ mod test { fn make_block( hugr: &mut Hugr, bb_parent: Node, - dfb: impl Into, - dfb_sig: FunctionType, - op: impl Into, // A single op contained in the DFB, with inputs and outputs wired to it + inputs: TypeRow, + predicate_variants: Vec, + extension_delta: ExtensionSet, ) -> Result> { + let predicate_type = Type::new_predicate(predicate_variants.clone()); + let dfb_sig = FunctionType::new(inputs.clone(), vec![predicate_type]) + .with_extension_delta(&extension_delta.clone()); + let dfb = ops::BasicBlock::DFB { + inputs, + other_outputs: type_row![], + predicate_variants, + extension_delta, + }; + let op = make_opaque(PRELUDE_ID, dfb_sig.clone()); + let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?; let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?; @@ -1265,65 +1276,33 @@ mod test { let bb0 = make_block( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: bc.clone(), - }, - FunctionType::new(vec![NAT], oneway(NAT)), - make_opaque( - B, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&bc), - ), + type_row![NAT], + vec![type_row![NAT]], + bc.clone(), )?; let bb1 = make_block( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT], type_row![NAT]], - extension_delta: b.clone(), - }, - FunctionType::new(vec![NAT], twoway(NAT)), - make_opaque( - B, - FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&b), - ), + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + b.clone(), )?; let bb10 = make_block( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: c.clone(), - }, - FunctionType::new(vec![NAT], oneway(NAT)), - make_opaque( - C, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&c), - ), + type_row![NAT], + vec![type_row![NAT]], + c.clone(), )?; let bb11 = make_block( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: c.clone(), - }, - FunctionType::new(vec![NAT], oneway(NAT)), - make_opaque( - C, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&c), - ), + type_row![NAT], + vec![type_row![NAT]], + c.clone(), )?; // CFG Wiring @@ -1416,40 +1395,25 @@ mod test { let bb0 = make_block( &mut hugr, cfg, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: ExtensionSet::new(), - }, - FunctionType::new(vec![NAT], oneway(NAT)), - make_opaque(PRELUDE_ID, FunctionType::new(vec![NAT], oneway(NAT))), + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), )?; let bb1 = make_block( &mut hugr, cfg, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: ExtensionSet::new(), - }, - FunctionType::new(vec![NAT], oneway(NAT)), - make_opaque(PRELUDE_ID, FunctionType::new(vec![NAT], oneway(NAT))), + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), )?; let bb2 = make_block( &mut hugr, cfg, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: ExtensionSet::new(), - }, - FunctionType::new(vec![NAT], oneway(NAT)), - make_opaque(PRELUDE_ID, FunctionType::new(vec![NAT], oneway(NAT))), + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), )?; hugr.connect(entry, 0, bb0, 0)?; @@ -1543,33 +1507,17 @@ mod test { let bb1 = make_block( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT], type_row![NAT]], - extension_delta: bb1_ext.clone(), - }, - FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&bb1_ext), - make_opaque( - PRELUDE_ID, - FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&bb1_ext), - ), + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + bb1_ext.clone(), )?; let bb2 = make_block( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: bb2_ext.clone(), - }, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&bb2_ext), - make_opaque( - PRELUDE_ID, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&bb2_ext), - ), + type_row![NAT], + vec![type_row![NAT]], + bb2_ext.clone(), )?; hugr.connect(entry, 0, bb1, 0)?; @@ -1663,17 +1611,9 @@ mod test { let bb = make_block( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT], type_row![NAT]], - extension_delta: just_a.clone(), - }, - FunctionType::new(vec![NAT], twoway(NAT)), - make_opaque( - PRELUDE_ID, - FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&just_a), - ), + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + just_a.clone(), )?; hugr.connect(entry, 0, bb, 0)?; From df75915e358c86d2819904fb3c118e67f9f88609 Mon Sep 17 00:00:00 2001 From: Craig Roy Date: Tue, 3 Oct 2023 15:10:37 +0100 Subject: [PATCH 16/16] tests: Add helper function for making control flow graphs --- src/extension/infer.rs | 218 +++++++++++++++++------------------------ 1 file changed, 88 insertions(+), 130 deletions(-) diff --git a/src/extension/infer.rs b/src/extension/infer.rs index d5f3a3df0..b51eebd4d 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -1206,6 +1206,44 @@ mod test { vec![Type::new_predicate([vec![ty.clone()], vec![ty]])] } + fn create_entry_exit( + hugr: &mut Hugr, + root: Node, + inputs: TypeRow, + entry_predicates: Vec, + entry_extensions: ExtensionSet, + exit_types: impl Into, + ) -> Result<([Node; 3], Node), Box> { + let entry_predicate_type = Type::new_predicate(entry_predicates.clone()); + let dfb = ops::BasicBlock::DFB { + inputs: inputs.clone(), + other_outputs: type_row![], + predicate_variants: entry_predicates, + extension_delta: entry_extensions, + }; + + let exit = hugr.add_node_with_parent( + root, + NodeType::open_extensions(ops::BasicBlock::Exit { + cfg_outputs: exit_types.into(), + }), + )?; + + let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; + let entry_in = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Input { types: inputs }), + )?; + let entry_out = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Output { + types: vec![entry_predicate_type].into(), + }), + )?; + + Ok(([entry, entry_in, entry_out], exit)) + } + /// A CFG rooted hugr adding resources at each basic block. /// Looks like this: /// @@ -1242,23 +1280,13 @@ mod test { let root = hugr.root(); - let [entry, entry_in, entry_out] = create_with_io( + let ([entry, entry_in, entry_out], exit) = create_entry_exit( &mut hugr, root, - ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT], type_row![NAT]], - extension_delta: a.clone(), - }, - FunctionType::new(vec![NAT], twoway(NAT)), - )?; - - let exit = hugr.add_node_with_parent( - hugr.root(), - NodeType::open_extensions(ops::BasicBlock::Exit { - cfg_outputs: type_row![NAT], - }), + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + a.clone(), + type_row![NAT], )?; let mkpred = hugr.add_node_with_parent( @@ -1348,49 +1376,25 @@ mod test { signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions? })); let cfg = hugr.root(); - let exit = hugr.add_node_with_parent( + let ([entry, entry_in, entry_out], exit) = create_entry_exit( + &mut hugr, cfg, - NodeType::open_extensions(ops::BasicBlock::Exit { - cfg_outputs: type_row![NAT], - }), + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + ExtensionSet::new(), + type_row![NAT], )?; - let entry = { - let dfb = ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT], type_row![NAT]], - extension_delta: ExtensionSet::new(), - }; - - let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; - - let entry_in = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Input { - types: type_row![NAT], - }), - )?; - - let entry_out = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Output { - types: twoway(NAT).into(), - }), - )?; - - let mid = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(make_opaque( - PRELUDE_ID, - FunctionType::new(vec![NAT], twoway(NAT)), - )), - )?; - hugr.connect(entry_in, 0, mid, 0)?; - hugr.connect(mid, 0, entry_out, 0)?; + let entry_mid = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], twoway(NAT)), + )), + )?; - entry - }; + hugr.connect(entry_in, 0, entry_mid, 0)?; + hugr.connect(entry_mid, 0, entry_out, 0)?; let bb0 = make_block( &mut hugr, @@ -1461,48 +1465,25 @@ mod test { let root = hugr.root(); - let exit = hugr.add_node_with_parent( + let ([entry, entry_in, entry_out], exit) = create_entry_exit( + &mut hugr, root, - NodeType::open_extensions(ops::BasicBlock::Exit { - cfg_outputs: type_row![NAT], - }), + type_row![NAT], + vec![type_row![NAT]], + entry_ext.clone(), + type_row![NAT], )?; - let entry = { - let dfb: OpType = ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: entry_ext.clone(), - } - .into(); - - let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; - let entry_in = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Input { - types: type_row![NAT], - }), - )?; - let entry_out = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Output { - types: oneway(NAT).into(), - }), - )?; - - let dfg = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(make_opaque( - PRELUDE_ID, - FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext), - )), - )?; + let entry_dfg = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext), + )), + )?; - hugr.connect(entry_in, 0, dfg, 0)?; - hugr.connect(dfg, 0, entry_out, 0)?; - entry - }; + hugr.connect(entry_in, 0, entry_dfg, 0)?; + hugr.connect(entry_dfg, 0, entry_out, 0)?; let bb1 = make_block( &mut hugr, @@ -1565,48 +1546,25 @@ mod test { let root = hugr.root(); - let exit = hugr.add_node_with_parent( + let ([entry, entry_in, entry_out], exit) = create_entry_exit( + &mut hugr, root, - NodeType::open_extensions(ops::BasicBlock::Exit { - cfg_outputs: type_row![NAT], - }), + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), + type_row![NAT], )?; - let entry = { - let dfb: OpType = ops::BasicBlock::DFB { - inputs: type_row![NAT], - other_outputs: type_row![], - predicate_variants: vec![type_row![NAT]], - extension_delta: ExtensionSet::new(), - } - .into(); - - let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; - let entry_in = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Input { - types: type_row![NAT], - }), - )?; - let entry_out = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(ops::Output { - types: oneway(NAT).into(), - }), - )?; - - let mid = hugr.add_node_with_parent( - entry, - NodeType::open_extensions(make_opaque( - PRELUDE_ID, - FunctionType::new(vec![NAT], oneway(NAT)), - )), - )?; + let entry_mid = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], oneway(NAT)), + )), + )?; - hugr.connect(entry_in, 0, mid, 0)?; - hugr.connect(mid, 0, entry_out, 0)?; - entry - }; + hugr.connect(entry_in, 0, entry_mid, 0)?; + hugr.connect(entry_mid, 0, entry_out, 0)?; let bb = make_block( &mut hugr,