From 0bcab0ae9d7d542414fb886473378c6cfb0e1a92 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 21 Dec 2023 18:35:11 +0000 Subject: [PATCH] refactor!: use enum op traits for floats + conversions (#755) BREAKING CHANGES: extension() function replaced with EXTENSION static ref for float_ops and conversions --- src/ops/constant.rs | 2 +- src/std_extensions/arithmetic/conversions.rs | 164 ++++++++++----- src/std_extensions/arithmetic/float_ops.rs | 198 ++++++++++--------- src/std_extensions/arithmetic/float_types.rs | 34 ++-- src/std_extensions/arithmetic/int_ops.rs | 2 +- src/std_extensions/collections.rs | 2 +- src/utils.rs | 2 +- 7 files changed, 245 insertions(+), 159 deletions(-) diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 6db0006b3..c66ab6f76 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -148,7 +148,7 @@ mod test { use super::*; fn test_registry() -> ExtensionRegistry { - ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap() + ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap() } #[test] diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 4ae262b77..98e5df887 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,63 +1,131 @@ //! Conversions between integer and floating-point values. +use smol_str::SmolStr; +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + use crate::{ - extension::{prelude::sum_with_error, ExtensionId, ExtensionSet}, + extension::{ + prelude::sum_with_error, + simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}, + ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, + }, + ops::{custom::ExtensionOp, OpName}, type_row, - types::{FunctionType, PolyFuncType}, + types::{FunctionType, PolyFuncType, TypeArg}, Extension, }; use super::int_types::int_tv; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -/// Extension for basic arithmetic operations. -pub fn extension() -> Extension { - let ftoi_sig = PolyFuncType::new( - vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]), - ); - - let itof_sig = PolyFuncType::new( - vec![LOG_WIDTH_TYPE_PARAM], - FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]), - ); - - let mut extension = Extension::new_with_reqs( - EXTENSION_ID, - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ]), - ); - extension - .add_op( - "trunc_u".into(), - "float to unsigned int".to_owned(), - ftoi_sig.clone(), - ) - .unwrap(); - extension - .add_op("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig) - .unwrap(); - extension - .add_op( - "convert_u".into(), - "unsigned int to float".to_owned(), - itof_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "convert_s".into(), - "signed int to float".to_owned(), - itof_sig, - ) - .unwrap(); - - extension +/// Extensiop for conversions between floats and integers. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs, non_camel_case_types)] +pub enum ConvertOpDef { + trunc_u, + trunc_s, + convert_u, + convert_s, +} + +impl MakeOpDef for ConvertOpDef { + fn from_def(op_def: &OpDef) -> Result { + crate::extension::simple_op::try_from_name(op_def.name()) + } + + fn signature(&self) -> SignatureFunc { + use ConvertOpDef::*; + match self { + trunc_s | trunc_u => PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new(type_row![FLOAT64_TYPE], vec![sum_with_error(int_tv(0))]), + ), + + convert_s | convert_u => PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new(vec![int_tv(0)], type_row![FLOAT64_TYPE]), + ), + } + .into() + } + + fn description(&self) -> String { + use ConvertOpDef::*; + match self { + trunc_u => "float to unsigned int", + trunc_s => "float to signed int", + convert_u => "unsigned int to float", + convert_s => "signed int to float", + } + .to_string() + } +} + +/// Concrete convert operation with integer width set. +#[derive(Debug, Clone, PartialEq)] +pub struct ConvertOpType { + def: ConvertOpDef, + width: u64, +} + +impl OpName for ConvertOpType { + fn name(&self) -> SmolStr { + self.def.name() + } +} + +impl MakeExtensionOp for ConvertOpType { + fn from_extension_op(ext_op: &ExtensionOp) -> Result { + let def = ConvertOpDef::from_def(ext_op.def())?; + let width = match *ext_op.args() { + [TypeArg::BoundedNat { n }] => n, + _ => return Err(SignatureError::InvalidTypeArgs.into()), + }; + Ok(Self { def, width }) + } + + fn type_args(&self) -> Vec { + vec![TypeArg::BoundedNat { n: self.width }] + } +} + +lazy_static! { + /// Extension for conversions between integers and floats. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new_with_reqs( + EXTENSION_ID, + ExtensionSet::from_iter(vec![ + super::int_types::EXTENSION_ID, + super::float_types::EXTENSION_ID, + ]), + ); + + ConvertOpDef::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate integer operations. + pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + super::int_types::EXTENSION.to_owned(), + super::float_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} + +impl MakeRegisteredOp for ConvertOpType { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &CONVERT_OPS_REGISTRY + } } #[cfg(test)] @@ -66,7 +134,7 @@ mod test { #[test] fn test_conversions_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.conversions"); assert_eq!(r.types().count(), 0); for (name, _) in r.operations() { diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 5cef5d19a..87c87751b 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -1,103 +1,119 @@ //! Basic floating-point operations. +use strum_macros::{EnumIter, EnumString, IntoStaticStr}; + +use super::float_types::FLOAT64_TYPE; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::{ + prelude::BOOL_T, + simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, + ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, PRELUDE, + }, type_row, - types::{FunctionType, PolyFuncType}, + types::FunctionType, Extension, }; - -use super::float_types::FLOAT64_TYPE; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); -/// Extension for basic arithmetic operations. -pub fn extension() -> Extension { - let mut extension = Extension::new_with_reqs( - EXTENSION_ID, - ExtensionSet::singleton(&super::float_types::EXTENSION_ID), - ); - - let fcmp_sig: PolyFuncType = FunctionType::new( - type_row![FLOAT64_TYPE; 2], - type_row![crate::extension::prelude::BOOL_T], - ) - .into(); - let fbinop_sig: PolyFuncType = - FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into(); - let funop_sig: PolyFuncType = - FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into(); - extension - .add_op("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op("fne".into(), "inequality test".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone()) - .unwrap(); - extension - .add_op( - "fgt".into(), - "\"greater than\"".to_owned(), - fcmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fle".into(), - "\"less than or equal\"".to_owned(), - fcmp_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fge".into(), - "\"greater than or equal\"".to_owned(), - fcmp_sig, - ) - .unwrap(); - extension - .add_op("fmax".into(), "maximum".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fmin".into(), "minimum".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fadd".into(), "addition".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone()) - .unwrap(); - extension - .add_op("fneg".into(), "negation".to_owned(), funop_sig.clone()) - .unwrap(); - extension - .add_op( - "fabs".into(), - "absolute value".to_owned(), - funop_sig.clone(), - ) - .unwrap(); - extension - .add_op( - "fmul".into(), - "multiplication".to_owned(), - fbinop_sig.clone(), - ) - .unwrap(); - extension - .add_op("fdiv".into(), "division".to_owned(), fbinop_sig) - .unwrap(); - extension - .add_op("ffloor".into(), "floor".to_owned(), funop_sig.clone()) - .unwrap(); - extension - .add_op("fceil".into(), "ceiling".to_owned(), funop_sig) - .unwrap(); - - extension +/// Integer extension operation definitions. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] +#[allow(missing_docs, non_camel_case_types)] +pub enum FloatOps { + feq, + fne, + flt, + fgt, + fle, + fge, + fmax, + fmin, + fadd, + fsub, + fneg, + fabs, + fmul, + fdiv, + ffloor, + fceil, +} + +impl MakeOpDef for FloatOps { + fn from_def(op_def: &OpDef) -> Result { + crate::extension::simple_op::try_from_name(op_def.name()) + } + + fn signature(&self) -> SignatureFunc { + use FloatOps::*; + + match self { + feq | fne | flt | fgt | fle | fge => { + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![BOOL_T]) + } + fmax | fmin | fadd | fsub | fmul | fdiv => { + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]) + } + fneg | fabs | ffloor | fceil => FunctionType::new_endo(type_row![FLOAT64_TYPE]), + } + .into() + } + + fn description(&self) -> String { + use FloatOps::*; + match self { + feq => "equality test", + fne => "inequality test", + flt => "\"less than\"", + fgt => "\"greater than\"", + fle => "\"less than or equal\"", + fge => "\"greater than or equal\"", + fmax => "maximum", + fmin => "minimum", + fadd => "addition", + fsub => "subtraction", + fneg => "negation", + fabs => "absolute value", + fmul => "multiplication", + fdiv => "division", + ffloor => "floor", + fceil => "ceiling", + } + .to_string() + } +} + +lazy_static! { + /// Extension for basic float operations. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new_with_reqs( + EXTENSION_ID, + ExtensionSet::singleton(&super::int_types::EXTENSION_ID), + ); + + FloatOps::load_all_ops(&mut extension).unwrap(); + + extension + }; + + /// Registry of extensions required to validate float operations. + pub static ref FLOAT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::try_new([ + PRELUDE.to_owned(), + super::float_types::EXTENSION.to_owned(), + EXTENSION.to_owned(), + ]) + .unwrap(); +} + +impl MakeRegisteredOp for FloatOps { + fn extension_id(&self) -> ExtensionId { + EXTENSION_ID.to_owned() + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { + &FLOAT_OPS_REGISTRY + } } #[cfg(test)] @@ -106,7 +122,7 @@ mod test { #[test] fn test_float_ops_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.float"); assert_eq!(r.types().count(), 0); for (name, _) in r.operations() { diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 582c0eee7..71f91bf87 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -8,6 +8,7 @@ use crate::{ values::{CustomConst, KnownTypeConst}, Extension, }; +use lazy_static::lazy_static; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float.types"); @@ -72,29 +73,30 @@ impl CustomConst for ConstF64 { } } -/// Extension for basic floating-point types. -pub fn extension() -> Extension { - let mut extension = Extension::new(EXTENSION_ID); - - extension - .add_type( - FLOAT_TYPE_ID, - vec![], - "64-bit IEEE 754-2019 floating-point value".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - extension +lazy_static! { + /// Extension defining the float type. + pub static ref EXTENSION: Extension = { + let mut extension = Extension::new(EXTENSION_ID); + + extension + .add_type( + FLOAT_TYPE_ID, + vec![], + "64-bit IEEE 754-2019 floating-point value".to_owned(), + TypeBound::Copyable.into(), + ) + .unwrap(); + + extension + }; } - #[cfg(test)] mod test { use super::*; #[test] fn test_float_types_extension() { - let r = extension(); + let r = &EXTENSION; assert_eq!(r.name() as &str, "arithmetic.float.types"); assert_eq!(r.types().count(), 1); assert_eq!(r.operations().count(), 0); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 267fde902..ae5160ffd 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -41,7 +41,7 @@ impl ValidateJustArgs for IOValidator { Ok(()) } } -/// Logic extension operation definitions. +/// Integer extension operation definitions. #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs, non_camel_case_types)] pub enum IntOpDef { diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index a78f4793a..ebec9bda7 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -177,7 +177,7 @@ mod test { let reg = ExtensionRegistry::try_new([ EXTENSION.to_owned(), PRELUDE.to_owned(), - float_types::extension(), + float_types::EXTENSION.to_owned(), ]) .unwrap(); let pop_sig = get_op(&POP_NAME) diff --git a/src/utils.rs b/src/utils.rs index 6c297b4bd..f62cee50f 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -86,7 +86,7 @@ pub(crate) mod test_quantum_extension { lazy_static! { /// Quantum extension definition. pub static ref EXTENSION: Extension = extension(); - static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::extension()]).unwrap(); + static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::EXTENSION.to_owned()]).unwrap(); } fn get_gate(gate_name: &str) -> LeafOp {