diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index b50ddb132..3ef05fe2c 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -350,6 +350,7 @@ impl OpDef { /// Fallibly returns a Hugr that may replace an instance of this OpDef /// given a set of available extensions that may be used in the Hugr. pub fn try_lower(&self, args: &[TypeArg], available_extensions: &ExtensionSet) -> Option { + // TODO test this self.lower_funcs .iter() .flat_map(|f| match f { @@ -395,6 +396,20 @@ impl OpDef { } Ok(()) } + + /// Add a lowering function to the [OpDef] + pub fn add_lower_func(&mut self, lower: LowerFunc) { + self.lower_funcs.push(lower); + } + + /// Insert miscellaneous data `v` to the [OpDef], keyed by `k`. + pub fn add_misc( + &mut self, + k: impl ToString, + v: serde_yaml::Value, + ) -> Option { + self.misc.insert(k.to_string(), v) + } } impl Extension { @@ -406,41 +421,22 @@ impl Extension { &mut self, name: SmolStr, description: String, - misc: HashMap, - lower_funcs: Vec, signature_func: impl Into, - ) -> Result<&OpDef, ExtensionBuildError> { + ) -> Result<&mut OpDef, ExtensionBuildError> { let op = OpDef { extension: self.name.clone(), name, description, - misc, signature_func: signature_func.into(), - lower_funcs, + misc: Default::default(), + lower_funcs: Default::default(), }; match self.operations.entry(op.name.clone()) { Entry::Occupied(_) => Err(ExtensionBuildError::OpDefExists(op.name)), - Entry::Vacant(ve) => Ok(ve.insert(Arc::new(op))), + Entry::Vacant(ve) => Ok(Arc::get_mut(ve.insert(Arc::new(op))).unwrap()), } } - - /// Create an OpDef with `PolyFuncType`, `impl CustomSignatureFunc` or `CustomValidator` - /// ; and no "misc" or "lowering functions" defined. - pub fn add_op_simple( - &mut self, - name: SmolStr, - description: String, - signature_func: impl Into, - ) -> Result<&OpDef, ExtensionBuildError> { - self.add_op( - name, - description, - HashMap::default(), - Vec::new(), - signature_func, - ) - } } #[cfg(test)] @@ -451,16 +447,18 @@ mod test { use super::SignatureFromArgs; use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; + use crate::extension::op_def::LowerFunc; use crate::extension::prelude::USIZE_T; - use crate::extension::{ - ExtensionRegistry, SignatureError, EMPTY_REG, PRELUDE, PRELUDE_REGISTRY, - }; + use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; + use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::ops::custom::ExternalOp; use crate::ops::LeafOp; use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; use crate::types::Type; use crate::types::{type_param::TypeParam, FunctionType, PolyFuncType, TypeArg, TypeBound}; + use crate::Hugr; use crate::{const_extension_ids, Extension}; + const_extension_ids! { const EXT_ID: ExtensionId = "MyExt"; } @@ -474,7 +472,14 @@ mod test { Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); const OP_NAME: SmolStr = SmolStr::new_inline("Reverse"); let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var])); - e.add_op(OP_NAME, "".into(), Default::default(), vec![], type_scheme)?; + + let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?; + def.add_lower_func(LowerFunc::FixedHugr(ExtensionSet::new(), Hugr::default())); + def.add_misc("key", Default::default()); + assert_eq!(def.description(), "desc"); + assert_eq!(def.lower_funcs.len(), 1); + assert_eq!(def.misc.len(), 1); + let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap(); let e = reg.get(&EXT_ID).unwrap(); @@ -526,7 +531,8 @@ mod test { } } let mut e = Extension::new(EXT_ID); - let def = e.add_op_simple("MyOp".into(), "".to_string(), SigFun())?; + let def: &mut crate::extension::OpDef = + e.add_op("MyOp".into(), "".to_string(), SigFun())?; // Base case, no type variables: let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; @@ -587,7 +593,7 @@ mod test { // Check that we can instantiate a PolyFuncType-scheme with an (external) // type variable let mut e = Extension::new(EXT_ID); - let def = e.add_op_simple( + let def = e.add_op( "SimpleOp".into(), "".into(), PolyFuncType::new( diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index f2db71503..12f95a384 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -63,7 +63,7 @@ lazy_static! { ) .unwrap(); prelude - .add_op_simple( + .add_op( SmolStr::new_inline(NEW_ARRAY_OP_ID), "Create a new array from elements".to_string(), ArrayOpCustom, diff --git a/src/lib.rs b/src/lib.rs index 2d0cee669..e1640f698 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,15 +62,15 @@ //! let mut extension = Extension::new(EXTENSION_ID); //! //! extension -//! .add_op_simple(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func()) +//! .add_op(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func()) //! .unwrap(); //! //! extension -//! .add_op_simple(SmolStr::new_inline("CX"), "CX".into(), two_qb_func()) +//! .add_op(SmolStr::new_inline("CX"), "CX".into(), two_qb_func()) //! .unwrap(); //! //! extension -//! .add_op_simple( +//! .add_op( //! SmolStr::new_inline("Measure"), //! "Measure a qubit, returning the qubit and the measurement result.".into(), //! FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T]), diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index b9f79fa1c..56b636f19 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -36,24 +36,24 @@ pub fn extension() -> Extension { ]), ); extension - .add_op_simple( + .add_op( "trunc_u".into(), "float to unsigned int".to_owned(), ftoi_sig.clone(), ) .unwrap(); extension - .add_op_simple("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig) + .add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig) .unwrap(); extension - .add_op_simple( + .add_op( "convert_u".into(), "unsigned int to float".to_owned(), itof_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "convert_s".into(), "signed int to float".to_owned(), itof_sig, diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index d3a7050b0..5cef5d19a 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -29,72 +29,72 @@ pub fn extension() -> Extension { let funop_sig: PolyFuncType = FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into(); extension - .add_op_simple("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) + .add_op("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_simple("fne".into(), "inequality test".to_owned(), fcmp_sig.clone()) + .add_op("fne".into(), "inequality test".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone()) + .add_op("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_simple( + .add_op( "fgt".into(), "\"greater than\"".to_owned(), fcmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "fle".into(), "\"less than or equal\"".to_owned(), fcmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "fge".into(), "\"greater than or equal\"".to_owned(), fcmp_sig, ) .unwrap(); extension - .add_op_simple("fmax".into(), "maximum".to_owned(), fbinop_sig.clone()) + .add_op("fmax".into(), "maximum".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_simple("fmin".into(), "minimum".to_owned(), fbinop_sig.clone()) + .add_op("fmin".into(), "minimum".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_simple("fadd".into(), "addition".to_owned(), fbinop_sig.clone()) + .add_op("fadd".into(), "addition".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone()) + .add_op("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_simple("fneg".into(), "negation".to_owned(), funop_sig.clone()) + .add_op("fneg".into(), "negation".to_owned(), funop_sig.clone()) .unwrap(); extension - .add_op_simple( + .add_op( "fabs".into(), "absolute value".to_owned(), funop_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "fmul".into(), "multiplication".to_owned(), fbinop_sig.clone(), ) .unwrap(); extension - .add_op_simple("fdiv".into(), "division".to_owned(), fbinop_sig) + .add_op("fdiv".into(), "division".to_owned(), fbinop_sig) .unwrap(); extension - .add_op_simple("ffloor".into(), "floor".to_owned(), funop_sig.clone()) + .add_op("ffloor".into(), "floor".to_owned(), funop_sig.clone()) .unwrap(); extension - .add_op_simple("fceil".into(), "ceiling".to_owned(), funop_sig) + .add_op("fceil".into(), "ceiling".to_owned(), funop_sig) .unwrap(); extension diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 8333f1b86..9a77f5dfb 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -113,7 +113,7 @@ fn extension() -> Extension { ); extension - .add_op_simple( + .add_op( "iwiden_u".into(), "widen an unsigned integer to a wider one with the same value".to_owned(), CustomValidator::new_with_validator(widen_poly.clone(), IOValidator { f_gt_s: false }), @@ -121,14 +121,14 @@ fn extension() -> Extension { .unwrap(); extension - .add_op_simple( + .add_op( "iwiden_s".into(), "widen a signed integer to a wider one with the same value".to_owned(), CustomValidator::new_with_validator(widen_poly, IOValidator { f_gt_s: false }), ) .unwrap(); extension - .add_op_simple( + .add_op( "inarrow_u".into(), "narrow an unsigned integer to a narrower one with the same value if possible" .to_owned(), @@ -136,146 +136,146 @@ fn extension() -> Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "inarrow_s".into(), "narrow a signed integer to a narrower one with the same value if possible".to_owned(), CustomValidator::new_with_validator(narrow_poly, IOValidator { f_gt_s: true }), ) .unwrap(); extension - .add_op_simple( + .add_op( "itobool".into(), "convert to bool (1 is true, 0 is false)".to_owned(), itob_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "ifrombool".into(), "convert from bool (1 is true, 0 is false)".to_owned(), btoi_sig.clone(), ) .unwrap(); extension - .add_op_simple("ieq".into(), "equality test".to_owned(), icmp_sig.clone()) + .add_op("ieq".into(), "equality test".to_owned(), icmp_sig.clone()) .unwrap(); extension - .add_op_simple("ine".into(), "inequality test".to_owned(), icmp_sig.clone()) + .add_op("ine".into(), "inequality test".to_owned(), icmp_sig.clone()) .unwrap(); extension - .add_op_simple( + .add_op( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), icmp_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imax_u".into(), "maximum of unsigned integers".to_owned(), ibinop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imax_s".into(), "maximum of signed integers".to_owned(), ibinop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imin_u".into(), "minimum of unsigned integers".to_owned(), ibinop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imin_s".into(), "minimum of signed integers".to_owned(), ibinop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), ibinop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), ibinop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), iunop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), ibinop_sig(), ) .unwrap(); extension - .add_op_simple( + .add_op( "idivmod_checked_u".into(), "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "idivmod_u".into(), "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "idivmod_checked_s".into(), "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "idivmod_s".into(), "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "idiv_checked_u".into(), "as idivmod_checked_u but discarding the second output".to_owned(), idiv_checked_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "idiv_u".into(), "as idivmod_u but discarding the second output".to_owned(), idiv_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imod_checked_u".into(), "as idivmod_checked_u but discarding the first output".to_owned(), imod_checked_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imod_u".into(), "as idivmod_u but discarding the first output".to_owned(), imod_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "idiv_checked_s".into(), "as idivmod_checked_s but discarding the second output".to_owned(), idiv_checked_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "idiv_s".into(), "as idivmod_s but discarding the second output".to_owned(), idiv_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imod_checked_s".into(), "as idivmod_checked_s but discarding the first output".to_owned(), imod_checked_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "imod_s".into(), "as idivmod_s but discarding the first output".to_owned(), imod_sig.clone(), ) .unwrap(); extension - .add_op_simple( + .add_op( "iabs".into(), "convert signed to unsigned by taking absolute value".to_owned(), iunop_sig(), ) .unwrap(); extension - .add_op_simple("iand".into(), "bitwise AND".to_owned(), ibinop_sig()) + .add_op("iand".into(), "bitwise AND".to_owned(), ibinop_sig()) .unwrap(); extension - .add_op_simple("ior".into(), "bitwise OR".to_owned(), ibinop_sig()) + .add_op("ior".into(), "bitwise OR".to_owned(), ibinop_sig()) .unwrap(); extension - .add_op_simple("ixor".into(), "bitwise XOR".to_owned(), ibinop_sig()) + .add_op("ixor".into(), "bitwise XOR".to_owned(), ibinop_sig()) .unwrap(); extension - .add_op_simple("inot".into(), "bitwise NOT".to_owned(), iunop_sig()) + .add_op("inot".into(), "bitwise NOT".to_owned(), iunop_sig()) .unwrap(); extension - .add_op_simple( + .add_op( "ishl".into(), "shift first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits dropped, rightmost bits set to zero" @@ -395,7 +395,7 @@ fn extension() -> Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "ishr".into(), "shift first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits dropped, leftmost bits set to zero)" @@ -404,7 +404,7 @@ fn extension() -> Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "irotl".into(), "rotate first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits replace rightmost bits)" @@ -413,7 +413,7 @@ fn extension() -> Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( "irotr".into(), "rotate first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits replace leftmost bits)" diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index fb2a3ebf6..522c16dfa 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -84,7 +84,7 @@ fn extension() -> Extension { let (l, e) = list_and_elem_type(list_type_def); extension - .add_op_simple( + .add_op( POP_NAME, "Pop from back of list".into(), PolyFuncType::new( @@ -94,7 +94,7 @@ fn extension() -> Extension { ) .unwrap(); extension - .add_op_simple( + .add_op( PUSH_NAME, "Push to back of list".into(), PolyFuncType::new(vec![TP], FunctionType::new(vec![l.clone(), e], vec![l])), diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index d0444513f..710d29b1b 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -55,7 +55,7 @@ fn extension() -> Extension { let mut extension = Extension::new(EXTENSION_ID); extension - .add_op_simple( + .add_op( SmolStr::new_inline(NOT_NAME), "logical 'not'".into(), FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]), @@ -63,7 +63,7 @@ fn extension() -> Extension { .unwrap(); extension - .add_op_simple( + .add_op( SmolStr::new_inline(AND_NAME), "logical 'and'".into(), logic_op_sig(), @@ -71,7 +71,7 @@ fn extension() -> Extension { .unwrap(); extension - .add_op_simple( + .add_op( SmolStr::new_inline(OR_NAME), "logical 'or'".into(), logic_op_sig(), diff --git a/src/utils.rs b/src/utils.rs index 5ed2acd11..6c297b4bd 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -58,10 +58,10 @@ pub(crate) mod test_quantum_extension { let mut extension = Extension::new(EXTENSION_ID); extension - .add_op_simple(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func()) + .add_op(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func()) .unwrap(); extension - .add_op_simple( + .add_op( SmolStr::new_inline("RzF64"), "Rotation specified by float".into(), FunctionType::new(type_row![QB_T, float_types::FLOAT64_TYPE], type_row![QB_T]), @@ -69,11 +69,11 @@ pub(crate) mod test_quantum_extension { .unwrap(); extension - .add_op_simple(SmolStr::new_inline("CX"), "CX".into(), two_qb_func()) + .add_op(SmolStr::new_inline("CX"), "CX".into(), two_qb_func()) .unwrap(); extension - .add_op_simple( + .add_op( SmolStr::new_inline("Measure"), "Measure a qubit, returning the qubit and the measurement result.".into(), FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T]),