Skip to content

Commit

Permalink
feat: Add extension deltas to CFG ops (#503)
Browse files Browse the repository at this point in the history
  • Loading branch information
croyzor authored Sep 7, 2023
1 parent 014a957 commit 71b6ff0
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ pub trait Dataflow: Container {
NodeType::open_extensions(ops::CFG {
inputs: inputs.clone(),
outputs: output_types.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
}),
input_wires,
)?;
Expand Down
11 changes: 10 additions & 1 deletion src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use super::{
};

use crate::ops::{self, BasicBlock, OpType};
use crate::{extension::ExtensionRegistry, types::FunctionType};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
};
use crate::{hugr::views::HugrView, types::TypeRow};
use crate::{ops::handle::NodeHandle, types::Type};

Expand Down Expand Up @@ -60,6 +63,8 @@ impl CFGBuilder<Hugr> {
let cfg_op = ops::CFG {
inputs: input.clone(),
outputs: output.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
};

// TODO: Allow input extensions to be specified
Expand Down Expand Up @@ -130,6 +135,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
});
let parent = self.container_node();
let block_n = if entry {
Expand Down Expand Up @@ -277,6 +284,8 @@ impl BlockBuilder<Hugr> {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: make this a parameter
extension_delta: ExtensionSet::new(),
};

// TODO: Allow input extensions to be specified
Expand Down
3 changes: 3 additions & 0 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ mod test {
NodeType::pure(ops::CFG {
inputs: type_row![BOOL_T],
outputs: type_row![BOOL_T],
extension_delta: ExtensionSet::new(),
}),
);
assert_matches!(
Expand All @@ -969,6 +970,7 @@ mod test {
inputs: type_row![BOOL_T],
predicate_variants: vec![type_row![]],
other_outputs: type_row![BOOL_T],
extension_delta: ExtensionSet::new(),
},
)
.unwrap();
Expand Down Expand Up @@ -1009,6 +1011,7 @@ mod test {
inputs: type_row![Q],
predicate_variants: vec![type_row![]],
other_outputs: type_row![Q],
extension_delta: ExtensionSet::new(),
}),
);
let mut block_children = b.hierarchy.children(block.index);
Expand Down
13 changes: 13 additions & 0 deletions src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use smol_str::SmolStr;

use crate::extension::ExtensionSet;
use crate::type_row;
use crate::types::{EdgeKind, FunctionType, Type, TypeRow};

use super::dataflow::DataflowOpTrait;
Expand Down Expand Up @@ -97,6 +98,7 @@ impl Conditional {
pub struct CFG {
pub inputs: TypeRow,
pub outputs: TypeRow,
pub extension_delta: ExtensionSet,
}

impl_op_name!(CFG);
Expand All @@ -110,6 +112,7 @@ impl DataflowOpTrait for CFG {

fn signature(&self) -> FunctionType {
FunctionType::new(self.inputs.clone(), self.outputs.clone())
.with_extension_delta(&self.extension_delta)
}
}

Expand All @@ -123,6 +126,7 @@ pub enum BasicBlock {
inputs: TypeRow,
other_outputs: TypeRow,
predicate_variants: Vec<TypeRow>,
extension_delta: ExtensionSet,
},
/// The single exit node of the CFG, has no children,
/// stores the types of the CFG node output.
Expand Down Expand Up @@ -166,6 +170,15 @@ impl OpTrait for BasicBlock {
fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::ControlFlow)
}

fn signature(&self) -> FunctionType {
match self {
BasicBlock::DFB {
extension_delta, ..
} => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta),
BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]),
}
}
}

impl BasicBlock {
Expand Down
1 change: 1 addition & 0 deletions src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ impl ValidateOp for BasicBlock {
inputs,
predicate_variants,
other_outputs: outputs,
extension_delta: _,
} => {
let predicate_type = Type::new_predicate(predicate_variants.clone());
let node_outputs: TypeRow = [&[predicate_type], outputs.as_ref()].concat().into();
Expand Down

0 comments on commit 71b6ff0

Please sign in to comment.