Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OutlineCfg use insert_hugr; remove CfgBuilder::from_existing #298

Merged
merged 5 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 11 additions & 50 deletions src/builder/cfg.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use itertools::Itertools;

use super::{
build_traits::SubContainer,
dataflow::{DFGBuilder, DFGWrapper},
Expand Down Expand Up @@ -99,21 +97,6 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
})
}

/// Create a CFGBuilder for an existing CFG node (that already has entry + exit nodes)
pub(crate) fn from_existing(base: B, cfg_node: Node) -> Result<Self, BuildError> {
let OpType::CFG(crate::ops::controlflow::CFG {outputs, ..}) = base.get_optype(cfg_node)
else {return Err(BuildError::UnexpectedType{node: cfg_node, op_desc: "Any CFG"});};
let n_out_wires = outputs.len();
let (_, exit_node) = base.children(cfg_node).take(2).collect_tuple().unwrap();
Ok(Self {
base,
cfg_node,
inputs: None, // This will prevent creating an entry node
exit_node,
n_out_wires,
})
}

/// Return a builder for a non-entry [`BasicBlock::DFB`] child graph with `inputs`
/// and `outputs` and the variants of the branching predicate Sum value
/// specified by `predicate_variants`.
Expand Down Expand Up @@ -257,8 +240,7 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {
let db = DFGBuilder::create_with_io(base, block_n, signature)?;
Ok(BlockBuilder::from_dfg_builder(db))
}
}
impl<B: AsMut<Hugr> + AsRef<Hugr>> BlockBuilder<B> {

/// [Set outputs](BlockBuilder::set_outputs) and [finish](`BlockBuilder::finish_sub_container`).
pub fn finish_with_outputs(
mut self,
Expand Down Expand Up @@ -293,12 +275,20 @@ impl BlockBuilder<Hugr> {
let root = base.root();
Self::create(base, root, predicate_variants, other_outputs, inputs)
}

/// [Set outputs](BlockBuilder::set_outputs) and [finish_hugr](`BlockBuilder::finish_hugr`).
pub fn finish_hugr_with_outputs(
mut self,
branch_wire: Wire,
outputs: impl IntoIterator<Item = Wire>,
) -> Result<Hugr, BuildError> {
self.set_outputs(branch_wire, outputs)?;
self.finish_hugr().map_err(BuildError::InvalidHUGR)
}
}

#[cfg(test)]
mod test {
use std::collections::HashSet;

use crate::builder::build_traits::HugrBuilder;
use crate::builder::{DataflowSubContainer, ModuleBuilder};
use crate::macros::classic_row;
Expand Down Expand Up @@ -341,35 +331,6 @@ mod test {

Ok(())
}
#[test]
fn from_existing() -> Result<(), BuildError> {
let mut cfg_builder = CFGBuilder::new(type_row![NAT], type_row![NAT])?;
build_basic_cfg(&mut cfg_builder)?;
let h = cfg_builder.finish_hugr()?;

let mut new_builder = CFGBuilder::from_existing(h.clone(), h.root())?;
assert_matches!(new_builder.simple_entry_builder(type_row![NAT], 1), Err(_));
let h2 = new_builder.finish_hugr()?;
assert_eq!(h, h2); // No new nodes added

let mut new_builder = CFGBuilder::from_existing(h.clone(), h.root())?;
let block_builder = new_builder.simple_block_builder(
vec![SimpleType::new_simple_predicate(1), NAT].into(),
type_row![NAT],
1,
)?;
let new_bb = block_builder.container_node();
let [pred, nat]: [Wire; 2] = block_builder.input_wires_arr();
block_builder.finish_with_outputs(pred, [nat])?;
let h2 = new_builder.finish_hugr()?;
let expected_nodes = h
.children(h.root())
.chain([new_bb])
.collect::<HashSet<Node>>();
assert_eq!(expected_nodes, HashSet::from_iter(h2.children(h2.root())));

Ok(())
}

fn build_basic_cfg<T: AsMut<Hugr> + AsRef<Hugr>>(
cfg_builder: &mut CFGBuilder<T>,
Expand Down
52 changes: 26 additions & 26 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ use std::collections::HashSet;
use itertools::Itertools;
use thiserror::Error;

use crate::builder::{CFGBuilder, Container, Dataflow, SubContainer};
use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer};
use crate::hugr::rewrite::Rewrite;
use crate::hugr::{HugrMut, HugrView};
use crate::ops::handle::NodeHandle;
use crate::ops::{BasicBlock, ConstValue, OpType};
use crate::ops::{BasicBlock, ConstValue, OpTag, OpTrait, OpType};
use crate::{type_row, Hugr, Node};

/// Moves part of a Control-flow Sibling Graph into a new CFG-node
Expand Down Expand Up @@ -97,35 +96,36 @@ impl Rewrite for OutlineCfg {
OpType::BasicBlock(b) => b.dataflow_input().clone(),
_ => panic!("External successor not a basic block"),
};
let outer_entry = h.children(h.get_parent(entry).unwrap()).next().unwrap();
let outer_cfg = h.get_parent(entry).unwrap();
let outer_entry = h.children(outer_cfg).next().unwrap();

// 2. New CFG node will be contained in new single-successor BB
let mut existing_cfg = {
let parent = h.get_parent(entry).unwrap();
CFGBuilder::from_existing(h, parent).unwrap()
// 2. new_block contains input node, sub-cfg, exit node all connected
let new_block = {
let mut new_block_bldr =
BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap();
let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires());
let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap();
cfg.exit_block(); // Makes inner exit block (but no entry block)
let cfg_outputs = cfg.finish_sub_container().unwrap().outputs();
let predicate = new_block_bldr
.add_constant(ConstValue::simple_predicate(0, 1))
.unwrap();
let pred_wire = new_block_bldr.load_const(&predicate).unwrap();
let new_block_hugr = new_block_bldr
.finish_hugr_with_outputs(pred_wire, cfg_outputs)
.unwrap();
h.insert_hugr(outer_cfg, new_block_hugr).unwrap()
};
let mut new_block = existing_cfg
.block_builder(inputs.clone(), vec![type_row![]], outputs.clone())
.unwrap();

// 3. new_block contains input node, sub-cfg, exit node all connected
let wires_in = inputs.iter().cloned().zip(new_block.input_wires());
let cfg = new_block.cfg_builder(wires_in, outputs).unwrap();
let cfg_node = cfg.container_node();
let inner_exit = cfg.exit_block().node();
let cfg_outputs = cfg.finish_sub_container().unwrap().outputs();
let predicate = new_block
.add_constant(ConstValue::simple_predicate(0, 1))
// 3. Extract Cfg node created above (it moved when we called insert_hugr)
let cfg_node = h
.children(new_block)
.filter(|n| h.get_optype(*n).tag() == OpTag::Cfg)
.exactly_one()
.unwrap();
let pred_wire = new_block.load_const(&predicate).unwrap();
let new_block = new_block
.finish_with_outputs(pred_wire, cfg_outputs)
.unwrap()
.node();
let inner_exit = h.children(cfg_node).exactly_one().unwrap();

// 4. Entry edges. Change any edges into entry_block from outside, to target new_block
let h = existing_cfg.hugr_mut();

let preds: Vec<_> = h
.linked_ports(entry, h.node_inputs(entry).exactly_one().unwrap())
.collect();
Expand Down