Skip to content

Commit

Permalink
feat: Operation and constructor declarations in hugr-model (#1605)
Browse files Browse the repository at this point in the history
This PR adds the ability to declare custom operations and constructors
(so static types, runtime types, constraints, etc.) to `hugr-model`. In
the case of operations this is used for deduplication when exporting.
  • Loading branch information
zrho authored and ss2165 committed Nov 22, 2024
1 parent ab8ce80 commit 0dc177b
Show file tree
Hide file tree
Showing 23 changed files with 442 additions and 87 deletions.
86 changes: 76 additions & 10 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Exporting HUGR graphs to their `hugr-model` representation.
use crate::{
extension::ExtensionSet,
extension::{ExtensionId, ExtensionSet, OpDef, SignatureFunc},
hugr::IdentList,
ops::{DataflowBlock, OpTrait, OpType},
ops::{DataflowBlock, OpName, OpTrait, OpType},
types::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
Expand Down Expand Up @@ -38,11 +38,14 @@ struct Context<'a> {
/// Mapping from ports to link indices.
/// This only includes the minimum port among groups of linked ports.
links: FxIndexSet<(Node, Port)>,
/// The arena in which the model is allocated.
bump: &'a Bump,
/// Stores the terms that we have already seen to avoid duplicates.
term_map: FxHashMap<model::Term<'a>, model::TermId>,
/// The current scope for local variables.
local_scope: Option<model::NodeId>,
/// Mapping from extension operations to their declarations.
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
}

impl<'a> Context<'a> {
Expand All @@ -57,23 +60,26 @@ impl<'a> Context<'a> {
links: IndexSet::default(),
term_map: FxHashMap::default(),
local_scope: None,
decl_operations: FxHashMap::default(),
}
}

/// Exports the root module of the HUGR graph.
pub fn export_root(&mut self) {
let hugr_children = self.hugr.children(self.hugr.root());
let mut children = BumpVec::with_capacity_in(hugr_children.len(), self.bump);
let mut children = Vec::with_capacity(hugr_children.len());

for child in self.hugr.children(self.hugr.root()) {
children.push(self.export_node(child));
}

children.extend(self.decl_operations.values().copied());

let root = self.module.insert_region(model::Region {
kind: model::RegionKind::DataFlow,
kind: model::RegionKind::Module,
sources: &[],
targets: &[],
children: children.into_bump_slice(),
children: self.bump.alloc_slice_copy(&children),
meta: &[], // TODO: Export metadata
signature: None,
});
Expand Down Expand Up @@ -123,15 +129,23 @@ impl<'a> Context<'a> {
.or_insert_with(|| self.module.insert_term(term))
}

pub fn make_named_global_ref(
pub fn make_qualified_name(
&mut self,
extension: &IdentList,
extension: &ExtensionId,
name: impl AsRef<str>,
) -> model::GlobalRef<'a> {
) -> &'a str {
let capacity = extension.len() + name.as_ref().len() + 1;
let mut output = BumpString::with_capacity_in(capacity, self.bump);
let _ = write!(&mut output, "{}.{}", extension, name.as_ref());
model::GlobalRef::Named(output.into_bump_str())
output.into_bump_str()
}

pub fn make_named_global_ref(
&mut self,
extension: &IdentList,
name: impl AsRef<str>,
) -> model::GlobalRef<'a> {
model::GlobalRef::Named(self.make_qualified_name(extension, name))
}

