diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 9de97bc9c..7203d7355 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -5,7 +5,7 @@ use super::{ BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire, }; -use crate::ops::{self, BasicBlock, OpType}; +use crate::ops::{self, DataflowBlock, ExitBlock, OpType}; use crate::{ extension::{ExtensionRegistry, ExtensionSet}, types::FunctionType, @@ -86,7 +86,7 @@ impl + AsRef> CFGBuilder { output: TypeRow, ) -> Result { let n_out_wires = output.len(); - let exit_block_type = OpType::BasicBlock(BasicBlock::Exit { + let exit_block_type = OpType::ExitBlock(ExitBlock { cfg_outputs: output, }); let exit_node = base @@ -102,7 +102,7 @@ impl + AsRef> CFGBuilder { }) } - /// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and the variants of the branching TupleSum value /// specified by `tuple_sum_rows`. /// @@ -134,7 +134,7 @@ impl + AsRef> CFGBuilder { entry: bool, ) -> Result, BuildError> { let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); - let op = OpType::BasicBlock(BasicBlock::DFB { + let op = OpType::DataflowBlock(DataflowBlock { inputs: inputs.clone(), other_outputs: other_outputs.clone(), tuple_sum_rows: tuple_sum_rows.clone(), @@ -159,7 +159,7 @@ impl + AsRef> CFGBuilder { ) } - /// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. /// /// # Errors @@ -178,7 +178,7 @@ impl + AsRef> CFGBuilder { ) } - /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and the variants of the branching TupleSum value /// specified by `tuple_sum_rows`. /// @@ -198,7 +198,7 @@ impl + AsRef> CFGBuilder { self.any_block_builder(inputs, tuple_sum_rows, other_outputs, extension_delta, true) } - /// Return a builder for the entry [`BasicBlock::DFB`] child graph with `inputs` + /// Return a builder for the entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and a UnitSum type: a Sum of `n_cases` unit types. /// /// # Errors @@ -235,7 +235,7 @@ impl + AsRef> CFGBuilder { } } -/// Builder for a [`BasicBlock::DFB`] child graph. +/// Builder for a [`DataflowBlock`] child graph. pub type BlockBuilder = DFGWrapper; impl + AsRef> BlockBuilder { @@ -248,6 +248,7 @@ impl + AsRef> BlockBuilder { ) -> Result<(), BuildError> { Dataflow::set_outputs(self, [branch_wire].into_iter().chain(outputs)) } + fn create( base: B, block_n: Node, @@ -284,7 +285,7 @@ impl + AsRef> BlockBuilder { } impl BlockBuilder { - /// Initialize a [`BasicBlock::DFB`] rooted HUGR builder + /// Initialize a [`DataflowBlock`] rooted HUGR builder pub fn new( inputs: impl Into, input_extensions: impl Into>, @@ -295,7 +296,7 @@ impl BlockBuilder { let inputs = inputs.into(); let tuple_sum_rows: Vec<_> = tuple_sum_rows.into_iter().collect(); let other_outputs = other_outputs.into(); - let op = BasicBlock::DFB { + let op = DataflowBlock { inputs: inputs.clone(), other_outputs: other_outputs.clone(), tuple_sum_rows: tuple_sum_rows.clone(), diff --git a/src/extension/infer/test.rs b/src/extension/infer/test.rs index 7653c7578..1ac40455a 100644 --- a/src/extension/infer/test.rs +++ b/src/extension/infer/test.rs @@ -461,7 +461,7 @@ fn make_block( let tuple_sum_type = Type::new_tuple_sum(tuple_sum_rows.clone()); let dfb_sig = FunctionType::new(inputs.clone(), vec![tuple_sum_type]) .with_extension_delta(&extension_delta.clone()); - let dfb = ops::BasicBlock::DFB { + let dfb = ops::DataflowBlock { inputs, other_outputs: type_row![], tuple_sum_rows, @@ -496,7 +496,7 @@ fn create_entry_exit( exit_types: impl Into, ) -> Result<([Node; 3], Node), Box> { let entry_tuple_sum = Type::new_tuple_sum(entry_variants.clone()); - let dfb = ops::BasicBlock::DFB { + let dfb = ops::DataflowBlock { inputs: inputs.clone(), other_outputs: type_row![], tuple_sum_rows: entry_variants, @@ -505,7 +505,7 @@ fn create_entry_exit( let exit = hugr.add_node_with_parent( root, - ops::BasicBlock::Exit { + ops::ExitBlock { cfg_outputs: exit_types.into(), }, )?; diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index c13a4183f..7ab612548 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -12,9 +12,10 @@ use crate::hugr::rewrite::Rewrite; use crate::hugr::views::sibling::SiblingMut; use crate::hugr::{HugrMut, HugrView}; use crate::ops; +use crate::ops::controlflow::BasicBlock; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; -use crate::ops::{BasicBlock, OpType}; +use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; @@ -114,12 +115,13 @@ impl Rewrite for OutlineCfg { self.compute_entry_exit_outside_extensions(h)?; // 1. Compute signature // These panic()s only happen if the Hugr would not have passed validate() - let OpType::BasicBlock(BasicBlock::DFB { inputs, .. }) = h.get_optype(entry) else { + let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else { panic!("Entry node is not a basic block") }; let inputs = inputs.clone(); let outputs = match h.get_optype(outside) { - OpType::BasicBlock(b) => b.dataflow_input().clone(), + OpType::DataflowBlock(dfb) => dfb.dataflow_input().clone(), + OpType::ExitBlock(exit) => exit.dataflow_input().clone(), _ => panic!("External successor not a basic block"), }; let outer_cfg = h.get_parent(entry).unwrap(); @@ -265,7 +267,6 @@ mod test { use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; - use crate::ops::{BasicBlock, OpType}; use crate::types::FunctionType; use crate::{type_row, HugrView, Node}; use cool_asserts::assert_matches; @@ -348,12 +349,9 @@ mod test { h.output_neighbours(tail).take(2).collect::>(), HashSet::from([exit, new_block]) ); - assert_matches!( - h.get_optype(new_block), - OpType::BasicBlock(BasicBlock::DFB { .. }) - ); + assert!(h.get_optype(new_block).is_dataflow_block()); assert_eq!(h.base_hugr().get_parent(new_cfg), Some(new_block)); - assert_matches!(h.base_hugr().get_optype(new_cfg), OpType::CFG(_)); + assert!(h.base_hugr().get_optype(new_cfg).is_cfg()); } #[test] @@ -409,12 +407,9 @@ mod test { .unwrap(); h.update_validate(&PRELUDE_REGISTRY).unwrap(); assert_eq!(new_block, h.children(h.root()).next().unwrap()); - assert_matches!( - h.get_optype(new_block), - OpType::BasicBlock(BasicBlock::DFB { .. }) - ); + assert!(h.get_optype(new_block).is_dataflow_block()); assert_eq!(h.get_parent(new_cfg), Some(new_block)); - assert_matches!(h.get_optype(new_cfg), OpType::CFG(_)); + assert!(h.get_optype(new_cfg).is_cfg()); for n in other_blocks { assert_eq!(depth(&h, n), 1); } diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 017407c97..1e9e9f3a4 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -449,7 +449,7 @@ mod test { use crate::ops::custom::{ExternalOp, OpaqueOp}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; - use crate::ops::{self, BasicBlock, Case, LeafOp, OpTag, OpType, DFG}; + use crate::ops::{self, Case, DataflowBlock, LeafOp, OpTag, OpType, DFG}; use crate::std_extensions::collections; use crate::types::{FunctionType, Type, TypeArg, TypeRow}; use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; @@ -504,14 +504,8 @@ mod test { let popp = h.get_parent(pop).unwrap(); let pushp = h.get_parent(push).unwrap(); assert_ne!(popp, pushp); // Two different BBs - assert!(matches!( - h.get_optype(popp), - OpType::BasicBlock(BasicBlock::DFB { .. }) - )); - assert!(matches!( - h.get_optype(pushp), - OpType::BasicBlock(BasicBlock::DFB { .. }) - )); + assert!(h.get_optype(popp).is_dataflow_block()); + assert!(h.get_optype(pushp).is_dataflow_block()); assert_eq!(h.get_parent(popp).unwrap(), h.get_parent(pushp).unwrap()); } @@ -523,7 +517,7 @@ mod test { })); let r_bb = replacement.add_node_with_parent( replacement.root(), - BasicBlock::DFB { + DataflowBlock { inputs: vec![listy.clone()].into(), tuple_sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), @@ -596,10 +590,7 @@ mod test { let grandp = h.get_parent(popp).unwrap(); assert_eq!(grandp, h.get_parent(pushp).unwrap()); - assert!(matches!( - h.get_optype(grandp), - OpType::BasicBlock(BasicBlock::DFB { .. }) - )); + assert!(h.get_optype(grandp).is_dataflow_block()); } Ok(()) diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index dc8e9add2..520359a3a 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -289,7 +289,7 @@ fn cfg_children_restrictions() { let block = b .add_node_with_parent( cfg, - ops::BasicBlock::DFB { + ops::DataflowBlock { inputs: type_row![BOOL_T], tuple_sum_rows: vec![type_row![]], other_outputs: type_row![BOOL_T], @@ -301,7 +301,7 @@ fn cfg_children_restrictions() { let exit = b .add_node_with_parent( cfg, - ops::BasicBlock::Exit { + ops::ExitBlock { cfg_outputs: type_row![BOOL_T], }, ) @@ -315,7 +315,7 @@ fn cfg_children_restrictions() { let exit2 = b .add_node_after( exit, - ops::BasicBlock::Exit { + ops::ExitBlock { cfg_outputs: type_row![BOOL_T], }, ) @@ -330,7 +330,7 @@ fn cfg_children_restrictions() { // Change the types in the BasicBlock node to work on qubits instead of bits b.replace_op( block, - NodeType::new_pure(ops::BasicBlock::DFB { + NodeType::new_pure(ops::DataflowBlock { inputs: type_row![Q], tuple_sum_rows: vec![type_row![]], other_outputs: type_row![Q], diff --git a/src/hugr/views/root_checked.rs b/src/hugr/views/root_checked.rs index 6b3f7aba3..242ac8ee7 100644 --- a/src/hugr/views/root_checked.rs +++ b/src/hugr/views/root_checked.rs @@ -74,7 +74,7 @@ mod test { use crate::hugr::hugrmut::sealed::HugrMutInternals; use crate::hugr::{HugrError, HugrMut, NodeType}; use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID}; - use crate::ops::{BasicBlock, LeafOp, OpTag}; + use crate::ops::{DataflowBlock, LeafOp, OpTag}; use crate::{ops, type_row, types::FunctionType, Hugr, HugrView}; #[test] @@ -94,7 +94,7 @@ mod test { let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); // That is a HugrMutInternal, so we can try: let root = dfg_v.root(); - let bb = NodeType::new_pure(BasicBlock::DFB { + let bb = NodeType::new_pure(DataflowBlock { inputs: type_row![], other_outputs: type_row![], tuple_sum_rows: vec![type_row![]], diff --git a/src/ops.rs b/src/ops.rs index 9c44f7f00..5c52983db 100644 --- a/src/ops.rs +++ b/src/ops.rs @@ -21,7 +21,7 @@ use smol_str::SmolStr; use enum_dispatch::enum_dispatch; pub use constant::Const; -pub use controlflow::{BasicBlock, Case, Conditional, TailLoop, CFG}; +pub use controlflow::{Case, Conditional, DataflowBlock, ExitBlock, TailLoop, CFG}; pub use dataflow::{Call, CallIndirect, Input, LoadConstant, Output, DFG}; pub use leaf::LeafOp; pub use module::{AliasDecl, AliasDefn, FuncDecl, FuncDefn, Module}; @@ -48,7 +48,8 @@ pub enum OpType { LoadConstant, DFG, LeafOp, - BasicBlock, + DataflowBlock, + ExitBlock, TailLoop, CFG, Conditional, @@ -93,7 +94,8 @@ impl_op_ref_try_into!(CallIndirect); impl_op_ref_try_into!(LoadConstant); impl_op_ref_try_into!(DFG, dfg); impl_op_ref_try_into!(LeafOp); -impl_op_ref_try_into!(BasicBlock); +impl_op_ref_try_into!(DataflowBlock); +impl_op_ref_try_into!(ExitBlock); impl_op_ref_try_into!(TailLoop); impl_op_ref_try_into!(CFG, cfg); impl_op_ref_try_into!(Conditional); diff --git a/src/ops/controlflow.rs b/src/ops/controlflow.rs index 4489cc4fe..4f6fe22af 100644 --- a/src/ops/controlflow.rs +++ b/src/ops/controlflow.rs @@ -115,50 +115,50 @@ impl DataflowOpTrait for CFG { } #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] -#[serde(tag = "block")] -/// Basic block ops - nodes valid in control flow graphs. +/// A CFG basic block node. The signature is that of the internal Dataflow graph. #[allow(missing_docs)] -pub enum BasicBlock { - /// A CFG basic block node. The signature is that of the internal Dataflow graph. - DFB { - inputs: TypeRow, - other_outputs: TypeRow, - tuple_sum_rows: Vec, - extension_delta: ExtensionSet, - }, - /// The single exit node of the CFG, has no children, - /// stores the types of the CFG node output. - Exit { cfg_outputs: TypeRow }, +pub struct DataflowBlock { + pub inputs: TypeRow, + pub other_outputs: TypeRow, + pub tuple_sum_rows: Vec, + pub extension_delta: ExtensionSet, +} + +#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +/// The single exit node of the CFG. Has no children, +/// stores the types of the CFG node output. +pub struct ExitBlock { + /// Output type row of the CFG. + pub cfg_outputs: TypeRow, } -impl OpName for BasicBlock { - /// The name of the operation. +impl OpName for DataflowBlock { fn name(&self) -> SmolStr { - match self { - BasicBlock::DFB { .. } => "DFB".into(), - BasicBlock::Exit { .. } => "Exit".into(), - } + "DataflowBlock".into() } } -impl StaticTag for BasicBlock { +impl OpName for ExitBlock { + fn name(&self) -> SmolStr { + "ExitBlock".into() + } +} + +impl StaticTag for DataflowBlock { const TAG: OpTag = OpTag::BasicBlock; } -impl OpTrait for BasicBlock { - /// The description of the operation. +impl StaticTag for ExitBlock { + const TAG: OpTag = OpTag::BasicBlockExit; +} + +impl OpTrait for DataflowBlock { fn description(&self) -> &str { - match self { - BasicBlock::DFB { .. } => "A CFG basic block node", - BasicBlock::Exit { .. } => "A CFG exit block node", - } + "A CFG basic block node" } /// Tag identifying the operation. fn tag(&self) -> OpTag { - match self { - BasicBlock::DFB { .. } => OpTag::BasicBlock, - BasicBlock::Exit { .. } => OpTag::BasicBlockExit, - } + Self::TAG } fn other_input(&self) -> Option { @@ -170,43 +170,73 @@ impl OpTrait for BasicBlock { } fn dataflow_signature(&self) -> Option { - Some(match self { - BasicBlock::DFB { - extension_delta, .. - } => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta), - BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]), - }) + Some( + FunctionType::new(type_row![], type_row![]).with_extension_delta(&self.extension_delta), + ) } fn non_df_port_count(&self, dir: Direction) -> usize { - match self { - Self::DFB { tuple_sum_rows, .. } if dir == Direction::Outgoing => tuple_sum_rows.len(), - Self::Exit { .. } if dir == Direction::Outgoing => 0, - _ => 1, + match dir { + Direction::Incoming => 1, + Direction::Outgoing => self.tuple_sum_rows.len(), } } } -impl BasicBlock { - /// The input signature of the contained dataflow graph. - pub fn dataflow_input(&self) -> &TypeRow { - match self { - BasicBlock::DFB { inputs, .. } => inputs, - BasicBlock::Exit { cfg_outputs } => cfg_outputs, +impl OpTrait for ExitBlock { + fn description(&self) -> &str { + "A CFG exit block node" + } + /// Tag identifying the operation. + fn tag(&self) -> OpTag { + Self::TAG + } + + fn other_input(&self) -> Option { + Some(EdgeKind::ControlFlow) + } + + fn other_output(&self) -> Option { + Some(EdgeKind::ControlFlow) + } + + fn dataflow_signature(&self) -> Option { + Some(FunctionType::new(type_row![], type_row![])) + } + + fn non_df_port_count(&self, dir: Direction) -> usize { + match dir { + Direction::Incoming => 1, + Direction::Outgoing => 0, } } +} + +/// Functionality shared by DataflowBlock and Exit CFG block types. +pub trait BasicBlock { + /// The input dataflow signature of the CFG block. + fn dataflow_input(&self) -> &TypeRow; +} +impl BasicBlock for DataflowBlock { + fn dataflow_input(&self) -> &TypeRow { + &self.inputs + } +} +impl DataflowBlock { /// The correct inputs of any successors. Returns None if successor is not a /// valid index. pub fn successor_input(&self, successor: usize) -> Option { - match self { - BasicBlock::DFB { - tuple_sum_rows, - other_outputs: outputs, - .. - } => Some(tuple_sum_first(tuple_sum_rows.get(successor)?, outputs)), - BasicBlock::Exit { .. } => panic!("Exit should have no successors"), - } + Some(tuple_sum_first( + self.tuple_sum_rows.get(successor)?, + &self.other_outputs, + )) + } +} + +impl BasicBlock for ExitBlock { + fn dataflow_input(&self) -> &TypeRow { + &self.cfg_outputs } } diff --git a/src/ops/handle.rs b/src/ops/handle.rs index f62524f4e..5b58dab3c 100644 --- a/src/ops/handle.rs +++ b/src/ops/handle.rs @@ -103,7 +103,7 @@ impl AliasID { pub struct ConstID(Node); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] -/// Handle to a [BasicBlock](crate::ops::BasicBlock) node. +/// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node. pub struct BasicBlockID(Node); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] diff --git a/src/ops/validate.rs b/src/ops/validate.rs index 1d6500430..d48d59ab4 100644 --- a/src/ops/validate.rs +++ b/src/ops/validate.rs @@ -12,7 +12,8 @@ use thiserror::Error; use crate::types::{FunctionType, Type, TypeRow}; -use super::{impl_validate_op, BasicBlock, OpTag, OpTrait, OpType, ValidateOp}; +use super::controlflow::BasicBlock; +use super::{impl_validate_op, DataflowBlock, ExitBlock, OpTag, OpTrait, OpType, ValidateOp}; /// A set of property flags required for an operation. #[non_exhaustive] @@ -285,21 +286,16 @@ pub struct ChildrenEdgeData { /// Target port. pub target_port: PortOffset, } - -impl ValidateOp for BasicBlock { +impl ValidateOp for DataflowBlock { /// Returns the set of allowed parent operation types. fn validity_flags(&self) -> OpValidityFlags { - match self { - BasicBlock::DFB { .. } => OpValidityFlags { - allowed_children: OpTag::DataflowChild, - allowed_first_child: OpTag::Input, - allowed_second_child: OpTag::Output, - requires_children: true, - requires_dag: true, - ..Default::default() - }, - // Default flags are valid for non-container operations - BasicBlock::Exit { .. } => Default::default(), + OpValidityFlags { + allowed_children: OpTag::DataflowChild, + allowed_first_child: OpTag::Input, + allowed_second_child: OpTag::Output, + requires_children: true, + requires_dag: true, + ..Default::default() } } @@ -308,23 +304,16 @@ impl ValidateOp for BasicBlock { &self, children: impl DoubleEndedIterator, ) -> Result<(), ChildrenValidationError> { - match self { - BasicBlock::DFB { - inputs, - tuple_sum_rows: tuple_sum_variants, - other_outputs: outputs, - extension_delta: _, - } => { - let tuple_sum_type = Type::new_tuple_sum(tuple_sum_variants.clone()); - let node_outputs: TypeRow = [&[tuple_sum_type], outputs.as_ref()].concat().into(); - validate_io_nodes(inputs, &node_outputs, "basic block graph", children) - } - // Exit nodes do not have children - BasicBlock::Exit { .. } => Ok(()), - } + let tuple_sum_type = Type::new_tuple_sum(self.tuple_sum_rows.clone()); + let node_outputs: TypeRow = [&[tuple_sum_type], self.other_outputs.as_ref()] + .concat() + .into(); + validate_io_nodes(&self.inputs, &node_outputs, "basic block graph", children) } } +impl ValidateOp for ExitBlock {} + impl ValidateOp for super::Case { /// Returns the set of allowed parent operation types. fn validity_flags(&self) -> OpValidityFlags { @@ -412,14 +401,18 @@ fn validate_io_nodes<'a>( /// Validate an edge between two basic blocks in a CFG sibling graph. fn validate_cfg_edge(edge: ChildrenEdgeData) -> Result<(), EdgeValidationError> { - let [source, target]: [&BasicBlock; 2] = [&edge.source_op, &edge.target_op].map(|op| { - let block_op = op - .as_basic_block() - .expect("CFG sibling graphs can only contain basic block operations."); - block_op - }); - - if source.successor_input(edge.source_port.index()).as_ref() != Some(target.dataflow_input()) { + let source = &edge + .source_op + .as_dataflow_block() + .expect("CFG sibling graphs can only contain basic block operations."); + + let target_input = match &edge.target_op { + OpType::DataflowBlock(dfb) => dfb.dataflow_input(), + OpType::ExitBlock(exit) => exit.dataflow_input(), + _ => panic!("CFG sibling graphs can only contain basic block operations."), + }; + + if source.successor_input(edge.source_port.index()).as_ref() != Some(target_input) { return Err(EdgeValidationError::CFGEdgeSignatureMismatch { edge }); }