diff --git a/src/algorithm/nest_cfgs.rs b/src/algorithm/nest_cfgs.rs index 56aff219b..d9c67e0fd 100644 --- a/src/algorithm/nest_cfgs.rs +++ b/src/algorithm/nest_cfgs.rs @@ -429,7 +429,7 @@ pub(crate) mod test { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which + cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; @@ -678,7 +678,7 @@ pub(crate) mod test { ) -> Result<(Hugr, BasicBlockID, BasicBlockID), BuildError> { let mut cfg_builder = CFGBuilder::new(FunctionType::new(type_row![NAT], type_row![NAT]))?; let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which + cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; @@ -722,7 +722,7 @@ pub(crate) mod test { separate_headers: bool, ) -> Result<(BasicBlockID, BasicBlockID), BuildError> { let pred_const = - cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which + cfg_builder.add_constant(Const::simple_predicate(0, 2), ExtensionSet::new())?; // Nothing here cares which branch let const_unit = cfg_builder.add_constant(Const::simple_unary_predicate(), ExtensionSet::new())?; diff --git a/src/extension/infer.rs b/src/extension/infer.rs index f9f2f1751..b51eebd4d 100644 --- a/src/extension/infer.rs +++ b/src/extension/infer.rs @@ -302,6 +302,16 @@ impl UnificationContext { } } + if node_type.tag() == OpTag::Cfg { + let mut children = hugr.children(node); + let entry = children.next().unwrap(); + let exit = children.next().unwrap(); + let m_entry = self.make_or_get_meta(entry, Direction::Incoming); + let m_exit = self.make_or_get_meta(exit, Direction::Outgoing); + self.add_constraint(m_input, Constraint::Equal(m_entry)); + self.add_constraint(m_output, Constraint::Equal(m_exit)); + } + match node_type.signature() { // Input extensions are open None => { @@ -318,14 +328,18 @@ impl UnificationContext { } } } - // Seperate loop so that we can assume that a metavariable has been + // Separate loop so that we can assume that a metavariable has been // added for every (Node, Direction) in the graph already. for tgt_node in hugr.nodes() { let sig: &OpType = hugr.get_nodetype(tgt_node).into(); - // Incoming ports with a dataflow edge + // Incoming ports with an edge that should mean equal extension reqs for port in hugr.node_inputs(tgt_node).filter(|src_port| { - matches!(sig.port_kind(*src_port), Some(EdgeKind::Value(_))) - || matches!(sig.port_kind(*src_port), Some(EdgeKind::Static(_))) + matches!( + sig.port_kind(*src_port), + Some(EdgeKind::Value(_)) + | Some(EdgeKind::Static(_)) + | Some(EdgeKind::ControlFlow) + ) }) { for (src_node, _) in hugr.linked_ports(tgt_node, port) { let m_src = self @@ -653,6 +667,8 @@ impl UnificationContext { } /// Instantiate all variables in the graph with the empty extension set. + /// Instantiate all variables in the graph with the empty extension set, or + /// the smallest solution possible given their constraints. /// This is done to solve metas which depend on variables, which allows /// us to come up with a fully concrete solution to pass into validation. pub fn instantiate_variables(&mut self) { @@ -687,12 +703,15 @@ mod test { use super::*; use crate::builder::test::closed_dfg_root_hugr; - use crate::extension::{ExtensionSet, PRELUDE_REGISTRY}; + use crate::extension::{ + prelude::{PRELUDE_ID, PRELUDE_REGISTRY}, + ExtensionSet, + }; use crate::hugr::{validate::ValidationError, Hugr, HugrMut, HugrView, NodeType}; use crate::macros::const_extension_ids; use crate::ops::{self, dataflow::IOTrait, handle::NodeHandle, OpTrait}; use crate::type_row; - use crate::types::{FunctionType, Type}; + use crate::types::{FunctionType, Type, TypeRow}; use cool_asserts::assert_matches; use itertools::Itertools; @@ -975,20 +994,21 @@ mod test { hugr: &mut Hugr, parent: Node, op: impl Into, + op_sig: FunctionType, ) -> Result<[Node; 3], Box> { let op: OpType = op.into(); - let input_types = op.signature().input; - let output_types = op.signature().output; let node = hugr.add_node_with_parent(parent, NodeType::open_extensions(op))?; let input = hugr.add_node_with_parent( node, - NodeType::open_extensions(ops::Input { types: input_types }), + NodeType::open_extensions(ops::Input { + types: op_sig.input, + }), )?; let output = hugr.add_node_with_parent( node, NodeType::open_extensions(ops::Output { - types: output_types, + types: op_sig.output, }), )?; Ok([node, input, output]) @@ -1003,7 +1023,12 @@ mod test { first_ext: ExtensionId, second_ext: ExtensionId, ) -> Result> { - let [case, case_in, case_out] = create_with_io(hugr, conditional_node, op)?; + let [case, case_in, case_out] = create_with_io( + hugr, + conditional_node, + op.clone(), + Into::::into(op).signature(), + )?; let lift1 = hugr.add_node_with_parent( case, @@ -1100,14 +1125,16 @@ mod test { let df_nodes: Vec = vec![A, A, B, B, A, B] .into_iter() .map(|ext| { + let dfg_sig = df_sig + .clone() + .with_extension_delta(&ExtensionSet::singleton(&ext)); let [node, input, output] = create_with_io( &mut hugr, root, ops::DFG { - signature: df_sig - .clone() - .with_extension_delta(&ExtensionSet::singleton(&ext)), + signature: dfg_sig.clone(), }, + dfg_sig, ) .unwrap(); @@ -1133,6 +1160,423 @@ mod test { for (src, tgt) in nodes.into_iter().tuple_windows() { hugr.connect(src, 0, tgt, 0)?; } + hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + Ok(()) + } + + fn make_opaque(extension: impl Into, signature: FunctionType) -> ops::LeafOp { + let opaque = + ops::custom::OpaqueOp::new(extension.into(), "", "".into(), vec![], Some(signature)); + ops::custom::ExternalOp::from(opaque).into() + } + + fn make_block( + hugr: &mut Hugr, + bb_parent: Node, + inputs: TypeRow, + predicate_variants: Vec, + extension_delta: ExtensionSet, + ) -> Result> { + let predicate_type = Type::new_predicate(predicate_variants.clone()); + let dfb_sig = FunctionType::new(inputs.clone(), vec![predicate_type]) + .with_extension_delta(&extension_delta.clone()); + let dfb = ops::BasicBlock::DFB { + inputs, + other_outputs: type_row![], + predicate_variants, + extension_delta, + }; + let op = make_opaque(PRELUDE_ID, dfb_sig.clone()); + + let [bb, bb_in, bb_out] = create_with_io(hugr, bb_parent, dfb, dfb_sig)?; + + let dfg = hugr.add_node_with_parent(bb, NodeType::open_extensions(op))?; + + hugr.connect(bb_in, 0, dfg, 0)?; + hugr.connect(dfg, 0, bb_out, 0)?; + + Ok(bb) + } + + fn oneway(ty: Type) -> Vec { + vec![Type::new_predicate([vec![ty]])] + } + + fn twoway(ty: Type) -> Vec { + vec![Type::new_predicate([vec![ty.clone()], vec![ty]])] + } + + fn create_entry_exit( + hugr: &mut Hugr, + root: Node, + inputs: TypeRow, + entry_predicates: Vec, + entry_extensions: ExtensionSet, + exit_types: impl Into, + ) -> Result<([Node; 3], Node), Box> { + let entry_predicate_type = Type::new_predicate(entry_predicates.clone()); + let dfb = ops::BasicBlock::DFB { + inputs: inputs.clone(), + other_outputs: type_row![], + predicate_variants: entry_predicates, + extension_delta: entry_extensions, + }; + + let exit = hugr.add_node_with_parent( + root, + NodeType::open_extensions(ops::BasicBlock::Exit { + cfg_outputs: exit_types.into(), + }), + )?; + + let entry = hugr.add_node_before(exit, NodeType::open_extensions(dfb))?; + let entry_in = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Input { types: inputs }), + )?; + let entry_out = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(ops::Output { + types: vec![entry_predicate_type].into(), + }), + )?; + + Ok(([entry, entry_in, entry_out], exit)) + } + + /// A CFG rooted hugr adding resources at each basic block. + /// Looks like this: + /// + /// +-------------+ + /// | Entry | + /// | (Adds [A]) | + /// +-/---------\-+ + /// / \ + /// +-------/-----+ +-\-------------+ + /// | BB0 | | BB1 | + /// | (Adds [BC]) | | (Adds [B]) | + /// +----\--------+ +---/------\----+ + /// \ / \ + /// \ / \ + /// \ +----/-------+ +-\---------+ + /// \ | BB10 | | BB11 | + /// \ | (Adds [C]) | | (Adds [C])| + /// \ +----+-------+ +/----------+ + /// \ | / + /// +-----\-------+---------/-+ + /// | Exit | + /// +-------------------------+ + #[test] + fn infer_cfg_test() -> Result<(), Box> { + let a = ExtensionSet::singleton(&A); + let abc = ExtensionSet::from_iter([A, B, C]); + let bc = ExtensionSet::from_iter([B, C]); + let b = ExtensionSet::singleton(&B); + let c = ExtensionSet::singleton(&C); + + let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc), + })); + + let root = hugr.root(); + + let ([entry, entry_in, entry_out], exit) = create_entry_exit( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + a.clone(), + type_row![NAT], + )?; + + let mkpred = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + A, + FunctionType::new(vec![NAT], twoway(NAT)).with_extension_delta(&a), + )), + )?; + + // Internal wiring for DFGs + hugr.connect(entry_in, 0, mkpred, 0)?; + hugr.connect(mkpred, 0, entry_out, 0)?; + + let bb0 = make_block( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT]], + bc.clone(), + )?; + + let bb1 = make_block( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + b.clone(), + )?; + + let bb10 = make_block( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT]], + c.clone(), + )?; + + let bb11 = make_block( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT]], + c.clone(), + )?; + + // CFG Wiring + hugr.connect(entry, 0, bb0, 0)?; + hugr.connect(entry, 0, bb1, 0)?; + hugr.connect(bb1, 0, bb10, 0)?; + hugr.connect(bb1, 0, bb11, 0)?; + + hugr.connect(bb0, 0, exit, 0)?; + hugr.connect(bb10, 0, exit, 0)?; + hugr.connect(bb11, 0, exit, 0)?; + + hugr.infer_extensions()?; + + Ok(()) + } + + /// A test case for a CFG with a node (BB2) which has multiple predecessors, + /// Like so: + /// + /// +-----------------+ + /// | Entry | + /// +------/--\-------+ + /// / \ + /// / \ + /// / \ + /// +---------/--+ +----\-------+ + /// | BB0 | | BB1 | + /// +--------\---+ +----/-------+ + /// \ / + /// \ / + /// \ / + /// +------\---/--------+ + /// | BB2 | + /// +---------+---------+ + /// | + /// +---------+----------+ + /// | Exit | + /// +--------------------+ + #[test] + fn multi_entry() -> Result<(), Box> { + let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]), // maybe add extensions? + })); + let cfg = hugr.root(); + let ([entry, entry_in, entry_out], exit) = create_entry_exit( + &mut hugr, + cfg, + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + ExtensionSet::new(), + type_row![NAT], + )?; + + let entry_mid = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], twoway(NAT)), + )), + )?; + + hugr.connect(entry_in, 0, entry_mid, 0)?; + hugr.connect(entry_mid, 0, entry_out, 0)?; + + let bb0 = make_block( + &mut hugr, + cfg, + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), + )?; + + let bb1 = make_block( + &mut hugr, + cfg, + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), + )?; + + let bb2 = make_block( + &mut hugr, + cfg, + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), + )?; + + hugr.connect(entry, 0, bb0, 0)?; + hugr.connect(entry, 0, bb1, 0)?; + hugr.connect(bb0, 0, bb2, 0)?; + hugr.connect(bb1, 0, bb2, 0)?; + hugr.connect(bb2, 0, exit, 0)?; + + hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + + Ok(()) + } + + /// Create a CFG of the form below, with the extension deltas for `Entry`, + /// `BB1`, and `BB2` specified by arguments to the function. + /// + /// +-----------+ + /// +--->| Entry | + /// | +-----+-----+ + /// | | + /// | V + /// | +------------+ + /// | | BB1 +---+ + /// | +-----+------+ | + /// | | | + /// | V | + /// | +------------+ | + /// +----+ BB2 | | + /// +------------+ | + /// | + /// +------------+ | + /// | Exit |<--+ + /// +------------+ + fn make_looping_cfg( + entry_ext: ExtensionSet, + bb1_ext: ExtensionSet, + bb2_ext: ExtensionSet, + ) -> Result> { + let hugr_delta = entry_ext.clone().union(&bb1_ext).union(&bb2_ext); + + let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]) + .with_extension_delta(&hugr_delta), + })); + + let root = hugr.root(); + + let ([entry, entry_in, entry_out], exit) = create_entry_exit( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT]], + entry_ext.clone(), + type_row![NAT], + )?; + + let entry_dfg = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], oneway(NAT)).with_extension_delta(&entry_ext), + )), + )?; + + hugr.connect(entry_in, 0, entry_dfg, 0)?; + hugr.connect(entry_dfg, 0, entry_out, 0)?; + + let bb1 = make_block( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + bb1_ext.clone(), + )?; + + let bb2 = make_block( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT]], + bb2_ext.clone(), + )?; + + hugr.connect(entry, 0, bb1, 0)?; + hugr.connect(bb1, 0, bb2, 0)?; + hugr.connect(bb1, 0, exit, 0)?; + hugr.connect(bb2, 0, entry, 0)?; + + Ok(hugr) + } + + #[test] + fn test_cfg_loops() -> Result<(), Box> { + let just_a = ExtensionSet::singleton(&A); + let variants = vec![ + ( + ExtensionSet::new(), + ExtensionSet::new(), + ExtensionSet::new(), + ), + (just_a.clone(), ExtensionSet::new(), ExtensionSet::new()), + (ExtensionSet::new(), just_a.clone(), ExtensionSet::new()), + (ExtensionSet::new(), ExtensionSet::new(), just_a.clone()), + ]; + + for (bb0, bb1, bb2) in variants.into_iter() { + let mut hugr = make_looping_cfg(bb0, bb1, bb2)?; + hugr.infer_and_validate(&PRELUDE_REGISTRY)?; + } + Ok(()) + } + + #[test] + /// A control flow graph consisting of an entry node and a single block + /// which adds a resource and links to both itself and the exit node. + fn simple_cfg_loop() -> Result<(), Box> { + let just_a = ExtensionSet::singleton(&A); + + let mut hugr = Hugr::new(NodeType::new( + ops::CFG { + signature: FunctionType::new(type_row![NAT], type_row![NAT]) + .with_extension_delta(&just_a), + }, + just_a.clone(), + )); + + let root = hugr.root(); + + let ([entry, entry_in, entry_out], exit) = create_entry_exit( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT]], + ExtensionSet::new(), + type_row![NAT], + )?; + + let entry_mid = hugr.add_node_with_parent( + entry, + NodeType::open_extensions(make_opaque( + PRELUDE_ID, + FunctionType::new(vec![NAT], oneway(NAT)), + )), + )?; + + hugr.connect(entry_in, 0, entry_mid, 0)?; + hugr.connect(entry_mid, 0, entry_out, 0)?; + + let bb = make_block( + &mut hugr, + root, + type_row![NAT], + vec![type_row![NAT], type_row![NAT]], + just_a.clone(), + )?; + + hugr.connect(entry, 0, bb, 0)?; + hugr.connect(bb, 0, bb, 0)?; + hugr.connect(bb, 0, exit, 0)?; hugr.infer_and_validate(&PRELUDE_REGISTRY)?;