Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1 of 4] Rewrites use HugrMut methods rather than .hierarchy/.graph #280

Merged
merged 8 commits into from
Jul 24, 2023
17 changes: 5 additions & 12 deletions src/hugr/rewrite/outline_cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl Rewrite for OutlineCfg {
OpType::BasicBlock(b) => b.dataflow_input().clone(),
_ => panic!("External successor not a basic block"),
};
let is_outer_entry = h.children(h.get_parent(entry).unwrap()).next().unwrap() == entry;
let outer_entry = h.children(h.get_parent(entry).unwrap()).next().unwrap();

// 2. New CFG node will be contained in new single-successor BB
let mut existing_cfg = {
Expand Down Expand Up @@ -135,27 +135,20 @@ impl Rewrite for OutlineCfg {
h.connect(pred, br.index(), new_block, 0).unwrap();
}
}
if is_outer_entry {
if entry == outer_entry {
// new_block must be the entry node, i.e. first child, of the enclosing CFG
// (the current entry node will be reparented inside new_block below)
let parent = h.hierarchy.detach(new_block.index).unwrap();
h.hierarchy
.push_front_child(new_block.index, parent)
.unwrap();
h.move_before_sibling(new_block, outer_entry).unwrap();
}

// 5. Children of new CFG.
// Entry node must be first
h.hierarchy.detach(entry.index);
h.hierarchy
.insert_before(entry.index, inner_exit.index)
.unwrap();
h.move_before_sibling(entry, inner_exit).unwrap();
// And remaining nodes
for n in self.blocks {
// Do not move the entry node, as we have already
if n != entry {
h.hierarchy.detach(n.index);
h.hierarchy.push_child(n.index, cfg_node.index).unwrap();
h.set_parent(n, cfg_node).unwrap();
}
}

Expand Down
147 changes: 47 additions & 100 deletions src/hugr/rewrite/simple_replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
use std::collections::{HashMap, HashSet};

use itertools::Itertools;
use portgraph::{LinkMut, LinkView, MultiMut, NodeIndex, PortView};

use crate::hugr::{HugrMut, HugrView, NodeMetadata};
use crate::{
Expand Down Expand Up @@ -64,16 +63,14 @@ impl Rewrite for SimpleReplacement {
}
// 2. Check that all the to-be-removed nodes are children of it and are leaves.
for node in &self.removal {
if h.hierarchy.parent(node.index) != Some(self.parent.index)
|| h.hierarchy.has_children(node.index)
{
if h.get_parent(*node) != Some(self.parent) || h.children(*node).next().is_some() {
return Err(SimpleReplacementError::InvalidRemovedNode());
}
}
// 3. Do the replacement.
// 3.1. Add copies of all replacement nodes and edges to h. Exclude Input/Output nodes.
// Create map from old NodeIndex (in self.replacement) to new NodeIndex (in self).
let mut index_map: HashMap<NodeIndex, NodeIndex> = HashMap::new();
let mut index_map: HashMap<Node, Node> = HashMap::new();
let replacement_nodes = self
.replacement
.children(self.replacement.root())
Expand All @@ -92,13 +89,13 @@ impl Rewrite for SimpleReplacement {
return Err(SimpleReplacementError::InvalidReplacementNode());
}
}
let self_output_node_index = h.children(self.parent).nth(1).unwrap();
let self_output_node = h.children(self.parent).nth(1).unwrap();
let replacement_output_node = *replacement_nodes.get(1).unwrap();
for &node in replacement_inner_nodes {
// Add the nodes.
let op: &OpType = self.replacement.get_optype(node);
let new_node_index = h.add_op_after(self_output_node_index, op.clone()).unwrap();
index_map.insert(node.index, new_node_index.index);
let new_node = h.add_op_after(self_output_node, op.clone()).unwrap();
index_map.insert(node, new_node);

// Move the metadata
let meta: &NodeMetadata = self.replacement.get_metadata(node);
Expand All @@ -107,34 +104,12 @@ impl Rewrite for SimpleReplacement {
// Add edges between all newly added nodes matching those in replacement.
// TODO This will probably change when implicit copies are implemented.
for &node in replacement_inner_nodes {
let new_node_index = index_map.get(&node.index).unwrap();
for node_successor in self.replacement.output_neighbours(node).unique() {
if self.replacement.get_optype(node_successor).tag() != OpTag::Output {
let new_node_successor_index = index_map.get(&node_successor.index).unwrap();
for connection in self
.replacement
.graph
.get_connections(node.index, node_successor.index)
{
let src_offset = self
.replacement
.graph
.port_offset(connection.0)
.unwrap()
.index();
let tgt_offset = self
.replacement
.graph
.port_offset(connection.1)
.unwrap()
.index();
h.graph
.link_nodes(
*new_node_index,
src_offset,
*new_node_successor_index,
tgt_offset,
)
let new_node = index_map.get(&node).unwrap();
for outport in self.replacement.node_outputs(node) {
for target in self.replacement.linked_ports(node, outport) {
if self.replacement.get_optype(target.0).tag() != OpTag::Output {
let new_target = index_map.get(&target.0).unwrap();
h.connect(*new_node, outport.index(), *new_target, target.1.index())
.unwrap();
}
}
Expand All @@ -144,66 +119,40 @@ impl Rewrite for SimpleReplacement {
// predecessor of p to (the new copy of) q.
for ((rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port)) in &self.nu_inp {
if self.replacement.get_optype(*rep_inp_node).tag() != OpTag::Output {
let new_inp_node_index = index_map.get(&rep_inp_node.index).unwrap();
// add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port)
let rem_inp_port_index = h
.graph
.port_index(rem_inp_node.index, rem_inp_port.offset)
.unwrap();
let rem_inp_predecessor_subport = h.graph.port_link(rem_inp_port_index).unwrap();
let rem_inp_predecessor_port_index = rem_inp_predecessor_subport.port();
let new_inp_port_index = h
.graph
.port_index(*new_inp_node_index, rep_inp_port.offset)
.unwrap();
h.graph.unlink_subport(rem_inp_predecessor_subport);
h.graph
.link_ports(rem_inp_predecessor_port_index, new_inp_port_index)
let (rem_inp_pred_node, rem_inp_pred_port) = h
.linked_ports(*rem_inp_node, *rem_inp_port)
.exactly_one()
.unwrap();
h.disconnect(*rem_inp_node, *rem_inp_port).unwrap();
let new_inp_node = index_map.get(rep_inp_node).unwrap();
h.connect(
rem_inp_pred_node,
rem_inp_pred_port.index(),
*new_inp_node,
rep_inp_port.offset.index(),
)
.unwrap();
}
}
// 3.3. For each q = self.nu_out[p] such that the predecessor of q is not an Input port, add an
// edge from (the new copy of) the predecessor of q to p.
for ((rem_out_node, rem_out_port), rep_out_port) in &self.nu_out {
let rem_out_port_index = h
.graph
.port_index(rem_out_node.index, rem_out_port.offset)
.unwrap();
let rep_out_port_index = self
.replacement
.graph
.port_index(replacement_output_node.index, rep_out_port.offset)
.unwrap();
let rep_out_predecessor_port_index = self
let (rep_out_pred_node, rep_out_pred_port) = self
.replacement
.graph
.port_link(rep_out_port_index)
.linked_ports(replacement_output_node, *rep_out_port)
.exactly_one()
.unwrap();
let rep_out_predecessor_node_index = self
.replacement
.graph
.port_node(rep_out_predecessor_port_index)
if self.replacement.get_optype(rep_out_pred_node).tag() != OpTag::Input {
let new_out_node = index_map.get(&rep_out_pred_node).unwrap();
h.disconnect(*rem_out_node, *rem_out_port).unwrap();
h.connect(
*new_out_node,
rep_out_pred_port.index(),
*rem_out_node,
rem_out_port.index(),
)
.unwrap();
if self
.replacement
.get_optype(rep_out_predecessor_node_index.into())
.tag()
!= OpTag::Input
{
let rep_out_predecessor_port_offset = self
.replacement
.graph
.port_offset(rep_out_predecessor_port_index)
.unwrap();
let new_out_node_index = index_map.get(&rep_out_predecessor_node_index).unwrap();
let new_out_port_index = h
.graph
.port_index(*new_out_node_index, rep_out_predecessor_port_offset)
.unwrap();
h.graph.unlink_port(rem_out_port_index);
h.graph
.link_ports(new_out_port_index, rem_out_port_index)
.unwrap();
}
}
// 3.4. For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0
Expand All @@ -212,21 +161,19 @@ impl Rewrite for SimpleReplacement {
let rem_inp_nodeport = self.nu_inp.get(&(replacement_output_node, rep_out_port));
if let Some((rem_inp_node, rem_inp_port)) = rem_inp_nodeport {
// add edge from predecessor of (rem_inp_node, rem_inp_port) to (rem_out_node, rem_out_port):
let rem_inp_port_index = h
.graph
.port_index(rem_inp_node.index, rem_inp_port.offset)
.unwrap();
let rem_inp_predecessor_port_index =
h.graph.port_link(rem_inp_port_index).unwrap().port();
let rem_out_port_index = h
.graph
.port_index(rem_out_node.index, rem_out_port.offset)
.unwrap();
h.graph.unlink_port(rem_inp_port_index);
h.graph.unlink_port(rem_out_port_index);
h.graph
.link_ports(rem_inp_predecessor_port_index, rem_out_port_index)
let (rem_inp_pred_node, rem_inp_pred_port) = h
.linked_ports(*rem_inp_node, *rem_inp_port)
.exactly_one()
.unwrap();
h.disconnect(*rem_inp_node, *rem_inp_port).unwrap();
h.disconnect(*rem_out_node, *rem_out_port).unwrap();
h.connect(
rem_inp_pred_node,
rem_inp_pred_port.index(),
*rem_out_node,
rem_out_port.index(),
)
.unwrap();
}
}
// 3.5. Remove all nodes in self.removal and edges between them.
Expand Down