diff --git a/src/builder/cfg.rs b/src/builder/cfg.rs index 85df2a4f6..362cfdbcf 100644 --- a/src/builder/cfg.rs +++ b/src/builder/cfg.rs @@ -1,5 +1,3 @@ -use itertools::Itertools; - use super::{ build_traits::SubContainer, dataflow::{DFGBuilder, DFGWrapper}, @@ -99,21 +97,6 @@ impl + AsRef> CFGBuilder { }) } - /// Create a CFGBuilder for an existing CFG node (that already has entry + exit nodes) - pub(crate) fn from_existing(base: B, cfg_node: Node) -> Result { - let OpType::CFG(crate::ops::controlflow::CFG {outputs, ..}) = base.get_optype(cfg_node) - else {return Err(BuildError::UnexpectedType{node: cfg_node, op_desc: "Any CFG"});}; - let n_out_wires = outputs.len(); - let (_, exit_node) = base.children(cfg_node).take(2).collect_tuple().unwrap(); - Ok(Self { - base, - cfg_node, - inputs: None, // This will prevent creating an entry node - exit_node, - n_out_wires, - }) - } - /// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs` /// and `outputs` and the variants of the branching predicate Sum value /// specified by `predicate_variants`. @@ -257,8 +240,7 @@ impl + AsRef> BlockBuilder { let db = DFGBuilder::create_with_io(base, block_n, signature)?; Ok(BlockBuilder::from_dfg_builder(db)) } -} -impl + AsRef> BlockBuilder { + /// [Set outputs](BlockBuilder::set_outputs) and [finish](`BlockBuilder::finish_sub_container`). pub fn finish_with_outputs( mut self, @@ -293,12 +275,20 @@ impl BlockBuilder { let root = base.root(); Self::create(base, root, predicate_variants, other_outputs, inputs) } + + /// [Set outputs](BlockBuilder::set_outputs) and [finish_hugr](`BlockBuilder::finish_hugr`). + pub fn finish_hugr_with_outputs( + mut self, + branch_wire: Wire, + outputs: impl IntoIterator, + ) -> Result { + self.set_outputs(branch_wire, outputs)?; + self.finish_hugr().map_err(BuildError::InvalidHUGR) + } } #[cfg(test)] mod test { - use std::collections::HashSet; - use crate::builder::build_traits::HugrBuilder; use crate::builder::{DataflowSubContainer, ModuleBuilder}; use crate::macros::classic_row; @@ -341,35 +331,6 @@ mod test { Ok(()) } - #[test] - fn from_existing() -> Result<(), BuildError> { - let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?; - build_basic_cfg(&mut cfg_builder)?; - let h = cfg_builder.finish_hugr()?; - - let mut new_builder = CFGBuilder::from_existing(h.clone(), h.root())?; - assert_matches!(new_builder.simple_entry_builder(type_row![NAT], 1), Err(_)); - let h2 = new_builder.finish_hugr()?; - assert_eq!(h, h2); // No new nodes added - - let mut new_builder = CFGBuilder::from_existing(h.clone(), h.root())?; - let block_builder = new_builder.simple_block_builder( - vec![SimpleType::new_simple_predicate(1), NAT].into(), - type_row![NAT], - 1, - )?; - let new_bb = block_builder.container_node(); - let [pred, nat]: [Wire; 2] = block_builder.input_wires_arr(); - block_builder.finish_with_outputs(pred, [nat])?; - let h2 = new_builder.finish_hugr()?; - let expected_nodes = h - .children(h.root()) - .chain([new_bb]) - .collect::>(); - assert_eq!(expected_nodes, HashSet::from_iter(h2.children(h2.root()))); - - Ok(()) - } fn build_basic_cfg + AsRef>( cfg_builder: &mut CFGBuilder, diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index dc9051ae5..771d75882 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -4,11 +4,10 @@ use std::collections::HashSet; use itertools::Itertools; use thiserror::Error; -use crate::builder::{CFGBuilder, Container, Dataflow, SubContainer}; +use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; use crate::hugr::rewrite::Rewrite; use crate::hugr::{HugrMut, HugrView}; -use crate::ops::handle::NodeHandle; -use crate::ops::{BasicBlock, ConstValue, OpType}; +use crate::ops::{BasicBlock, ConstValue, OpTag, OpTrait, OpType}; use crate::{type_row, Hugr, Node}; /// Moves part of a Control-flow Sibling Graph into a new CFG-node @@ -97,35 +96,36 @@ impl Rewrite for OutlineCfg { OpType::BasicBlock(b) => b.dataflow_input().clone(), _ => panic!("External successor not a basic block"), }; - let outer_entry = h.children(h.get_parent(entry).unwrap()).next().unwrap(); + let outer_cfg = h.get_parent(entry).unwrap(); + let outer_entry = h.children(outer_cfg).next().unwrap(); - // 2. New CFG node will be contained in new single-successor BB - let mut existing_cfg = { - let parent = h.get_parent(entry).unwrap(); - CFGBuilder::from_existing(h, parent).unwrap() + // 2. new_block contains input node, sub-cfg, exit node all connected + let new_block = { + let mut new_block_bldr = + BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap(); + let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires()); + let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap(); + cfg.exit_block(); // Makes inner exit block (but no entry block) + let cfg_outputs = cfg.finish_sub_container().unwrap().outputs(); + let predicate = new_block_bldr + .add_constant(ConstValue::simple_predicate(0, 1)) + .unwrap(); + let pred_wire = new_block_bldr.load_const(&predicate).unwrap(); + let new_block_hugr = new_block_bldr + .finish_hugr_with_outputs(pred_wire, cfg_outputs) + .unwrap(); + h.insert_hugr(outer_cfg, new_block_hugr).unwrap() }; - let mut new_block = existing_cfg - .block_builder(inputs.clone(), vec![type_row![]], outputs.clone()) - .unwrap(); - // 3. new_block contains input node, sub-cfg, exit node all connected - let wires_in = inputs.iter().cloned().zip(new_block.input_wires()); - let cfg = new_block.cfg_builder(wires_in, outputs).unwrap(); - let cfg_node = cfg.container_node(); - let inner_exit = cfg.exit_block().node(); - let cfg_outputs = cfg.finish_sub_container().unwrap().outputs(); - let predicate = new_block - .add_constant(ConstValue::simple_predicate(0, 1)) + // 3. Extract Cfg node created above (it moved when we called insert_hugr) + let cfg_node = h + .children(new_block) + .filter(|n| h.get_optype(*n).tag() == OpTag::Cfg) + .exactly_one() .unwrap(); - let pred_wire = new_block.load_const(&predicate).unwrap(); - let new_block = new_block - .finish_with_outputs(pred_wire, cfg_outputs) - .unwrap() - .node(); + let inner_exit = h.children(cfg_node).exactly_one().unwrap(); // 4. Entry edges. Change any edges into entry_block from outside, to target new_block - let h = existing_cfg.hugr_mut(); - let preds: Vec<_> = h .linked_ports(entry, h.node_inputs(entry).exactly_one().unwrap()) .collect();