diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index 89719a6c0..d30c9f88b 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -2,7 +2,7 @@ //! operations and constants. use lazy_static::lazy_static; -use crate::ops::constant::ValueName; +use crate::ops::constant::{CustomCheckFailure, ValueName}; use crate::ops::{CustomOp, OpName}; use crate::types::{SumType, TypeName}; use crate::{ @@ -343,6 +343,62 @@ impl CustomConst for ConstError { } } +#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +/// A structure for holding references to external symbols. +pub struct ConstExternalSymbol { + /// The symbol name that this value refers to. Must be nonempty. + pub symbol: String, + /// The type of the value found at this symbol reference. + pub typ: Type, + /// Whether the value at the symbol referenence is constant or mutable. + pub constant: bool, +} + +impl ConstExternalSymbol { + /// Construct a new [ConstExternalSymbol]. + pub fn new(symbol: impl Into, typ: impl Into, constant: bool) -> Self { + Self { + symbol: symbol.into(), + typ: typ.into(), + constant, + } + } +} + +impl PartialEq for ConstExternalSymbol { + fn eq(&self, other: &dyn CustomConst) -> bool { + self.equal_consts(other) + } +} + +#[typetag::serde] +impl CustomConst for ConstExternalSymbol { + fn name(&self) -> ValueName { + format!("@{}", &self.symbol).into() + } + + fn equal_consts(&self, other: &dyn CustomConst) -> bool { + crate::ops::constant::downcast_equal_consts(self, other) + } + + fn extension_reqs(&self) -> ExtensionSet { + ExtensionSet::singleton(&PRELUDE_ID) + } + fn get_type(&self) -> Type { + self.typ.clone() + } + + fn validate(&self) -> Result<(), CustomCheckFailure> { + if self.symbol.is_empty() { + Err(CustomCheckFailure::Message( + "External symbol name is empty.".into(), + )) + } else { + Ok(()) + } + } +} + #[cfg(test)] mod test { use crate::{ @@ -477,4 +533,24 @@ mod test { b.add_dataflow_op(print_op, [greeting_out]).unwrap(); b.finish_prelude_hugr_with_outputs([]).unwrap(); } + + #[test] + fn test_external_symbol() { + let subject = ConstExternalSymbol::new("foo", Type::UNIT, false); + assert_eq!(subject.get_type(), Type::UNIT); + assert_eq!(subject.name(), "@foo"); + assert!(subject.validate().is_ok()); + assert_eq!( + subject.extension_reqs(), + ExtensionSet::singleton(&PRELUDE_ID) + ); + assert!(subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, false))); + assert!(!subject.equal_consts(&ConstExternalSymbol::new("bar", Type::UNIT, false))); + assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", STRING_TYPE, false))); + assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, true))); + + assert!(ConstExternalSymbol::new("", Type::UNIT, true) + .validate() + .is_err()) + } }