From 201d1a26a64eaa19e36e41bf70e09dbaa41018be Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 14:06:49 +0000 Subject: [PATCH] refactor: use type schemes in extension definitions wherever possible (#678) Closes #658 As a drive by improve extensions test coverage as I went. Definition can be a bit unintuitive, should improve after #676 is done. --- src/extension/op_def.rs | 17 + src/extension/prelude.rs | 5 + src/hugr/validate/test.rs | 10 +- src/std_extensions/arithmetic/conversions.rs | 61 ++- src/std_extensions/arithmetic/float_ops.rs | 79 ++-- src/std_extensions/arithmetic/float_types.rs | 5 + src/std_extensions/arithmetic/int_ops.rs | 376 +++++++++---------- src/std_extensions/arithmetic/int_types.rs | 54 ++- src/std_extensions/collections.rs | 65 ++-- src/std_extensions/logic.rs | 47 +-- src/utils.rs | 41 +- 11 files changed, 387 insertions(+), 373 deletions(-) diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 6a6ba0f97..60916914a 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -331,6 +331,23 @@ impl Extension { SignatureFunc::TypeScheme(type_scheme), ) } + + /// Create an OpDef with a signature (inputs+outputs) read from e.g. + /// declarative YAML; and no "misc" or "lowering functions" defined. + pub fn add_op_type_scheme_simple( + &mut self, + name: SmolStr, + description: String, + type_scheme: PolyFuncType, + ) -> Result<&OpDef, ExtensionBuildError> { + self.add_op( + name, + description, + Default::default(), + vec![], + SignatureFunc::TypeScheme(type_scheme), + ) + } } #[cfg(test)] diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index 98038fc18..37ad705bc 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -136,6 +136,11 @@ pub const ERROR_TYPE: Type = Type::new_extension(CustomType::new_simple( /// The string name of the error type. pub const ERROR_TYPE_NAME: SmolStr = SmolStr::new_inline("error"); +/// Return a Sum type with the first variant as the given type and the second an Error. +pub fn sum_with_error(ty: Type) -> Type { + Type::new_sum(vec![ty, ERROR_TYPE]) +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] /// Structure for holding constant usize values. pub struct ConstUsize(u64); diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index a4a4dd1a2..2cf5d1812 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -13,7 +13,7 @@ use crate::macros::const_extension_ids; use crate::ops::dataflow::IOTrait; use crate::ops::{self, LeafOp, OpType}; use crate::std_extensions::logic; -use crate::std_extensions::logic::test::{and_op, not_op}; +use crate::std_extensions::logic::test::{and_op, not_op, or_op}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::{CustomType, FunctionType, Type, TypeBound, TypeRow}; use crate::{type_row, Direction, IncomingPort, Node}; @@ -612,12 +612,12 @@ fn dfg_with_cycles() -> Result<(), HugrError> { type_row![BOOL_T], )); let [input, output] = h.get_io(h.root()).unwrap(); - let and = h.add_node_with_parent(h.root(), and_op())?; + let or = h.add_node_with_parent(h.root(), or_op())?; let not1 = h.add_node_with_parent(h.root(), not_op())?; let not2 = h.add_node_with_parent(h.root(), not_op())?; - h.connect(input, 0, and, 0)?; - h.connect(and, 0, not1, 0)?; - h.connect(not1, 0, and, 1)?; + h.connect(input, 0, or, 0)?; + h.connect(or, 0, not1, 0)?; + h.connect(not1, 0, or, 1)?; h.connect(input, 1, not2, 0)?; h.connect(not2, 0, output, 0)?; // The graph contains a cycle: diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index d07b97f62..63207c219 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,36 +1,34 @@ //! Conversions between integer and floating-point values. use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError}, + extension::{ + prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, + PRELUDE, + }, type_row, - types::{type_param::TypeArg, FunctionType, Type}, - utils::collect_array, + types::{FunctionType, PolyFuncType}, Extension, }; -use super::int_types::int_type; +use super::int_types::int_type_var; use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -fn ftoi_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( +fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result { + let body = FunctionType::new( type_row![FLOAT64_TYPE], - vec![Type::new_sum(vec![ - int_type(arg.clone()), - crate::extension::prelude::ERROR_TYPE, - ])], - )) + vec![sum_with_error(int_type_var(0))], + ); + + PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } -fn itof_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone())], - type_row![FLOAT64_TYPE], - )) +fn itof_sig(temp_reg: &ExtensionRegistry) -> Result { + let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]); + + PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) } /// Extension for basic arithmetic operations. @@ -42,37 +40,38 @@ pub fn extension() -> Extension { super::float_types::EXTENSION_ID, ]), ); - + let temp_reg: ExtensionRegistry = [ + super::int_types::EXTENSION.to_owned(), + super::float_types::extension(), + PRELUDE.to_owned(), + ] + .into(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "trunc_u".into(), "float to unsigned int".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ftoi_sig, + ftoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "trunc_s".into(), "float to signed int".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ftoi_sig, + ftoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "convert_u".into(), "unsigned int to float".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - itof_sig, + itof_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "convert_s".into(), "signed int to float".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - itof_sig, + itof_sig(&temp_reg).unwrap(), ) .unwrap(); diff --git a/src/std_extensions/arithmetic/float_ops.rs b/src/std_extensions/arithmetic/float_ops.rs index 5a8907300..7dc20f40d 100644 --- a/src/std_extensions/arithmetic/float_ops.rs +++ b/src/std_extensions/arithmetic/float_ops.rs @@ -1,9 +1,9 @@ //! Basic floating-point operations. use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError}, + extension::{ExtensionId, ExtensionSet}, type_row, - types::{type_param::TypeArg, FunctionType}, + types::{FunctionType, PolyFuncType}, Extension, }; @@ -12,27 +12,6 @@ use super::float_types::FLOAT64_TYPE; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.float"); -fn fcmp_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![FLOAT64_TYPE; 2], - type_row![crate::extension::prelude::BOOL_T], - )) -} - -fn fbinop_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![FLOAT64_TYPE; 2], - type_row![FLOAT64_TYPE], - )) -} - -fn funop_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![FLOAT64_TYPE], - type_row![FLOAT64_TYPE], - )) -} - /// Extension for basic arithmetic operations. pub fn extension() -> Extension { let mut extension = Extension::new_with_reqs( @@ -40,78 +19,82 @@ pub fn extension() -> Extension { ExtensionSet::singleton(&super::float_types::EXTENSION_ID), ); + let fcmp_sig: PolyFuncType = FunctionType::new( + type_row![FLOAT64_TYPE; 2], + type_row![crate::extension::prelude::BOOL_T], + ) + .into(); + let fbinop_sig: PolyFuncType = + FunctionType::new(type_row![FLOAT64_TYPE; 2], type_row![FLOAT64_TYPE]).into(); + let funop_sig: PolyFuncType = + FunctionType::new(type_row![FLOAT64_TYPE], type_row![FLOAT64_TYPE]).into(); extension - .add_op_custom_sig_simple("feq".into(), "equality test".to_owned(), vec![], fcmp_sig) + .add_op_type_scheme_simple("feq".into(), "equality test".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple("fne".into(), "inequality test".to_owned(), vec![], fcmp_sig) + .add_op_type_scheme_simple("fne".into(), "inequality test".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple("flt".into(), "\"less than\"".to_owned(), vec![], fcmp_sig) + .add_op_type_scheme_simple("flt".into(), "\"less than\"".to_owned(), fcmp_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "fgt".into(), "\"greater than\"".to_owned(), - vec![], - fcmp_sig, + fcmp_sig.clone(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "fle".into(), "\"less than or equal\"".to_owned(), - vec![], - fcmp_sig, + fcmp_sig.clone(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "fge".into(), "\"greater than or equal\"".to_owned(), - vec![], fcmp_sig, ) .unwrap(); extension - .add_op_custom_sig_simple("fmax".into(), "maximum".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fmax".into(), "maximum".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple("fmin".into(), "minimum".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fmin".into(), "minimum".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple("fadd".into(), "addition".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fadd".into(), "addition".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple("fsub".into(), "subtraction".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fsub".into(), "subtraction".to_owned(), fbinop_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple("fneg".into(), "negation".to_owned(), vec![], funop_sig) + .add_op_type_scheme_simple("fneg".into(), "negation".to_owned(), funop_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "fabs".into(), "absolute value".to_owned(), - vec![], - funop_sig, + funop_sig.clone(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "fmul".into(), "multiplication".to_owned(), - vec![], - fbinop_sig, + fbinop_sig.clone(), ) .unwrap(); extension - .add_op_custom_sig_simple("fdiv".into(), "division".to_owned(), vec![], fbinop_sig) + .add_op_type_scheme_simple("fdiv".into(), "division".to_owned(), fbinop_sig) .unwrap(); extension - .add_op_custom_sig_simple("ffloor".into(), "floor".to_owned(), vec![], funop_sig) + .add_op_type_scheme_simple("ffloor".into(), "floor".to_owned(), funop_sig.clone()) .unwrap(); extension - .add_op_custom_sig_simple("fceil".into(), "ceiling".to_owned(), vec![], funop_sig) + .add_op_type_scheme_simple("fceil".into(), "ceiling".to_owned(), funop_sig) .unwrap(); extension diff --git a/src/std_extensions/arithmetic/float_types.rs b/src/std_extensions/arithmetic/float_types.rs index 5413b6183..32b7815ef 100644 --- a/src/std_extensions/arithmetic/float_types.rs +++ b/src/std_extensions/arithmetic/float_types.rs @@ -100,6 +100,11 @@ mod test { fn test_float_consts() { let const_f64_1 = ConstF64::new(1.0); let const_f64_2 = ConstF64::new(2.0); + + assert_eq!(const_f64_1.value(), 1.0); + assert_eq!(*const_f64_2, 2.0); + assert_eq!(const_f64_1.name(), "f64(1)"); + assert!(const_f64_1.equal_consts(&ConstF64::new(1.0))); assert_ne!(const_f64_1, const_f64_2); assert_eq!(const_f64_1, ConstF64::new(1.0)); } diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index 4a6aae198..cc09fc4a3 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -1,9 +1,10 @@ //! Basic integer operations. -use super::int_types::{get_log_width, int_type, type_arg, LOG_WIDTH_TYPE_PARAM}; -use crate::extension::prelude::{BOOL_T, ERROR_TYPE}; +use super::int_types::{get_log_width, int_type, int_type_var, LOG_WIDTH_TYPE_PARAM}; +use crate::extension::prelude::{sum_with_error, BOOL_T}; +use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; -use crate::types::FunctionType; +use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; use crate::{ extension::{ExtensionId, ExtensionSet, SignatureError}, @@ -36,104 +37,99 @@ fn inarrow_sig(arg_values: &[TypeArg]) -> Result { } Ok(FunctionType::new( vec![int_type(arg0.clone())], - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], + vec![sum_with_error(int_type(arg1.clone()))], )) } -fn itob_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - vec![int_type(type_arg(0))], - type_row![BOOL_T], - )) +fn int_polytype( + n_vars: usize, + input: impl Into, + output: impl Into, + temp_reg: &ExtensionRegistry, +) -> Result { + PolyFuncType::new_validated( + vec![LOG_WIDTH_TYPE_PARAM; n_vars], + FunctionType::new(input, output), + temp_reg, + ) } -fn btoi_sig(_arg_values: &[TypeArg]) -> Result { - Ok(FunctionType::new( - type_row![BOOL_T], - vec![int_type(type_arg(0))], - )) +fn itob_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T], temp_reg) } -fn icmp_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone()); 2], - type_row![BOOL_T], - )) +fn btoi_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, type_row![BOOL_T], vec![int_type_var(0)], temp_reg) } -fn ibinop_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone()); 2], - vec![int_type(arg.clone())], - )) +fn icmp_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(1, vec![int_type_var(0); 2], type_row![BOOL_T], temp_reg) } -fn iunop_sig(arg_values: &[TypeArg]) -> Result { - let [arg] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg.clone())], - vec![int_type(arg.clone())], - )) +fn ibinop_sig(temp_reg: &ExtensionRegistry) -> Result { + let int_type_var = int_type_var(0); + + int_polytype( + 1, + vec![int_type_var.clone(); 2], + vec![int_type_var], + temp_reg, + ) } -fn idivmod_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); - Ok(FunctionType::new( - intpair.clone(), - vec![Type::new_sum(vec![Type::new_tuple(intpair), ERROR_TYPE])], - )) +fn iunop_sig(temp_reg: &ExtensionRegistry) -> Result { + let int_type_var = int_type_var(0); + int_polytype(1, vec![int_type_var.clone()], vec![int_type_var], temp_reg) } -fn idivmod_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - let intpair: TypeRow = vec![int_type(arg0.clone()), int_type(arg1.clone())].into(); - Ok(FunctionType::new( +fn idivmod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { + let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); + int_polytype( + 2, intpair.clone(), - vec![Type::new_tuple(intpair)], - )) + vec![sum_with_error(Type::new_tuple(intpair))], + temp_reg, + ) } -fn idiv_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![Type::new_sum(vec![int_type(arg0.clone()), ERROR_TYPE])], - )) +fn idivmod_sig(temp_reg: &ExtensionRegistry) -> Result { + let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); + int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)], temp_reg) } -fn idiv_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg0.clone())], - )) +fn idiv_checked_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype( + 2, + vec![int_type_var(1)], + vec![sum_with_error(int_type_var(0))], + temp_reg, + ) } -fn imod_checked_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![Type::new_sum(vec![int_type(arg1.clone()), ERROR_TYPE])], - )) +fn idiv_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) } -fn imod_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg1.clone())], - )) +fn imod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype( + 2, + vec![int_type_var(0), int_type_var(1).clone()], + vec![sum_with_error(int_type_var(1))], + temp_reg, + ) } -fn ish_sig(arg_values: &[TypeArg]) -> Result { - let [arg0, arg1] = collect_array(arg_values); - Ok(FunctionType::new( - vec![int_type(arg0.clone()), int_type(arg1.clone())], - vec![int_type(arg0.clone())], - )) +fn imod_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype( + 2, + vec![int_type_var(0), int_type_var(1).clone()], + vec![int_type_var(1)], + temp_reg, + ) +} + +fn ish_sig(temp_reg: &ExtensionRegistry) -> Result { + int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) } /// Extension for basic integer operations. @@ -143,6 +139,9 @@ pub fn extension() -> Extension { ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); + let temp_reg: ExtensionRegistry = + [super::int_types::EXTENSION.to_owned(), PRELUDE.to_owned()].into(); + extension .add_op_custom_sig_simple( "iwiden_u".into(), @@ -177,347 +176,306 @@ pub fn extension() -> Extension { ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "itobool".into(), "convert to bool (1 is true, 0 is false)".to_owned(), - vec![], - itob_sig, + itob_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ifrombool".into(), "convert from bool (1 is true, 0 is false)".to_owned(), - vec![], - btoi_sig, + btoi_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ieq".into(), "equality test".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ine".into(), "inequality test".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - icmp_sig, + icmp_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imax_u".into(), "maximum of unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imax_s".into(), "maximum of signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imin_u".into(), "minimum of unsigned integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imin_s".into(), "minimum of signed integers".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - iunop_sig, + iunop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - vec![LOG_WIDTH_TYPE_PARAM], - ibinop_sig, + ibinop_sig(&temp_reg).unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( "idivmod_checked_u".into(), "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r TypeArg { + TypeArg::BoundedNat { n } + } + #[test] + fn test_binary_signatures() { + let sig = iwiden_sig(&[ta(3), ta(4)]).unwrap(); + assert_eq!( + sig, + FunctionType::new(vec![int_type(ta(3))], vec![int_type(ta(4))],) + ); + + iwiden_sig(&[ta(4), ta(3)]).unwrap_err(); + + let sig = inarrow_sig(&[ta(2), ta(1)]).unwrap(); + assert_eq!( + sig, + FunctionType::new(vec![int_type(ta(2))], vec![sum_with_error(int_type(ta(1)))],) + ); + + inarrow_sig(&[ta(1), ta(2)]).unwrap_err(); + } } diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 536be3774..7a67de28a 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -18,7 +18,7 @@ use lazy_static::lazy_static; pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.int.types"); /// Identifier for the integer type. -const INT_TYPE_ID: SmolStr = SmolStr::new_inline("int"); +pub const INT_TYPE_ID: SmolStr = SmolStr::new_inline("int"); fn int_custom_type(width_arg: TypeArg) -> CustomType { CustomType::new(INT_TYPE_ID, [width_arg], EXTENSION_ID, TypeBound::Eq) @@ -198,6 +198,21 @@ pub fn extension() -> Extension { extension } +lazy_static! { + /// Lazy reference to int types extension. + pub static ref EXTENSION: Extension = extension(); +} + +/// get an integer type variable, given the integer type definition +pub(super) fn int_type_var(var_id: usize) -> Type { + Type::new_extension( + EXTENSION + .get_type(&INT_TYPE_ID) + .unwrap() + .instantiate(vec![TypeArg::new_var_use(var_id, LOG_WIDTH_TYPE_PARAM)]) + .unwrap(), + ) +} #[cfg(test)] mod test { use cool_asserts::assert_matches; @@ -232,6 +247,7 @@ mod test { assert_ne!(const_u32_7, const_u64_7); assert_ne!(const_u32_7, const_u32_8); assert_eq!(const_u32_7, ConstIntU::new(5, 7)); + assert_matches!( ConstIntU::new(3, 256), Err(ConstTypeError::CustomCheckFail(_)) @@ -244,6 +260,40 @@ mod test { ConstIntS::new(3, 128), Err(ConstTypeError::CustomCheckFail(_)) ); - assert_matches!(ConstIntS::new(3, -128), Ok(_)); + assert!(ConstIntS::new(3, -128).is_ok()); + + let const_u32_7 = const_u32_7.unwrap(); + assert!(const_u32_7.equal_consts(&ConstIntU::new(5, 7).unwrap())); + assert_eq!(const_u32_7.log_width(), 5); + assert_eq!(const_u32_7.value(), 7); + assert!(const_u32_7 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 5 })) + .is_ok()); + assert!(const_u32_7 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 6 })) + .is_err()); + + assert_eq!(const_u32_7.name(), "u5(7)"); + assert!(const_u32_7 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 19 })) + .is_err()); + + let const_i32_2 = ConstIntS::new(5, -2).unwrap(); + assert!(const_i32_2.equal_consts(&ConstIntS::new(5, -2).unwrap())); + assert_eq!(const_i32_2.log_width(), 5); + assert_eq!(const_i32_2.value(), -2); + assert!(const_i32_2 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 5 })) + .is_ok()); + assert!(const_i32_2 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 6 })) + .is_err()); + assert!(const_i32_2 + .check_custom_type(&int_custom_type(TypeArg::BoundedNat { n: 19 })) + .is_err()); + assert_eq!(const_i32_2.name(), "i5(-2)"); + + ConstIntS::new(50, -2).unwrap_err(); + ConstIntU::new(50, 2).unwrap_err(); } } diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index 40b10f6a2..cc830f930 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -5,10 +5,10 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}, + extension::{ExtensionId, ExtensionRegistry, TypeDef, TypeDefBound}, types::{ type_param::{TypeArg, TypeParam}, - CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow, + CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, }, values::{CustomConst, Value}, Extension, @@ -59,6 +59,7 @@ impl CustomConst for ListValue { crate::values::downcast_equal_consts(self, other) } } +const TP: TypeParam = TypeParam::Type(TypeBound::Any); fn extension() -> Extension { let mut extension = Extension::new(EXTENSION_NAME); @@ -66,43 +67,42 @@ fn extension() -> Extension { extension .add_type( LIST_TYPENAME, - vec![TypeParam::Type(TypeBound::Any)], + vec![TP], "Generic dynamically sized list of type T.".into(), TypeDefBound::FromParams(vec![0]), ) .unwrap(); + let temp_reg: ExtensionRegistry = [extension.clone()].into(); + let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); + + let (l, e) = list_and_elem_type(list_type_def); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( POP_NAME, "Pop from back of list".into(), - vec![TypeParam::Type(TypeBound::Any)], - move |args: &[TypeArg]| { - let (list_type, element_type) = list_types(args)?; - Ok(FunctionType { - input: TypeRow::from(vec![list_type.clone()]), - output: TypeRow::from(vec![list_type, element_type]), - extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME), - }) - }, + PolyFuncType::new_validated( + vec![TP], + FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]), + &temp_reg, + ) + .unwrap(), ) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( PUSH_NAME, "Push to back of list".into(), - vec![TypeParam::Type(TypeBound::Any)], - move |args: &[TypeArg]| { - let (list_type, element_type) = list_types(args)?; - Ok(FunctionType { - output: TypeRow::from(vec![list_type.clone()]), - input: TypeRow::from(vec![list_type, element_type]), - extension_reqs: ExtensionSet::singleton(&EXTENSION_NAME), - }) - }, + PolyFuncType::new_validated( + vec![TP], + FunctionType::new(vec![l.clone(), e], vec![l]), + &temp_reg, + ) + .unwrap(), ) .unwrap(); extension } + lazy_static! { /// Collections extension definition. pub static ref EXTENSION: Extension = extension(); @@ -112,16 +112,15 @@ fn get_type(name: &str) -> &TypeDef { EXTENSION.get_type(name).unwrap() } -fn list_types(args: &[TypeArg]) -> Result<(Type, Type), SignatureError> { - let list_custom_type = get_type(&LIST_TYPENAME).instantiate(args)?; - let [TypeArg::Type { ty: element_type }] = args else { - panic!("should be checked by def.") - }; - - let list_type: Type = Type::new_extension(list_custom_type); - Ok((list_type, element_type.clone())) +fn list_and_elem_type(list_type_def: &TypeDef) -> (Type, Type) { + let elem_type = Type::new_var_use(0, TypeBound::Any); + let list_type = Type::new_extension( + list_type_def + .instantiate(vec![TypeArg::new_var_use(0, TP)]) + .unwrap(), + ); + (list_type, elem_type) } - #[cfg(test)] mod test { use crate::{ @@ -130,7 +129,7 @@ mod test { OpDef, PRELUDE, }, std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, - types::{type_param::TypeArg, Type}, + types::{type_param::TypeArg, Type, TypeRow}, Extension, }; diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 978c520c8..794b98c9d 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -7,7 +7,7 @@ use crate::{ extension::{prelude::BOOL_T, ExtensionId}, ops, type_row, types::{ - type_param::{TypeArg, TypeArgError, TypeParam}, + type_param::{TypeArg, TypeParam}, FunctionType, }, Extension, @@ -34,11 +34,10 @@ fn extension() -> Extension { let mut extension = Extension::new(EXTENSION_ID); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( SmolStr::new_inline(NOT_NAME), "logical 'not'".into(), - vec![], - |_arg_values: &[TypeArg]| Ok(FunctionType::new(type_row![BOOL_T], type_row![BOOL_T])), + FunctionType::new(type_row![BOOL_T], type_row![BOOL_T]).into(), ) .unwrap(); @@ -48,19 +47,12 @@ fn extension() -> Extension { "logical 'and'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let a = arg_values.iter().exactly_one().unwrap(); - let n: u64 = match a { - TypeArg::BoundedNat { n } => *n, - _ => { - return Err(TypeArgError::TypeMismatch { - arg: a.clone(), - param: H_INT, - } - .into()); - } + let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else { + panic!("should be covered by validation.") }; + Ok(FunctionType::new( - vec![BOOL_T; n as usize], + vec![BOOL_T; *n as usize], type_row![BOOL_T], )) }, @@ -73,19 +65,12 @@ fn extension() -> Extension { "logical 'or'".into(), vec![H_INT], |arg_values: &[TypeArg]| { - let a = arg_values.iter().exactly_one().unwrap(); - let n: u64 = match a { - TypeArg::BoundedNat { n } => *n, - _ => { - return Err(TypeArgError::TypeMismatch { - arg: a.clone(), - param: H_INT, - } - .into()); - } + let Ok(TypeArg::BoundedNat { n }) = arg_values.iter().exactly_one() else { + panic!("should be covered by validation.") }; + Ok(FunctionType::new( - vec![BOOL_T; n as usize], + vec![BOOL_T; *n as usize], type_row![BOOL_T], )) }, @@ -115,7 +100,7 @@ pub(crate) mod test { Extension, }; - use super::{extension, AND_NAME, EXTENSION, FALSE_NAME, NOT_NAME, TRUE_NAME}; + use super::{extension, AND_NAME, EXTENSION, FALSE_NAME, NOT_NAME, OR_NAME, TRUE_NAME}; #[test] fn test_logic_extension() { @@ -144,6 +129,14 @@ pub(crate) mod test { .into() } + /// Generate a logic extension and "or" operation over [`crate::prelude::BOOL_T`] + pub(crate) fn or_op() -> LeafOp { + EXTENSION + .instantiate_extension_op(OR_NAME, [TypeArg::BoundedNat { n: 2 }], &EMPTY_REG) + .unwrap() + .into() + } + /// Generate a logic extension and "not" operation over [`crate::prelude::BOOL_T`] pub(crate) fn not_op() -> LeafOp { EXTENSION diff --git a/src/utils.rs b/src/utils.rs index de37f348c..9ce62ba37 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -34,23 +34,23 @@ pub(crate) mod test_quantum_extension { use crate::{ extension::{ prelude::{BOOL_T, QB_T}, - ExtensionId, ExtensionRegistry, SignatureError, PRELUDE, + ExtensionId, ExtensionRegistry, PRELUDE, }, ops::LeafOp, std_extensions::arithmetic::float_types::FLOAT64_TYPE, type_row, - types::{FunctionType, TypeArg}, + types::{FunctionType, PolyFuncType}, Extension, }; use lazy_static::lazy_static; - fn one_qb_func(_: &[TypeArg]) -> Result { - Ok(FunctionType::new_linear(type_row![QB_T])) + fn one_qb_func() -> PolyFuncType { + FunctionType::new_linear(type_row![QB_T]).into() } - fn two_qb_func(_: &[TypeArg]) -> Result { - Ok(FunctionType::new_linear(type_row![QB_T, QB_T])) + fn two_qb_func() -> PolyFuncType { + FunctionType::new_linear(type_row![QB_T, QB_T]).into() } /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); @@ -58,42 +58,25 @@ pub(crate) mod test_quantum_extension { let mut extension = Extension::new(EXTENSION_ID); extension - .add_op_custom_sig_simple( - SmolStr::new_inline("H"), - "Hadamard".into(), - vec![], - one_qb_func, - ) + .add_op_type_scheme_simple(SmolStr::new_inline("H"), "Hadamard".into(), one_qb_func()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( SmolStr::new_inline("RzF64"), "Rotation specified by float".into(), - vec![], - |_: &[_]| { - Ok(FunctionType::new( - type_row![QB_T, FLOAT64_TYPE], - type_row![QB_T], - )) - }, + FunctionType::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]).into(), ) .unwrap(); extension - .add_op_custom_sig_simple(SmolStr::new_inline("CX"), "CX".into(), vec![], two_qb_func) + .add_op_type_scheme_simple(SmolStr::new_inline("CX"), "CX".into(), two_qb_func()) .unwrap(); extension - .add_op_custom_sig_simple( + .add_op_type_scheme_simple( SmolStr::new_inline("Measure"), "Measure a qubit, returning the qubit and the measurement result.".into(), - vec![], - |_arg_values: &[TypeArg]| { - Ok(FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T])) - // TODO add logic as an extension delta when inference is - // done? - // https://github.com/CQCL-DEV/hugr/issues/425 - }, + FunctionType::new(type_row![QB_T], type_row![QB_T, BOOL_T]).into(), ) .unwrap();