/// Get the node that declares or defines the function associated with the given
Expand Down Expand Up @@ -315,7 +329,7 @@ impl<'a> Context<'a> {
// regions of potentially different kinds. At the moment, we check if the node has any
// children, in which case we create a dataflow region with those children.
OpType::ExtensionOp(op) => {
let operation = self.make_named_global_ref(op.def().extension(), op.def().name());
let operation = self.export_opdef(op.def());

params = self
.bump
Expand Down Expand Up @@ -392,6 +406,58 @@ impl<'a> Context<'a> {
node_id
}

/// Export an `OpDef` as an operation declaration.
///
/// Operations that allow a declarative form are exported as a reference to
/// an operation declaration node, and this node is reused for all instances
/// of the operation. The node is added to the `decl_operations` map so that
/// at the end of the export, the operation declaration nodes can be added
/// to the module as children of the module region.
pub fn export_opdef(&mut self, opdef: &OpDef) -> model::GlobalRef<'a> {
use std::collections::hash_map::Entry;

let poly_func_type = match opdef.signature_func() {
SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type,
_ => return self.make_named_global_ref(opdef.extension(), opdef.name()),
};

let key = (opdef.extension().clone(), opdef.name().clone());
let entry = self.decl_operations.entry(key);

let node = match entry {
Entry::Occupied(occupied_entry) => {
return model::GlobalRef::Direct(*occupied_entry.get())
}
Entry::Vacant(vacant_entry) => {
*vacant_entry.insert(self.module.insert_node(model::Node {
operation: model::Operation::Invalid,
inputs: &[],
outputs: &[],
params: &[],
regions: &[],
meta: &[], // TODO: Metadata
signature: None,
}))
}
};

let decl = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let (params, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
params,
r#type,
});
decl
});

self.module.get_node_mut(node).unwrap().operation =
model::Operation::DeclareOperation { decl };

model::GlobalRef::Direct(node)
}

/// Export the signature of a `DataflowBlock`. Here we can't use `OpType::dataflow_signature`
/// like for the other nodes since the ports are control flow ports.
pub fn export_block_signature(&mut self, block: &DataflowBlock) -> model::TermId {
Expand Down
5 changes: 5 additions & 0 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,11 @@ impl OpDef {
) -> ConstFoldResult {
(self.constant_folder.as_ref())?.fold(type_args, consts)
}

/// Returns a reference to the signature function of this [`OpDef`].
pub fn signature_func(&self) -> &SignatureFunc {
&self.signature_func
}
}

