Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add extension deltas to CFG ops #503

Merged
merged 1 commit into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ pub trait Dataflow: Container {
NodeType::open_extensions(ops::CFG {
inputs: inputs.clone(),
outputs: output_types.clone(),
// TODO: Make this a parameter
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do another PR to make these methods take FunctionType as an argument and taking the extension_delta from there

extension_delta: ExtensionSet::new(),
}),
input_wires,
)?;
Expand Down
11 changes: 10 additions & 1 deletion src/builder/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ use super::{
};

use crate::ops::{self, BasicBlock, OpType};
use crate::{extension::ExtensionRegistry, types::FunctionType};
use crate::{
extension::{ExtensionRegistry, ExtensionSet},
types::FunctionType,
};
use crate::{hugr::views::HugrView, types::TypeRow};
use crate::{ops::handle::NodeHandle, types::Type};

Expand Down Expand Up @@ -60,6 +63,8 @@ impl CFGBuilder<Hugr> {
let cfg_op = ops::CFG {
inputs: input.clone(),
outputs: output.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
};

// TODO: Allow input extensions to be specified
Expand Down Expand Up @@ -130,6 +135,8 @@ impl<B: AsMut<Hugr> + AsRef<Hugr>> CFGBuilder<B> {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: Make this a parameter
extension_delta: ExtensionSet::new(),
});
let parent = self.container_node();
let block_n = if entry {
Expand Down Expand Up @@ -277,6 +284,8 @@ impl BlockBuilder<Hugr> {
inputs: inputs.clone(),
other_outputs: other_outputs.clone(),
predicate_variants: predicate_variants.clone(),
// TODO: make this a parameter
extension_delta: ExtensionSet::new(),
};

// TODO: Allow input extensions to be specified
Expand Down
3 changes: 3 additions & 0 deletions src/hugr/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,7 @@ mod test {
NodeType::pure(ops::CFG {
inputs: type_row![BOOL_T],
outputs: type_row![BOOL_T],
extension_delta: ExtensionSet::new(),
}),
);
assert_matches!(
Expand All @@ -969,6 +970,7 @@ mod test {
inputs: type_row![BOOL_T],
predicate_variants: vec![type_row![]],
other_outputs: type_row![BOOL_T],
extension_delta: ExtensionSet::new(),
},
)
.unwrap();
Expand Down Expand Up @@ -1009,6 +1011,7 @@ mod test {
inputs: type_row![Q],
predicate_variants: vec![type_row![]],
other_outputs: type_row![Q],
extension_delta: ExtensionSet::new(),
}),
);
let mut block_children = b.hierarchy.children(block.index);
Expand Down
13 changes: 13 additions & 0 deletions src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use smol_str::SmolStr;

use crate::extension::ExtensionSet;
use crate::type_row;
use crate::types::{EdgeKind, FunctionType, Type, TypeRow};

use super::dataflow::DataflowOpTrait;
Expand Down Expand Up @@ -97,6 +98,7 @@ impl Conditional {
pub struct CFG {
pub inputs: TypeRow,
pub outputs: TypeRow,
pub extension_delta: ExtensionSet,
}

impl_op_name!(CFG);
Expand All @@ -110,6 +112,7 @@ impl DataflowOpTrait for CFG {

fn signature(&self) -> FunctionType {
FunctionType::new(self.inputs.clone(), self.outputs.clone())
.with_extension_delta(&self.extension_delta)
}
}

Expand All @@ -123,6 +126,7 @@ pub enum BasicBlock {
inputs: TypeRow,
other_outputs: TypeRow,
predicate_variants: Vec<TypeRow>,
extension_delta: ExtensionSet,
},
/// The single exit node of the CFG, has no children,
/// stores the types of the CFG node output.
Expand Down Expand Up @@ -166,6 +170,15 @@ impl OpTrait for BasicBlock {
fn other_output(&self) -> Option<EdgeKind> {
Some(EdgeKind::ControlFlow)
}

fn signature(&self) -> FunctionType {
match self {
BasicBlock::DFB {
extension_delta, ..
} => FunctionType::new(type_row![], type_row![]).with_extension_delta(extension_delta),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a hack to make extension inference work! Previously the signature was just the empty default

BasicBlock::Exit { .. } => FunctionType::new(type_row![], type_row![]),
}
}
}

impl BasicBlock {
Expand Down
1 change: 1 addition & 0 deletions src/ops/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ impl ValidateOp for BasicBlock {
inputs,
predicate_variants,
other_outputs: outputs,
extension_delta: _,
} => {
let predicate_type = Type::new_predicate(predicate_variants.clone());
let node_outputs: TypeRow = [&[predicate_type], outputs.as_ref()].concat().into();
Expand Down