Skip to content

Commit

Permalink
feat: Extension inference for CFGs
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor committed Sep 13, 2023
1 parent b247378 commit 761feec
Showing 1 changed file with 195 additions and 2 deletions.
197 changes: 195 additions & 2 deletions src/extension/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,18 @@ impl UnificationContext {
}
}

if node_type.tag() == OpTag::Cfg {
let mut children = hugr.children(node);
let entry = children.next().unwrap();
let exit = children.next().unwrap();
let m_entry = self.make_or_get_meta(entry, Direction::Incoming);
let m_exit_in = self.make_or_get_meta(exit, Direction::Incoming);
let m_exit_out = self.make_or_get_meta(exit, Direction::Outgoing);
self.add_constraint(m_input, Constraint::Equal(m_entry));
self.add_constraint(m_output, Constraint::Equal(m_exit_in));
self.add_constraint(m_output, Constraint::Equal(m_exit_out));
}

match node_type.signature() {
// Input extensions are open
None => {
Expand All @@ -324,8 +336,12 @@ impl UnificationContext {
let sig: &OpType = hugr.get_nodetype(tgt_node).into();
// Incoming ports with a dataflow edge
for port in hugr.node_inputs(tgt_node).filter(|src_port| {
matches!(sig.port_kind(*src_port), Some(EdgeKind::Value(_)))
|| matches!(sig.port_kind(*src_port), Some(EdgeKind::Static(_)))
matches!(
sig.port_kind(*src_port),
Some(EdgeKind::Value(_))
| Some(EdgeKind::Static(_))
| Some(EdgeKind::ControlFlow)
)
}) {
for (src_node, _) in hugr.linked_ports(tgt_node, port) {
let m_src = self
Expand Down Expand Up @@ -1091,4 +1107,181 @@ mod test {
}
Ok(())
}

fn make_opaque(extension: impl Into<ExtensionId>, signature: FunctionType) -> ops::LeafOp {
let opaque =
ops::custom::OpaqueOp::new(extension.into(), "", "".into(), vec![], Some(signature));
ops::custom::ExternalOp::from(opaque).into()
}

/// A CFG rooted hugr adding resources at each basic block.
/// Looks like this:
///
/// +-------------+
/// | Entry |
/// | (Adds [A]) |
/// +-/---------\-+
/// / \
/// +-------/-----+ +-\-------------+
/// | BB0 | | BB1 |
/// | (Adds [BC]) | | (Adds [B]) |
/// +----\--------+ +---/------\----+
/// \ / \
/// \ / \
/// \ +----/-------+ +-\---------+
/// \ | BB10 | | BB11 |
/// \ | (Adds [C]) | | (Adds [C])|
/// \ +----+-------+ +/----------+
/// \ | /
/// +-----\-------+---------/-+
/// | Exit |
/// +-------------------------+
#[test]
fn infer_cfg_test() -> Result<(), Box<dyn Error>> {
let a = ExtensionSet::singleton(&A);
let abc = ExtensionSet::from_iter([A, B, C]);
let bc = ExtensionSet::from_iter([B, C]);
let b = ExtensionSet::singleton(&B);
let c = ExtensionSet::singleton(&C);

let oneway = vec![Type::new_predicate([type_row![NAT]])];
let twoway = vec![Type::new_predicate([type_row![NAT], type_row![NAT]])];

let mut hugr = Hugr::new(NodeType::open_extensions(ops::CFG {
signature: FunctionType::new(type_row![NAT], type_row![NAT]).with_extension_delta(&abc),
}));

let root = hugr.root();

let [entry, entry_in, entry_out] = create_with_io(
&mut hugr,
root,
ops::BasicBlock::DFB {
inputs: type_row![NAT],
other_outputs: type_row![],
predicate_variants: vec![type_row![NAT], type_row![NAT]],
extension_delta: a.clone(),
},
)?;

let exit = hugr.add_node_with_parent(
hugr.root(),
NodeType::open_extensions(ops::BasicBlock::Exit {
cfg_outputs: type_row![NAT],
}),
)?;

let mkpred = hugr.add_node_with_parent(
entry,
NodeType::open_extensions(make_opaque(
A,
FunctionType::new(vec![NAT], twoway.clone()).with_extension_delta(&a),
)),
)?;

let [bb0, bb0_in, bb0_out] = create_with_io(
&mut hugr,
root,
ops::BasicBlock::DFB {
inputs: type_row![NAT],
other_outputs: type_row![],
predicate_variants: vec![type_row![NAT]],
extension_delta: bc.clone(),
},
)?;

let [bb1, bb1_in, bb1_out] = create_with_io(
&mut hugr,
root,
ops::BasicBlock::DFB {
inputs: type_row![NAT],
other_outputs: type_row![],
predicate_variants: vec![type_row![NAT], type_row![NAT]],
extension_delta: b.clone(),
},
)?;

let add_bc = hugr.add_node_with_parent(
bb0,
NodeType::open_extensions(make_opaque(
B,
FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&bc),
)),
)?;

let add_b = hugr.add_node_with_parent(
bb1,
NodeType::open_extensions(make_opaque(
B,
FunctionType::new(vec![NAT], twoway.clone()).with_extension_delta(&b),
)),
)?;

let [bb10, bb10_in, bb10_out] = create_with_io(
&mut hugr,
root,
ops::BasicBlock::DFB {
inputs: type_row![NAT],
other_outputs: type_row![],
predicate_variants: vec![type_row![NAT]],
extension_delta: c.clone(),
},
)?;

let [bb11, bb11_in, bb11_out] = create_with_io(
&mut hugr,
root,
ops::BasicBlock::DFB {
inputs: type_row![NAT],
other_outputs: type_row![],
predicate_variants: vec![type_row![NAT]],
extension_delta: c.clone(),
},
)?;

let add_c0 = hugr.add_node_with_parent(
bb10,
NodeType::open_extensions(make_opaque(
C,
FunctionType::new(vec![NAT], oneway.clone()).with_extension_delta(&c),
)),
)?;

let add_c1 = hugr.add_node_with_parent(
bb11,
NodeType::open_extensions(make_opaque(
C,
FunctionType::new(vec![NAT], oneway).with_extension_delta(&c),
)),
)?;

// Internal wiring for DFGs
hugr.connect(entry_in, 0, mkpred, 0)?;
hugr.connect(mkpred, 0, entry_out, 0)?;

hugr.connect(bb0_in, 0, add_bc, 0)?;
hugr.connect(add_bc, 0, bb0_out, 0)?;

hugr.connect(bb1_in, 0, add_b, 0)?;
hugr.connect(add_b, 0, bb1_out, 0)?;

hugr.connect(bb10_in, 0, add_c0, 0)?;
hugr.connect(add_c0, 0, bb10_out, 0)?;
hugr.connect(bb11_in, 0, add_c1, 0)?;
hugr.connect(add_c1, 0, bb11_out, 0)?;

// CFG Wiring
hugr.connect(entry, 0, bb0, 0)?;
hugr.connect(entry, 1, bb1, 0)?;
hugr.connect(bb1, 0, bb10, 0)?;
hugr.connect(bb1, 1, bb11, 0)?;

hugr.connect(bb0, 0, exit, 0)?;
hugr.connect(bb10, 0, exit, 0)?;
hugr.connect(bb11, 0, exit, 0)?;

hugr.infer_extensions()?;

Ok(())
}
}

0 comments on commit 761feec

Please sign in to comment.