diff --git a/src/hugr/rewrite/outline_cfg.rs b/src/hugr/rewrite/outline_cfg.rs index dd84b1a44..0c1e02aea 100644 --- a/src/hugr/rewrite/outline_cfg.rs +++ b/src/hugr/rewrite/outline_cfg.rs @@ -36,10 +36,11 @@ impl OutlineCfg { if !matches!(o, OpType::CFG(_)) { return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); }; + let cfg_entry = h.children(cfg_n).next().unwrap(); let mut entry = None; let mut exit_succ = None; for &n in self.0.iter() { - if h.input_neighbours(n).any(|pred| !self.0.contains(&pred)) { + if n == cfg_entry || h.input_neighbours(n).any(|pred| !self.0.contains(&pred)) { match entry { None => { entry = Some(n); @@ -89,6 +90,7 @@ impl Rewrite for OutlineCfg { OpType::BasicBlock(b) => b.dataflow_input().clone(), _ => panic!("External successor not a basic block"), }; + let is_outer_entry = h.children(h.get_parent(entry).unwrap()).next().unwrap() == entry; // 2. New CFG node will be contained in new single-successor BB let mut existing_cfg = { @@ -111,7 +113,8 @@ impl Rewrite for OutlineCfg { let pred_wire = new_block.load_const(&predicate).unwrap(); let new_block = new_block .finish_with_outputs(pred_wire, cfg_outputs) - .unwrap(); + .unwrap() + .node(); // 4. Entry edges. Change any edges into entry_block from outside, to target new_block let h = existing_cfg.hugr_mut(); @@ -122,9 +125,17 @@ impl Rewrite for OutlineCfg { for (pred, br) in preds { if !self.0.contains(&pred) { h.disconnect(pred, br).unwrap(); - h.connect(pred, br.index(), new_block.node(), 0).unwrap(); + h.connect(pred, br.index(), new_block, 0).unwrap(); } } + if is_outer_entry { + // new_block must be the entry node, i.e. first child, of the enclosing CFG + // (the current entry node will be reparented inside new_block below) + let parent = h.hierarchy.detach(new_block.index).unwrap(); + h.hierarchy + .push_front_child(new_block.index, parent) + .unwrap(); + } // 5. Children of new CFG. // Entry node must be first @@ -155,7 +166,7 @@ impl Rewrite for OutlineCfg { h.disconnect(exit, exit_port).unwrap(); h.connect(exit, exit_port.index(), inner_exit, 0).unwrap(); // And connect new_block to outside instead - h.connect(new_block.node(), 0, outside, 0).unwrap(); + h.connect(new_block, 0, outside, 0).unwrap(); Ok(()) }