Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Operation and constructor declarations in hugr-model #1605

Merged
merged 9 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 76 additions & 10 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Exporting HUGR graphs to their `hugr-model` representation.
use crate::{
extension::ExtensionSet,
extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc},
hugr::IdentList,
ops::{DataflowBlock, OpTrait, OpType},
ops::{DataflowBlock, OpName, OpTrait, OpType},
types::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
Expand Down Expand Up @@ -38,11 +38,14 @@ struct Context<'a> {
/// Mapping from ports to link indices.
/// This only includes the minimum port among groups of linked ports.
links: FxIndexSet<(Node, Port)>,
/// The arena in which the model is allocated.
bump: &'a Bump,
/// Stores the terms that we have already seen to avoid duplicates.
term_map: FxHashMap<model::Term<'a>, model::TermId>,
/// The current scope for local variables.
local_scope: Option<model::NodeId>,
/// Mapping from extension operations to their declarations.
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
}

impl<'a> Context<'a> {
Expand All @@ -57,23 +60,26 @@ impl<'a> Context<'a> {
links: IndexSet::default(),
term_map: FxHashMap::default(),
local_scope: None,
decl_operations: FxHashMap::default(),
}
}

/// Exports the root module of the HUGR graph.
pub fn export_root(&mut self) {
let hugr_children = self.hugr.children(self.hugr.root());
let mut children = BumpVec::with_capacity_in(hugr_children.len(), self.bump);
let mut children = Vec::with_capacity(hugr_children.len());

for child in self.hugr.children(self.hugr.root()) {
children.push(self.export_node(child));
}

children.extend(self.decl_operations.values().copied());

let root = self.module.insert_region(model::Region {
kind: model::RegionKind::DataFlow,
kind: model::RegionKind::Module,
sources: &[],
targets: &[],
children: children.into_bump_slice(),
children: self.bump.alloc_slice_copy(&children),
meta: &[], // TODO: Export metadata
signature: None,
});
Expand Down Expand Up @@ -123,15 +129,23 @@ impl<'a> Context<'a> {
.or_insert_with(|| self.module.insert_term(term))
}

pub fn make_named_global_ref(
pub fn make_qualified_name(
&mut self,
extension: &IdentList,
extension: &ExtensionId,
name: impl AsRef<str>,
) -> model::GlobalRef<'a> {
) -> &'a str {
let capacity = extension.len() + name.as_ref().len() + 1;
let mut output = BumpString::with_capacity_in(capacity, self.bump);
let _ = write!(&mut output, "{}.{}", extension, name.as_ref());
model::GlobalRef::Named(output.into_bump_str())
output.into_bump_str()
}

pub fn make_named_global_ref(
&mut self,
extension: &IdentList,
name: impl AsRef<str>,
) -> model::GlobalRef<'a> {
model::GlobalRef::Named(self.make_qualified_name(extension, name))
}

/// Get the node that declares or defines the function associated with the given
Expand Down Expand Up @@ -315,7 +329,7 @@ impl<'a> Context<'a> {
// regions of potentially different kinds. At the moment, we check if the node has any
// children, in which case we create a dataflow region with those children.
OpType::ExtensionOp(op) => {
let operation = self.make_named_global_ref(op.def().extension(), op.def().name());
let operation = self.export_opdef(op.def());

params = self
.bump
Expand Down Expand Up @@ -392,6 +406,58 @@ impl<'a> Context<'a> {
node_id
}

/// Export an `OpDef` as an operation declaration.
///
/// Operations that allow a declarative form are exported as a reference to
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
/// an operation declaration node, and this node is reused for all instances
/// of the operation. The node is added to the `decl_operations` map so that
/// at the end of the export, the operation declaration nodes can be added
/// to the module as children of the module region.
pub fn export_opdef(&mut self, opdef: &OpDef) -> model::GlobalRef<'a> {
use std::collections::hash_map::Entry;

let poly_func_type = match opdef.signature_func() {
SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type,
_ => return self.make_named_global_ref(opdef.extension(), opdef.name()),
};

let key = (opdef.extension().clone(), opdef.name().clone());
let entry = self.decl_operations.entry(key);

let node = match entry {
Entry::Occupied(occupied_entry) => {
return model::GlobalRef::Direct(*occupied_entry.get())
}
Entry::Vacant(vacant_entry) => {
*vacant_entry.insert(self.module.insert_node(model::Node {
operation: model::Operation::Invalid,
inputs: &[],
outputs: &[],
params: &[],
regions: &[],
meta: &[], // TODO: Metadata
signature: None,
}))
}
};

let decl = self.with_local_scope(node, |this| {
ss2165 marked this conversation as resolved.
Show resolved Hide resolved
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let (params, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
params,
r#type,
});
decl
});

self.module.get_node_mut(node).unwrap().operation =
model::Operation::DeclareOperation { decl };

model::GlobalRef::Direct(node)
}

