Skip to content

Commit

Permalink
OutlineCfg rewrite returns new BB and CFG nodes (#554)
Browse files Browse the repository at this point in the history
This seems like a good use of the facility for a rewrite to return a
rewrite-specific result, and preliminary to other upcoming PRs
(specifically #247)

The new CFG node is non-trivial to find given its parent (basic block)
as that has several children: not only Input and Output but also
constant nodes to build the predicate.
  • Loading branch information
acl-cqc authored Sep 25, 2023
1 parent 19ed0fc commit a9fdcaa
Showing 1 changed file with 28 additions and 8 deletions.
36 changes: 28 additions & 8 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,17 @@ impl OutlineCfg {

impl Rewrite for OutlineCfg {
type Error = OutlineCfgError;
type ApplyResult = ();
/// The newly-created basic block, and the [CFG] node inside it
///
/// [CFG]: OpType::CFG
type ApplyResult = (Node, Node);

const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> {
self.compute_entry_exit_outside_extensions(h)?;
Ok(())
}
fn apply(self, h: &mut impl HugrMut) -> Result<(), OutlineCfgError> {
fn apply(self, h: &mut impl HugrMut) -> Result<(Node, Node), OutlineCfgError> {
let (entry, exit, outside, extension_delta) =
self.compute_entry_exit_outside_extensions(h)?;
// 1. Compute signature
Expand Down Expand Up @@ -206,7 +209,7 @@ impl Rewrite for OutlineCfg {
SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap();
in_cfg_view.connect(exit, exit_port, inner_exit, 0).unwrap();

Ok(())
Ok((new_block, cfg_node))
}
}

Expand Down Expand Up @@ -251,6 +254,7 @@ mod test {
use crate::hugr::views::sibling::SiblingMut;
use crate::hugr::HugrMut;
use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle};
use crate::ops::{BasicBlock, OpType};
use crate::types::FunctionType;
use crate::{type_row, Hugr, HugrView, Node};
use cool_asserts::assert_matches;
Expand Down Expand Up @@ -321,11 +325,14 @@ mod test {
assert_eq!(depth(h.base_hugr(), n), expected_depth);
}
let blocks = [head, left, right, merge];
h.apply_rewrite(OutlineCfg::new(blocks)).unwrap();
let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks)).unwrap();
for n in blocks {
assert_eq!(depth(h.base_hugr(), n), expected_depth + 2);
}
let new_block = h.output_neighbours(entry).exactly_one().ok().unwrap();
assert_eq!(
new_block,
h.output_neighbours(entry).exactly_one().ok().unwrap()
);
for n in [entry, exit, tail, new_block] {
assert_eq!(depth(h.base_hugr(), n), expected_depth);
}
Expand All @@ -337,6 +344,12 @@ mod test {
h.output_neighbours(tail).take(2).collect::<HashSet<Node>>(),
HashSet::from([exit, new_block])
);
assert_matches!(
h.get_optype(new_block),
OpType::BasicBlock(BasicBlock::DFB { .. })
);
assert_eq!(h.base_hugr().get_parent(new_cfg), Some(new_block));
assert_matches!(h.base_hugr().get_optype(new_cfg), OpType::CFG(_));
}

#[test]
Expand Down Expand Up @@ -387,15 +400,22 @@ mod test {
for &n in blocks_to_move.iter().chain(other_blocks.iter()) {
assert_eq!(depth(&h, n), 1);
}
h.apply_rewrite(OutlineCfg::new(blocks_to_move.iter().copied()))
let (new_block, new_cfg) = h
.apply_rewrite(OutlineCfg::new(blocks_to_move.iter().copied()))
.unwrap();
h.validate(&PRELUDE_REGISTRY).unwrap();
let new_entry = h.children(h.root()).next().unwrap();
assert_eq!(new_block, h.children(h.root()).next().unwrap());
assert_matches!(
h.get_optype(new_block),
OpType::BasicBlock(BasicBlock::DFB { .. })
);
assert_eq!(h.get_parent(new_cfg), Some(new_block));
assert_matches!(h.get_optype(new_cfg), OpType::CFG(_));
for n in other_blocks {
assert_eq!(depth(&h, n), 1);
}
for n in blocks_to_move {
assert_eq!(h.get_parent(h.get_parent(n).unwrap()).unwrap(), new_entry);
assert_eq!(h.get_parent(n).unwrap(), new_cfg);
}
}
}

0 comments on commit a9fdcaa

Please sign in to comment.