diff --git a/src/extension.rs b/src/extension.rs index 25efa58f1..5bfb81924 100644 --- a/src/extension.rs +++ b/src/extension.rs @@ -35,33 +35,38 @@ pub use prelude::{PRELUDE, PRELUDE_REGISTRY}; pub struct ExtensionRegistry(BTreeMap); impl ExtensionRegistry { - /// Makes a new (empty) registry. - pub const fn new() -> Self { - Self(BTreeMap::new()) - } - /// Gets the Extension with the given name pub fn get(&self, name: &str) -> Option<&Extension> { self.0.get(name) } -} - -/// An Extension Registry containing no extensions. -pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry::new(); -impl> From for ExtensionRegistry { - fn from(value: T) -> Self { - let mut reg = Self::new(); + /// Makes a new ExtensionRegistry, validating all the extensions in it + pub fn try_new( + value: impl IntoIterator, + ) -> Result { + let mut exts = BTreeMap::new(); for ext in value.into_iter() { - let prev = reg.0.insert(ext.name.clone(), ext); + let prev = exts.insert(ext.name.clone(), ext); if let Some(prev) = prev { panic!("Multiple extensions with same name: {}", prev.name) }; } - reg + // Note this potentially asks extensions to validate themselves against other extensions that + // may *not* be valid themselves yet. It'd be better to order these respecting dependencies, + // or at least to validate the types first - which we don't do at all yet: + // TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be + // cyclically dependent, so there is no perfect solution, and this is at least simple. + let res = ExtensionRegistry(exts); + for ext in res.0.values() { + ext.validate(&res).map_err(|e| (ext.name().clone(), e))?; + } + Ok(res) } } +/// An Extension Registry containing no extensions. +pub const EMPTY_REG: ExtensionRegistry = ExtensionRegistry(BTreeMap::new()); + /// An error that can occur in computing the signature of a node. /// TODO: decide on failure modes #[derive(Debug, Clone, Error, PartialEq, Eq)] @@ -290,6 +295,16 @@ impl Extension { let op_def = self.get_op(op_name).expect("Op not found."); ExtensionOp::new(op_def.clone(), args, ext_reg) } + + // Validates against a registry, which we can assume includes this extension itself. + // (TODO deal with the registry itself containing invalid extensions!) + fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> { + // We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624 + for op_def in self.operations.values() { + op_def.validate(all_exts)?; + } + Ok(()) + } } impl PartialEq for Extension { diff --git a/src/extension/op_def.rs b/src/extension/op_def.rs index 4774d5c62..b9107acba 100644 --- a/src/extension/op_def.rs +++ b/src/extension/op_def.rs @@ -245,6 +245,15 @@ impl OpDef { SignatureFunc::CustomFunc { static_params, .. } => static_params, } } + + pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> { + // TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams + // for both type scheme and custom binary + if let SignatureFunc::TypeScheme(ts) = &self.signature_func { + ts.validate(exts, &[])?; + } + Ok(()) + } } impl Extension { @@ -356,7 +365,7 @@ mod test { use crate::builder::{DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::prelude::USIZE_T; - use crate::extension::PRELUDE; + use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::ops::custom::ExternalOp; use crate::ops::LeafOp; use crate::std_extensions::collections::{EXTENSION, LIST_TYPENAME}; @@ -370,34 +379,29 @@ mod test { #[test] fn op_def_with_type_scheme() -> Result<(), Box> { - let reg1 = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); let mut e = Extension::new(EXT_ID); const TP: TypeParam = TypeParam::Type(TypeBound::Any); let list_of_var = Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); const OP_NAME: SmolStr = SmolStr::new_inline("Reverse"); - let type_scheme = PolyFuncType::new_validated( - vec![TP], - FunctionType::new_endo(vec![list_of_var]), - ®1, - )?; + let type_scheme = PolyFuncType::new(vec![TP], FunctionType::new_endo(vec![list_of_var])); e.add_op_type_scheme(OP_NAME, "".into(), Default::default(), vec![], type_scheme)?; + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap(); + let e = reg.get(&EXT_ID).unwrap(); let list_usize = Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: USIZE_T }])?); let mut dfg = DFGBuilder::new(FunctionType::new_endo(vec![list_usize]))?; let rev = dfg.add_dataflow_op( LeafOp::from(ExternalOp::Extension( - e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], ®1) + e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: USIZE_T }], ®) .unwrap(), )), dfg.input_wires(), )?; - dfg.finish_hugr_with_outputs( - rev.outputs(), - &[PRELUDE.to_owned(), EXTENSION.to_owned(), e].into(), - )?; + dfg.finish_hugr_with_outputs(rev.outputs(), ®)?; Ok(()) } diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index 37ad705bc..cfd66563d 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -76,7 +76,8 @@ lazy_static! { prelude }; /// An extension registry containing only the prelude - pub static ref PRELUDE_REGISTRY: ExtensionRegistry = [PRELUDE_DEF.to_owned()].into(); + pub static ref PRELUDE_REGISTRY: ExtensionRegistry = + ExtensionRegistry::try_new([PRELUDE_DEF.to_owned()]).unwrap(); /// Prelude extension pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap(); diff --git a/src/hugr/rewrite/replace.rs b/src/hugr/rewrite/replace.rs index 16eeb9321..dd478969f 100644 --- a/src/hugr/rewrite/replace.rs +++ b/src/hugr/rewrite/replace.rs @@ -458,7 +458,9 @@ mod test { #[test] fn cfg() -> Result<(), Box> { - let reg: ExtensionRegistry = [PRELUDE.to_owned(), collections::EXTENSION.to_owned()].into(); + let reg = + ExtensionRegistry::try_new([PRELUDE.to_owned(), collections::EXTENSION.to_owned()]) + .unwrap(); let listy = Type::new_extension( collections::EXTENSION .get_type(collections::LIST_TYPENAME.as_str()) diff --git a/src/hugr/validate/test.rs b/src/hugr/validate/test.rs index edc80c955..c72c90615 100644 --- a/src/hugr/validate/test.rs +++ b/src/hugr/validate/test.rs @@ -672,7 +672,7 @@ fn invalid_types() { TypeDefBound::Explicit(TypeBound::Any), ) .unwrap(); - let reg: ExtensionRegistry = [e, PRELUDE.to_owned()].into(); + let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()]).unwrap(); let validate_to_sig_error = |t: CustomType| { let (h, def) = identity_hugr_with_type(Type::new_extension(t)); diff --git a/src/lib.rs b/src/lib.rs index 2fca6ace7..772e6a5c0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,7 +83,8 @@ //! lazy_static! { //! /// Quantum extension definition. //! pub static ref EXTENSION: Extension = extension(); -//! static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); +//! static ref REG: ExtensionRegistry = +//! ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned()]).unwrap(); //! //! } //! fn get_gate(gate_name: &str) -> LeafOp { diff --git a/src/ops/constant.rs b/src/ops/constant.rs index bd8970b1e..5f87c96d2 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -126,11 +126,10 @@ mod test { builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, extension::{ prelude::{ConstUsize, USIZE_T}, - ExtensionId, ExtensionSet, + ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE, }, - std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}, + std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, type_row, - types::test::test_registry, types::type_param::TypeArg, types::{CustomCheckFailure, CustomType, FunctionType, Type, TypeBound, TypeRow}, values::{ @@ -143,6 +142,10 @@ mod test { use super::*; + fn test_registry() -> ExtensionRegistry { + ExtensionRegistry::try_new([PRELUDE.to_owned(), float_types::extension()]).unwrap() + } + #[test] fn test_tuple_sum() -> Result<(), BuildError> { use crate::builder::Container; diff --git a/src/ops/leaf.rs b/src/ops/leaf.rs index 85de1a0a8..432790cf8 100644 --- a/src/ops/leaf.rs +++ b/src/ops/leaf.rs @@ -192,8 +192,8 @@ impl DataflowOpTrait for LeafOp { mod test { use crate::builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::prelude::BOOL_T; - use crate::extension::SignatureError; use crate::extension::{prelude::USIZE_T, PRELUDE}; + use crate::extension::{ExtensionRegistry, SignatureError}; use crate::hugr::ValidationError; use crate::ops::handle::NodeHandle; use crate::std_extensions::collections::EXTENSION; @@ -206,7 +206,7 @@ mod test { #[test] fn hugr_with_type_apply() -> Result<(), Box> { - let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); + let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap(); let pf_in = nested_func(); let pf_out = pf_in.instantiate(&[USIZE_TA], ®)?; let mut dfg = DFGBuilder::new(FunctionType::new( @@ -225,7 +225,7 @@ mod test { #[test] fn bad_type_apply() -> Result<(), Box> { - let reg = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); + let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap(); let pf = nested_func(); let pf_usz = pf.instantiate_poly(&[USIZE_TA], ®)?; let pf_bool = pf.instantiate_poly(&[TypeArg::Type { ty: BOOL_T }], ®)?; diff --git a/src/std_extensions/arithmetic/conversions.rs b/src/std_extensions/arithmetic/conversions.rs index 63207c219..d94baea6b 100644 --- a/src/std_extensions/arithmetic/conversions.rs +++ b/src/std_extensions/arithmetic/conversions.rs @@ -1,10 +1,7 @@ //! Conversions between integer and floating-point values. use crate::{ - extension::{ - prelude::sum_with_error, ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError, - PRELUDE, - }, + extension::{prelude::sum_with_error, ExtensionId, ExtensionSet}, type_row, types::{FunctionType, PolyFuncType}, Extension, @@ -16,23 +13,21 @@ use super::{float_types::FLOAT64_TYPE, int_types::LOG_WIDTH_TYPE_PARAM}; /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("arithmetic.conversions"); -fn ftoi_sig(temp_reg: &ExtensionRegistry) -> Result { - let body = FunctionType::new( - type_row![FLOAT64_TYPE], - vec![sum_with_error(int_type_var(0))], +/// Extension for basic arithmetic operations. +pub fn extension() -> Extension { + let ftoi_sig = PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new( + type_row![FLOAT64_TYPE], + vec![sum_with_error(int_type_var(0))], + ), ); - PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) -} - -fn itof_sig(temp_reg: &ExtensionRegistry) -> Result { - let body = FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]); - - PolyFuncType::new_validated(vec![LOG_WIDTH_TYPE_PARAM], body, temp_reg) -} + let itof_sig = PolyFuncType::new( + vec![LOG_WIDTH_TYPE_PARAM], + FunctionType::new(vec![int_type_var(0)], type_row![FLOAT64_TYPE]), + ); -/// Extension for basic arithmetic operations. -pub fn extension() -> Extension { let mut extension = Extension::new_with_reqs( EXTENSION_ID, ExtensionSet::from_iter(vec![ @@ -40,38 +35,28 @@ pub fn extension() -> Extension { super::float_types::EXTENSION_ID, ]), ); - let temp_reg: ExtensionRegistry = [ - super::int_types::EXTENSION.to_owned(), - super::float_types::extension(), - PRELUDE.to_owned(), - ] - .into(); extension .add_op_type_scheme_simple( "trunc_u".into(), "float to unsigned int".to_owned(), - ftoi_sig(&temp_reg).unwrap(), + ftoi_sig.clone(), ) .unwrap(); extension - .add_op_type_scheme_simple( - "trunc_s".into(), - "float to signed int".to_owned(), - ftoi_sig(&temp_reg).unwrap(), - ) + .add_op_type_scheme_simple("trunc_s".into(), "float to signed int".to_owned(), ftoi_sig) .unwrap(); extension .add_op_type_scheme_simple( "convert_u".into(), "unsigned int to float".to_owned(), - itof_sig(&temp_reg).unwrap(), + itof_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "convert_s".into(), "signed int to float".to_owned(), - itof_sig(&temp_reg).unwrap(), + itof_sig, ) .unwrap(); diff --git a/src/std_extensions/arithmetic/int_ops.rs b/src/std_extensions/arithmetic/int_ops.rs index cc09fc4a3..68ebb107d 100644 --- a/src/std_extensions/arithmetic/int_ops.rs +++ b/src/std_extensions/arithmetic/int_ops.rs @@ -2,7 +2,6 @@ use super::int_types::{get_log_width, int_type, int_type_var, LOG_WIDTH_TYPE_PARAM}; use crate::extension::prelude::{sum_with_error, BOOL_T}; -use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::type_row; use crate::types::{FunctionType, PolyFuncType}; use crate::utils::collect_array; @@ -45,103 +44,73 @@ fn int_polytype( n_vars: usize, input: impl Into, output: impl Into, - temp_reg: &ExtensionRegistry, -) -> Result { - PolyFuncType::new_validated( +) -> PolyFuncType { + PolyFuncType::new( vec![LOG_WIDTH_TYPE_PARAM; n_vars], FunctionType::new(input, output), - temp_reg, ) } -fn itob_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T], temp_reg) -} - -fn btoi_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype(1, type_row![BOOL_T], vec![int_type_var(0)], temp_reg) -} - -fn icmp_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype(1, vec![int_type_var(0); 2], type_row![BOOL_T], temp_reg) -} - -fn ibinop_sig(temp_reg: &ExtensionRegistry) -> Result { +fn ibinop_sig() -> PolyFuncType { let int_type_var = int_type_var(0); - int_polytype( - 1, - vec![int_type_var.clone(); 2], - vec![int_type_var], - temp_reg, - ) + int_polytype(1, vec![int_type_var.clone(); 2], vec![int_type_var]) } -fn iunop_sig(temp_reg: &ExtensionRegistry) -> Result { +fn iunop_sig() -> PolyFuncType { let int_type_var = int_type_var(0); - int_polytype(1, vec![int_type_var.clone()], vec![int_type_var], temp_reg) + int_polytype(1, vec![int_type_var.clone()], vec![int_type_var]) } -fn idivmod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { +fn idivmod_checked_sig() -> PolyFuncType { let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); int_polytype( 2, intpair.clone(), vec![sum_with_error(Type::new_tuple(intpair))], - temp_reg, ) } -fn idivmod_sig(temp_reg: &ExtensionRegistry) -> Result { +fn idivmod_sig() -> PolyFuncType { let intpair: TypeRow = vec![int_type_var(0), int_type_var(1)].into(); - int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)], temp_reg) + int_polytype(2, intpair.clone(), vec![Type::new_tuple(intpair)]) } -fn idiv_checked_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype( +/// Extension for basic integer operations. +pub fn extension() -> Extension { + let itob_sig = int_polytype(1, vec![int_type_var(0)], type_row![BOOL_T]); + + let btoi_sig = int_polytype(1, type_row![BOOL_T], vec![int_type_var(0)]); + + let icmp_sig = int_polytype(1, vec![int_type_var(0); 2], type_row![BOOL_T]); + + let idiv_checked_sig = int_polytype( 2, vec![int_type_var(1)], vec![sum_with_error(int_type_var(0))], - temp_reg, - ) -} + ); -fn idiv_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) -} + let idiv_sig = int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)]); -fn imod_checked_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype( + let imod_checked_sig = int_polytype( 2, vec![int_type_var(0), int_type_var(1).clone()], vec![sum_with_error(int_type_var(1))], - temp_reg, - ) -} + ); -fn imod_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype( + let imod_sig = int_polytype( 2, vec![int_type_var(0), int_type_var(1).clone()], vec![int_type_var(1)], - temp_reg, - ) -} + ); -fn ish_sig(temp_reg: &ExtensionRegistry) -> Result { - int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)], temp_reg) -} + let ish_sig = int_polytype(2, vec![int_type_var(1)], vec![int_type_var(0)]); -/// Extension for basic integer operations. -pub fn extension() -> Extension { let mut extension = Extension::new_with_reqs( EXTENSION_ID, ExtensionSet::singleton(&super::int_types::EXTENSION_ID), ); - let temp_reg: ExtensionRegistry = - [super::int_types::EXTENSION.to_owned(), PRELUDE.to_owned()].into(); - extension .add_op_custom_sig_simple( "iwiden_u".into(), @@ -179,140 +148,132 @@ pub fn extension() -> Extension { .add_op_type_scheme_simple( "itobool".into(), "convert to bool (1 is true, 0 is false)".to_owned(), - itob_sig(&temp_reg).unwrap(), + itob_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "ifrombool".into(), "convert from bool (1 is true, 0 is false)".to_owned(), - btoi_sig(&temp_reg).unwrap(), + btoi_sig.clone(), ) .unwrap(); extension - .add_op_type_scheme_simple( - "ieq".into(), - "equality test".to_owned(), - icmp_sig(&temp_reg).unwrap(), - ) + .add_op_type_scheme_simple("ieq".into(), "equality test".to_owned(), icmp_sig.clone()) .unwrap(); extension - .add_op_type_scheme_simple( - "ine".into(), - "inequality test".to_owned(), - icmp_sig(&temp_reg).unwrap(), - ) + .add_op_type_scheme_simple("ine".into(), "inequality test".to_owned(), icmp_sig.clone()) .unwrap(); extension .add_op_type_scheme_simple( "ilt_u".into(), "\"less than\" as unsigned integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "ilt_s".into(), "\"less than\" as signed integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "igt_u".into(), "\"greater than\" as unsigned integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "igt_s".into(), "\"greater than\" as signed integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "ile_u".into(), "\"less than or equal\" as unsigned integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "ile_s".into(), "\"less than or equal\" as signed integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "ige_u".into(), "\"greater than or equal\" as unsigned integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "ige_s".into(), "\"greater than or equal\" as signed integers".to_owned(), - icmp_sig(&temp_reg).unwrap(), + icmp_sig.clone(), ) .unwrap(); extension .add_op_type_scheme_simple( "imax_u".into(), "maximum of unsigned integers".to_owned(), - ibinop_sig(&temp_reg).unwrap(), + ibinop_sig(), ) .unwrap(); extension .add_op_type_scheme_simple( "imax_s".into(), "maximum of signed integers".to_owned(), - ibinop_sig(&temp_reg).unwrap(), + ibinop_sig(), ) .unwrap(); extension .add_op_type_scheme_simple( "imin_u".into(), "minimum of unsigned integers".to_owned(), - ibinop_sig(&temp_reg).unwrap(), + ibinop_sig(), ) .unwrap(); extension .add_op_type_scheme_simple( "imin_s".into(), "minimum of signed integers".to_owned(), - ibinop_sig(&temp_reg).unwrap(), + ibinop_sig(), ) .unwrap(); extension .add_op_type_scheme_simple( "iadd".into(), "addition modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(&temp_reg).unwrap(), + ibinop_sig(), ) .unwrap(); extension .add_op_type_scheme_simple( "isub".into(), "subtraction modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(&temp_reg).unwrap(), + ibinop_sig(), ) .unwrap(); extension .add_op_type_scheme_simple( "ineg".into(), "negation modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - iunop_sig(&temp_reg).unwrap(), + iunop_sig(), ) .unwrap(); extension .add_op_type_scheme_simple( "imul".into(), "multiplication modulo 2^N (signed and unsigned versions are the same op)".to_owned(), - ibinop_sig(&temp_reg).unwrap(), + ibinop_sig(), ) .unwrap(); extension @@ -321,7 +282,7 @@ pub fn extension() -> Extension { "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { "given unsigned integers 0 <= n < 2^N, 0 <= m < 2^M, generates unsigned q, r where \ q*m+r=n, 0<=r Extension { "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { "given signed integer -2^{N-1} <= n < 2^{N-1} and unsigned 0 <= m < 2^M, generates \ signed q and unsigned r where q*m+r=n, 0<=r Extension { "shift first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits dropped, rightmost bits set to zero" .to_owned(), - ish_sig(&temp_reg).unwrap(), + ish_sig.clone(), ) .unwrap(); extension @@ -457,7 +402,7 @@ pub fn extension() -> Extension { "shift first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits dropped, leftmost bits set to zero)" .to_owned(), - ish_sig(&temp_reg).unwrap(), + ish_sig.clone(), ) .unwrap(); extension @@ -466,7 +411,7 @@ pub fn extension() -> Extension { "rotate first input left by k bits where k is unsigned interpretation of second input \ (leftmost bits replace rightmost bits)" .to_owned(), - ish_sig(&temp_reg).unwrap(), + ish_sig.clone(), ) .unwrap(); extension @@ -475,7 +420,7 @@ pub fn extension() -> Extension { "rotate first input right by k bits where k is unsigned interpretation of second input \ (rightmost bits replace leftmost bits)" .to_owned(), - ish_sig( &temp_reg).unwrap(), + ish_sig.clone(), ) .unwrap(); diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index cc830f930..99d208c63 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::{ - extension::{ExtensionId, ExtensionRegistry, TypeDef, TypeDefBound}, + extension::{ExtensionId, TypeDef, TypeDefBound}, types::{ type_param::{TypeArg, TypeParam}, CustomCheckFailure, CustomType, FunctionType, PolyFuncType, Type, TypeBound, @@ -72,7 +72,6 @@ fn extension() -> Extension { TypeDefBound::FromParams(vec![0]), ) .unwrap(); - let temp_reg: ExtensionRegistry = [extension.clone()].into(); let list_type_def = extension.get_type(&LIST_TYPENAME).unwrap(); let (l, e) = list_and_elem_type(list_type_def); @@ -80,24 +79,17 @@ fn extension() -> Extension { .add_op_type_scheme_simple( POP_NAME, "Pop from back of list".into(), - PolyFuncType::new_validated( + PolyFuncType::new( vec![TP], FunctionType::new(vec![l.clone()], vec![l.clone(), e.clone()]), - &temp_reg, - ) - .unwrap(), + ), ) .unwrap(); extension .add_op_type_scheme_simple( PUSH_NAME, "Push to back of list".into(), - PolyFuncType::new_validated( - vec![TP], - FunctionType::new(vec![l.clone(), e], vec![l]), - &temp_reg, - ) - .unwrap(), + PolyFuncType::new(vec![TP], FunctionType::new(vec![l.clone(), e], vec![l])), ) .unwrap(); extension @@ -126,7 +118,7 @@ mod test { use crate::{ extension::{ prelude::{ConstUsize, QB_T, USIZE_T}, - OpDef, PRELUDE, + ExtensionRegistry, OpDef, PRELUDE, }, std_extensions::arithmetic::float_types::{self, ConstF64, FLOAT64_TYPE}, types::{type_param::TypeArg, Type, TypeRow}, @@ -169,14 +161,14 @@ mod test { #[test] fn test_list_ops() { - let reg = &[ + let reg = ExtensionRegistry::try_new([ EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::extension(), - ] - .into(); + ]) + .unwrap(); let pop_sig = get_op(&POP_NAME) - .compute_signature(&[TypeArg::Type { ty: QB_T }], reg) + .compute_signature(&[TypeArg::Type { ty: QB_T }], ®) .unwrap(); let list_type = Type::new_extension(CustomType::new( @@ -192,7 +184,7 @@ mod test { assert_eq!(pop_sig.output(), &both_row); let push_sig = get_op(&PUSH_NAME) - .compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }], reg) + .compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }], ®) .unwrap(); let list_type = Type::new_extension(CustomType::new( diff --git a/src/types.rs b/src/types.rs index 123a320a6..463848d26 100644 --- a/src/types.rs +++ b/src/types.rs @@ -412,18 +412,10 @@ pub(crate) mod test { pub(crate) use poly_func::test::nested_func; use super::*; - use crate::{ - extension::{prelude::USIZE_T, PRELUDE}, - ops::AliasDecl, - std_extensions::arithmetic::float_types, - }; + use crate::{extension::prelude::USIZE_T, ops::AliasDecl}; use crate::types::TypeBound; - pub(crate) fn test_registry() -> ExtensionRegistry { - vec![PRELUDE.to_owned(), float_types::extension()].into() - } - #[test] fn construct() { let t: Type = Type::new_tuple(vec![ diff --git a/src/types/poly_func.rs b/src/types/poly_func.rs index 8b33287e8..759776fb1 100644 --- a/src/types/poly_func.rs +++ b/src/types/poly_func.rs @@ -49,29 +49,24 @@ impl PolyFuncType { &self.params } - /// Create a new PolyFuncType and validates it. (This will only succeed - /// for outermost PolyFuncTypes i.e. with no free type-variables.) - /// The [ExtensionRegistry] should be the same (or a subset) of that which will later - /// be used to validate the Hugr; at this point we only need the types. - /// - /// #Errors - /// Validates that all types in the schema are well-formed and all variables in the body - /// are declared with [TypeParam]s that guarantee they will fit. - pub fn new_validated( - params: impl Into>, - body: FunctionType, - extension_registry: &ExtensionRegistry, - ) -> Result { - let params = params.into(); - body.validate(extension_registry, ¶ms)?; - Ok(Self { params, body }) + /// Create a new PolyFuncType given the kinds of the variables it declares + /// and the underlying [FunctionType]. + pub fn new(params: impl Into>, body: FunctionType) -> Self { + Self { + params: params.into(), + body, + } } - pub(super) fn validate( + /// Validates this instance, checking that the types in the body are + /// wellformed with respect to the registry, and that all type variables + /// are declared (perhaps in an enclosing scope, kinds passed in). + pub fn validate( &self, reg: &ExtensionRegistry, external_var_decls: &[TypeParam], ) -> Result<(), SignatureError> { + // TODO https://github.com/CQCL/hugr/issues/624 validate TypeParams declared here, too let mut v; // Declared here so live until end of scope let all_var_decls = if self.params.is_empty() { external_var_decls @@ -224,6 +219,7 @@ impl<'a> Substitution for InsideBinders<'a> { pub(crate) mod test { use std::num::NonZeroU64; + use lazy_static::lazy_static; use smol_str::SmolStr; use crate::extension::prelude::{PRELUDE_ID, USIZE_CUSTOM_T, USIZE_T}; @@ -237,19 +233,35 @@ pub(crate) mod test { use super::PolyFuncType; + lazy_static! { + static ref REGISTRY: ExtensionRegistry = + ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned()]).unwrap(); + } + + impl PolyFuncType { + fn new_validated( + params: impl Into>, + body: FunctionType, + extension_registry: &ExtensionRegistry, + ) -> Result { + let res = Self::new(params, body); + res.validate(extension_registry, &[])?; + Ok(res) + } + } + #[test] fn test_opaque() -> Result<(), SignatureError> { let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); let tyvar = TypeArg::new_var_use(0, TypeParam::Type(TypeBound::Any)); let list_of_var = Type::new_extension(list_def.instantiate([tyvar.clone()])?); - let reg: ExtensionRegistry = [PRELUDE.to_owned(), EXTENSION.to_owned()].into(); let list_len = PolyFuncType::new_validated( [TypeParam::Type(TypeBound::Any)], FunctionType::new(vec![list_of_var], vec![USIZE_T]), - ®, + ®ISTRY, )?; - let t = list_len.instantiate(&[TypeArg::Type { ty: USIZE_T }], ®)?; + let t = list_len.instantiate(&[TypeArg::Type { ty: USIZE_T }], ®ISTRY)?; assert_eq!( t, FunctionType::new( @@ -330,14 +342,14 @@ pub(crate) mod test { let tv = TypeArg::new_var_use(0, TypeParam::Type(TypeBound::Copyable)); let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); let body_type = id_fn(Type::new_extension(list_def.instantiate([tv])?)); - let reg = [EXTENSION.to_owned()].into(); for decl in [ TypeParam::Extensions, TypeParam::List(Box::new(TypeParam::max_nat())), TypeParam::Opaque(USIZE_CUSTOM_T), TypeParam::Tuple(vec![TypeParam::Type(TypeBound::Any), TypeParam::max_nat()]), ] { - let invalid_ts = PolyFuncType::new_validated([decl.clone()], body_type.clone(), ®); + let invalid_ts = + PolyFuncType::new_validated([decl.clone()], body_type.clone(), ®ISTRY); assert_eq!( invalid_ts.err(), Some(SignatureError::TypeVarDoesNotMatchDeclaration { @@ -347,7 +359,7 @@ pub(crate) mod test { ); } // Variable not declared at all - let invalid_ts = PolyFuncType::new_validated([], body_type, ®); + let invalid_ts = PolyFuncType::new_validated([], body_type, ®ISTRY); assert_eq!( invalid_ts.err(), Some(SignatureError::FreeTypeVar { @@ -376,7 +388,7 @@ pub(crate) mod test { ) .unwrap(); - let reg: ExtensionRegistry = [e].into(); + let reg = ExtensionRegistry::try_new([e]).unwrap(); let make_scheme = |tp: TypeParam| { PolyFuncType::new_validated( @@ -525,7 +537,7 @@ pub(crate) mod test { ), ))], ), - &[EXTENSION.to_owned()].into(), + ®ISTRY, ) .unwrap() } @@ -533,7 +545,6 @@ pub(crate) mod test { #[test] fn test_instantiate_nested() -> Result<(), SignatureError> { let outer = nested_func(); - let reg: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); let arg = new_array(USIZE_T, TypeArg::BoundedNat { n: 5 }); // `arg` -> (forall C. C -> List(Tuple(C, `arg`))) @@ -551,7 +562,7 @@ pub(crate) mod test { ))], ); - let res = outer.instantiate(&[TypeArg::Type { ty: arg }], ®)?; + let res = outer.instantiate(&[TypeArg::Type { ty: arg }], ®ISTRY)?; assert_eq!(res, outer_applied); Ok(()) } @@ -561,11 +572,10 @@ pub(crate) mod test { let outer = nested_func(); // Now substitute in a free var from further outside - let reg = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); const FREE: usize = 3; const TP_EQ: TypeParam = TypeParam::Type(TypeBound::Eq); let res = outer - .instantiate(&[TypeArg::new_var_use(FREE, TP_EQ)], ®) + .instantiate(&[TypeArg::new_var_use(FREE, TP_EQ)], ®ISTRY) .unwrap(); assert_eq!( res, @@ -598,7 +608,7 @@ pub(crate) mod test { }; let res = outer - .instantiate(&[TypeArg::Type { ty: rhs(FREE) }], ®) + .instantiate(&[TypeArg::Type { ty: rhs(FREE) }], ®ISTRY) .unwrap(); assert_eq!( res, diff --git a/src/utils.rs b/src/utils.rs index 73787b4ed..8818e2822 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -37,7 +37,7 @@ pub(crate) mod test_quantum_extension { ExtensionId, ExtensionRegistry, PRELUDE, }, ops::LeafOp, - std_extensions::arithmetic::float_types::FLOAT64_TYPE, + std_extensions::arithmetic::float_types, type_row, types::{FunctionType, PolyFuncType}, Extension, @@ -64,7 +64,8 @@ pub(crate) mod test_quantum_extension { .add_op_type_scheme_simple( SmolStr::new_inline("RzF64"), "Rotation specified by float".into(), - FunctionType::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]).into(), + FunctionType::new(type_row![QB_T, float_types::FLOAT64_TYPE], type_row![QB_T]) + .into(), ) .unwrap(); @@ -86,7 +87,7 @@ pub(crate) mod test_quantum_extension { lazy_static! { /// Quantum extension definition. pub static ref EXTENSION: Extension = extension(); - static ref REG: ExtensionRegistry = [EXTENSION.to_owned(), PRELUDE.to_owned()].into(); + static ref REG: ExtensionRegistry = ExtensionRegistry::try_new([EXTENSION.to_owned(), PRELUDE.to_owned(), float_types::extension()]).unwrap(); } fn get_gate(gate_name: &str) -> LeafOp {