/// Export the signature of a `DataflowBlock`. Here we can't use `OpType::dataflow_signature`
/// like for the other nodes since the ports are control flow ports.
pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId {
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ impl OpDef {
) -> ConstFoldResult {
(self.constant_folder.as_ref())?.fold(type_args, consts)
}

/// Returns a reference to the signature function of this [`OpDef`].
pub fn signature_func(&self) -> &SignatureFunc {
&self.signature_func
}
}

impl Extension {
Expand Down
76 changes: 51 additions & 25 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ impl<'a> Context<'a> {
match item {
NamedItem::FuncDecl(node) => Ok(*node),
NamedItem::FuncDefn(node) => Ok(*node),
NamedItem::CtrDecl(node) => Ok(*node),
NamedItem::OperationDecl(node) => Ok(*node),
}
}
}
Expand All @@ -299,6 +301,8 @@ impl<'a> Context<'a> {
model::Operation::DeclareFunc { decl } => decl.name,
model::Operation::DefineAlias { decl, .. } => decl.name,
model::Operation::DeclareAlias { decl } => decl.name,
model::Operation::DeclareConstructor { decl } => decl.name,
model::Operation::DeclareOperation { decl } => decl.name,
_ => {
return Err(model::ModelError::InvalidGlobal(global_ref.to_string()).into());
}
Expand Down Expand Up @@ -334,7 +338,11 @@ impl<'a> Context<'a> {
Ok(())
}

fn import_node(&mut self, node_id: model::NodeId, parent: Node) -> Result<Node, ImportError> {
fn import_node(
&mut self,
node_id: model::NodeId,
parent: Node,
) -> Result<Option<Node>, ImportError> {
let node_data = self.get_node(node_id)?;

match node_data.operation {
Expand All @@ -349,7 +357,7 @@ impl<'a> Context<'a> {
};

self.import_dfg_region(node_id, *region, node)?;
Ok(node)
Ok(Some(node))
}

model::Operation::Cfg => {
Expand All @@ -362,10 +370,13 @@ impl<'a> Context<'a> {
};

self.import_cfg_region(node_id, *region, node)?;
Ok(node)
Ok(Some(node))
}

model::Operation::Block => self.import_cfg_block(node_id, parent),
model::Operation::Block => {
let node = self.import_cfg_block(node_id, parent)?;
Ok(Some(node))
}

model::Operation::DefineFunc { decl } => {
self.import_poly_func_type(*decl, |ctx, signature| {
Expand All @@ -382,7 +393,7 @@ impl<'a> Context<'a> {

ctx.import_dfg_region(node_id, *region, node)?;

Ok(node)
Ok(Some(node))
})
}

Expand All @@ -395,7 +406,7 @@ impl<'a> Context<'a> {

let node = ctx.make_node(node_id, optype, parent)?;

Ok(node)
Ok(Some(node))
})
}

Expand All @@ -415,7 +426,8 @@ impl<'a> Context<'a> {
self.static_edges.push((func_node, node_id));
let optype = OpType::Call(Call::try_new(func_sig, type_args, self.extensions)?);

self.make_node(node_id, optype, parent)
let node = self.make_node(node_id, optype, parent)?;
Ok(Some(node))
}

model::Operation::LoadFunc { func } => {
Expand All @@ -439,18 +451,26 @@ impl<'a> Context<'a> {
self.extensions,
)?);

self.make_node(node_id, optype, parent)
let node = self.make_node(node_id, optype, parent)?;
Ok(Some(node))
}

model::Operation::TailLoop => self.import_tail_loop(node_id, parent),
model::Operation::Conditional => self.import_conditional(node_id, parent),
model::Operation::TailLoop => {
let node = self.import_tail_loop(node_id, parent)?;
Ok(Some(node))
}
model::Operation::Conditional => {
let node = self.import_conditional(node_id, parent)?;
Ok(Some(node))
}

model::Operation::CustomFull {
operation: GlobalRef::Named(name),
} if name == OP_FUNC_CALL_INDIRECT => {
let signature = self.get_node_signature(node_id)?;
let optype = OpType::CallIndirect(CallIndirect { signature });
self.make_node(node_id, optype, parent)
let node = self.make_node(node_id, optype, parent)?;
Ok(Some(node))
}

model::Operation::CustomFull { operation } => {
Expand All @@ -461,15 +481,7 @@ impl<'a> Context<'a> {
.map(|param| self.import_type_arg(*param))
.collect::<Result<Vec<_>, _>>()?;

let name = match operation {
GlobalRef::Direct(_) => {
return Err(error_unsupported!(
"custom operation with direct reference to declaring node"
))
}
GlobalRef::Named(name) => name,
};

let name = self.get_global_name(operation)?;
let (extension, name) = self.import_custom_name(name)?;

// TODO: Currently we do not have the description or any other metadata for
Expand All @@ -493,7 +505,7 @@ impl<'a> Context<'a> {
_ => return Err(error_unsupported!("multiple regions in custom operation")),
}

Ok(node)
Ok(Some(node))
}

model::Operation::Custom { .. } => Err(error_unsupported!(
Expand All @@ -512,7 +524,8 @@ impl<'a> Context<'a> {
definition: ctx.import_type(value)?,
});

ctx.make_node(node_id, optype, parent)
let node = ctx.make_node(node_id, optype, parent)?;
Ok(Some(node))
}),

model::Operation::DeclareAlias { decl } => self.with_local_socpe(|ctx| {
Expand All @@ -527,7 +540,8 @@ impl<'a> Context<'a> {
bound: TypeBound::Copyable,
});

ctx.make_node(node_id, optype, parent)
let node = ctx.make_node(node_id, optype, parent)?;
Ok(Some(node))
}),