impl Extension {
Expand Down
76 changes: 51 additions & 25 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,8 @@ impl<'a> Context<'a> {
match item {
NamedItem::FuncDecl(node) => Ok(*node),
NamedItem::FuncDefn(node) => Ok(*node),
NamedItem::CtrDecl(node) => Ok(*node),
NamedItem::OperationDecl(node) => Ok(*node),
}
}
}
Expand All @@ -299,6 +301,8 @@ impl<'a> Context<'a> {
model::Operation::DeclareFunc { decl } => decl.name,
model::Operation::DefineAlias { decl, .. } => decl.name,
model::Operation::DeclareAlias { decl } => decl.name,
model::Operation::DeclareConstructor { decl } => decl.name,
model::Operation::DeclareOperation { decl } => decl.name,
_ => {
return Err(model::ModelError::InvalidGlobal(global_ref.to_string()).into());
}
Expand Down Expand Up @@ -334,7 +338,11 @@ impl<'a> Context<'a> {
Ok(())
}

fn import_node(&mut self, node_id: model::NodeId, parent: Node) -> Result<Node, ImportError> {
fn import_node(
&mut self,
node_id: model::NodeId,
parent: Node,
) -> Result<Option<Node>, ImportError> {
let node_data = self.get_node(node_id)?;

match node_data.operation {
Expand All @@ -349,7 +357,7 @@ impl<'a> Context<'a> {
};

self.import_dfg_region(node_id, *region, node)?;
Ok(node)
Ok(Some(node))
}

model::Operation::Cfg => {
Expand All @@ -362,10 +370,13 @@ impl<'a> Context<'a> {
};

self.import_cfg_region(node_id, *region, node)?;
Ok(node)
Ok(Some(node))
}

model::Operation::Block => self.import_cfg_block(node_id, parent),
model::Operation::Block => {
let node = self.import_cfg_block(node_id, parent)?;
Ok(Some(node))
}

model::Operation::DefineFunc { decl } => {
self.import_poly_func_type(*decl, |ctx, signature| {
Expand All @@ -382,7 +393,7 @@ impl<'a> Context<'a> {

ctx.import_dfg_region(node_id, *region, node)?;

Ok(node)
Ok(Some(node))
})
}

Expand All @@ -395,7 +406,7 @@ impl<'a> Context<'a> {

let node = ctx.make_node(node_id, optype, parent)?;

Ok(node)
Ok(Some(node))
})
}

Expand All @@ -415,7 +426,8 @@ impl<'a> Context<'a> {
self.static_edges.push((func_node, node_id));
let optype = OpType::Call(Call::try_new(func_sig, type_args, self.extensions)?);

self.make_node(node_id, optype, parent)
let node = self.make_node(node_id, optype, parent)?;
Ok(Some(node))
}

model::Operation::LoadFunc { func } => {
Expand All @@ -439,18 +451,26 @@ impl<'a> Context<'a> {
self.extensions,
)?);

self.make_node(node_id, optype, parent)
let node = self.make_node(node_id, optype, parent)?;
Ok(Some(node))
}

model::Operation::TailLoop => self.import_tail_loop(node_id, parent),
model::Operation::Conditional => self.import_conditional(node_id, parent),
model::Operation::TailLoop => {
let node = self.import_tail_loop(node_id, parent)?;
Ok(Some(node))
}
model::Operation::Conditional => {
let node = self.import_conditional(node_id, parent)?;
Ok(Some(node))
}

model::Operation::CustomFull {
operation: GlobalRef::Named(name),
} if name == OP_FUNC_CALL_INDIRECT => {
let signature = self.get_node_signature(node_id)?;
let optype = OpType::CallIndirect(CallIndirect { signature });
self.make_node(node_id, optype, parent)
let node = self.make_node(node_id, optype, parent)?;
Ok(Some(node))
}

model::Operation::CustomFull { operation } => {
Expand All @@ -461,15 +481,7 @@ impl<'a> Context<'a> {
.map(|param| self.import_type_arg(*param))
.collect::<Result<Vec<_>, _>>()?;

let name = match operation {
GlobalRef::Direct(_) => {
return Err(error_unsupported!(
"custom operation with direct reference to declaring node"
))
}
GlobalRef::Named(name) => name,
};

let name = self.get_global_name(operation)?;
let (extension, name) = self.import_custom_name(name)?;

// TODO: Currently we do not have the description or any other metadata for
Expand All @@ -493,7 +505,7 @@ impl<'a> Context<'a> {
_ => return Err(error_unsupported!("multiple regions in custom operation")),
}

Ok(node)
Ok(Some(node))
}

model::Operation::Custom { .. } => Err(error_unsupported!(
Expand All @@ -512,7 +524,8 @@ impl<'a> Context<'a> {
definition: ctx.import_type(value)?,
});

ctx.make_node(node_id, optype, parent)
let node = ctx.make_node(node_id, optype, parent)?;
Ok(Some(node))
}),

model::Operation::DeclareAlias { decl } => self.with_local_socpe(|ctx| {
Expand All @@ -527,7 +540,8 @@ impl<'a> Context<'a> {
bound: TypeBound::Copyable,
});

ctx.make_node(node_id, optype, parent)
let node = ctx.make_node(node_id, optype, parent)?;
Ok(Some(node))
}),

model::Operation::Tag { tag } => {
Expand All @@ -536,15 +550,19 @@ impl<'a> Context<'a> {
.ok_or_else(|| error_uninferred!("node signature"))?;
let (_, outputs, _) = self.get_func_type(signature)?;
let (variants, _) = self.import_adt_and_rest(node_id, outputs)?;
self.make_node(
let node = self.make_node(
node_id,
OpType::Tag(Tag {
variants,
tag: tag as _,
}),
parent,
)
)?;
Ok(Some(node))
}

model::Operation::DeclareConstructor { .. } => Ok(None),
model::Operation::DeclareOperation { .. } => Ok(None),
}
}

Expand Down Expand Up @@ -1188,6 +1206,8 @@ impl<'a> Context<'a> {
enum NamedItem {
FuncDecl(model::NodeId),
FuncDefn(model::NodeId),
CtrDecl(model::NodeId),
OperationDecl(model::NodeId),
}

struct Names<'a> {
Expand All @@ -1208,6 +1228,12 @@ impl<'a> Names<'a> {
model::Operation::DeclareFunc { decl } => {
Some((decl.name, NamedItem::FuncDefn(node_id)))
}
model::Operation::DeclareConstructor { decl } => {
Some((decl.name, NamedItem::CtrDecl(node_id)))
}
model::Operation::DeclareOperation { decl } => {
Some((decl.name, NamedItem::OperationDecl(node_id)))
}
_ => None,
};

Expand Down
Loading

0 comments on commit 0dc177b

Please sign in to comment.