diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index 41e94ee58..c77ed5edc 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -26,6 +26,8 @@ pub enum OpLoadError { NotMember(String), #[error("Type args invalid: {0}.")] InvalidArgs(#[from] SignatureError), + #[error("OpDef belongs to extension {0}, expected {1}.")] + WrongExtension(ExtensionId, ExtensionId), } impl NamedOp for T @@ -51,6 +53,9 @@ pub trait MakeOpDef: NamedOp { /// Return the signature (polymorphic function type) of the operation. fn signature(&self) -> SignatureFunc; + /// The ID of the extension this operation is defined in. + fn extension(&self) -> ExtensionId; + /// Description of the operation. By default, the same as `self.name()`. fn description(&self) -> String { self.name().to_string() @@ -138,11 +143,20 @@ impl MakeExtensionOp for T { /// Load an [MakeOpDef] from its name. /// See [strum_macros::EnumString]. -pub fn try_from_name(name: &OpNameRef) -> Result +pub fn try_from_name(name: &OpNameRef, def_extension: &ExtensionId) -> Result where T: std::str::FromStr + MakeOpDef, { - T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string())) + let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?; + let expected_extension = op.extension(); + if def_extension != &expected_extension { + return Err(OpLoadError::WrongExtension( + def_extension.clone(), + expected_extension, + )); + } + + Ok(op) } /// Wrap an [MakeExtensionOp] with an extension registry to allow type computation. @@ -245,6 +259,10 @@ mod test { fn from_def(_op_def: &OpDef) -> Result { Ok(Self::Dumb) } + + fn extension(&self) -> ExtensionId { + EXT_ID.to_owned() + } } const_extension_ids! { const EXT_ID: ExtensionId = "DummyExt"; diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index b6dcab87b..5125d3a60 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -36,7 +36,11 @@ pub enum ConvertOpDef { impl MakeOpDef for ConvertOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() } fn signature(&self) -> SignatureFunc { diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 72601c2a3..b098d89dd 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -44,7 +44,11 @@ pub enum FloatOps { impl MakeOpDef for FloatOps { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() } fn signature(&self) -> SignatureFunc { diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 546b2b6c3..a302100b2 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -100,7 +100,11 @@ pub enum IntOpDef { impl MakeOpDef for IntOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() } fn signature(&self) -> SignatureFunc { diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index d6e51811f..52e3904ad 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -72,7 +72,11 @@ impl MakeOpDef for NaryLogic { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name()) + try_from_name(op_def.name(), op_def.extension()) + } + + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() } fn post_opdef(&self, def: &mut OpDef) { @@ -127,6 +131,10 @@ impl MakeOpDef for NotOp { } } + fn extension(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + fn signature(&self) -> SignatureFunc { FunctionType::new_endo(type_row![BOOL_T]).into() }