model::Operation::Tag { tag } => {
Expand All @@ -536,15 +550,19 @@ impl<'a> Context<'a> {
.ok_or_else(|| error_uninferred!("node signature"))?;
let (_, outputs, _) = self.get_func_type(signature)?;
let (variants, _) = self.import_adt_and_rest(node_id, outputs)?;
self.make_node(
let node = self.make_node(
node_id,
OpType::Tag(Tag {
variants,
tag: tag as _,
}),
parent,
)
)?;
Ok(Some(node))
}

model::Operation::DeclareConstructor { .. } => Ok(None),
model::Operation::DeclareOperation { .. } => Ok(None),
}
}

Expand Down Expand Up @@ -1188,6 +1206,8 @@ impl<'a> Context<'a> {
enum NamedItem {
FuncDecl(model::NodeId),
FuncDefn(model::NodeId),
CtrDecl(model::NodeId),
OperationDecl(model::NodeId),
}

struct Names<'a> {
Expand All @@ -1208,6 +1228,12 @@ impl<'a> Names<'a> {
model::Operation::DeclareFunc { decl } => {
Some((decl.name, NamedItem::FuncDefn(node_id)))
}
model::Operation::DeclareConstructor { decl } => {
Some((decl.name, NamedItem::CtrDecl(node_id)))
}
model::Operation::DeclareOperation { decl } => {
Some((decl.name, NamedItem::OperationDecl(node_id)))
}
_ => None,
};

Expand Down
Loading
Loading