diff --git a/src/extension/prelude.rs b/src/extension/prelude.rs index dcd2bede7..5ce17338f 100644 --- a/src/extension/prelude.rs +++ b/src/extension/prelude.rs @@ -80,7 +80,10 @@ pub const BOOL_T: Type = Type::new_simple_predicate(2); pub fn new_array(typ: Type, size: u64) -> Type { let array_def = PRELUDE.get_type("array").unwrap(); let custom_t = array_def - .instantiate_concrete(vec![TypeArg::Type(typ), TypeArg::BoundedNat(size)]) + .instantiate_concrete(vec![ + TypeArg::Type { ty: typ }, + TypeArg::BoundedNat { n: size }, + ]) .unwrap(); Type::new_extension(custom_t) } diff --git a/src/extension/type_def.rs b/src/extension/type_def.rs index 026689977..b636b1712 100644 --- a/src/extension/type_def.rs +++ b/src/extension/type_def.rs @@ -106,7 +106,7 @@ impl TypeDef { least_upper_bound(indices.iter().map(|i| { let ta = args.get(*i); match ta { - Some(TypeArg::Type(s)) => s.least_upper_bound(), + Some(TypeArg::Type { ty: s }) => s.least_upper_bound(), _ => panic!("TypeArg index does not refer to a type."), } })) @@ -174,22 +174,24 @@ mod test { bound: TypeDefBound::FromParams(vec![0]), }; let typ = Type::new_extension( - def.instantiate_concrete(vec![TypeArg::Type(Type::new_function(FunctionType::new( - vec![], - vec![], - )))]) + def.instantiate_concrete(vec![TypeArg::Type { + ty: Type::new_function(FunctionType::new(vec![], vec![])), + }]) .unwrap(), ); assert_eq!(typ.least_upper_bound(), TypeBound::Copyable); - let typ2 = Type::new_extension(def.instantiate_concrete([TypeArg::Type(USIZE_T)]).unwrap()); + let typ2 = Type::new_extension( + def.instantiate_concrete([TypeArg::Type { ty: USIZE_T }]) + .unwrap(), + ); assert_eq!(typ2.least_upper_bound(), TypeBound::Eq); // And some bad arguments...firstly, wrong kind of TypeArg: assert_eq!( - def.instantiate_concrete([TypeArg::Type(QB_T)]), + def.instantiate_concrete([TypeArg::Type { ty: QB_T }]), Err(SignatureError::TypeArgMismatch( TypeArgError::TypeMismatch { - arg: TypeArg::Type(QB_T), + arg: TypeArg::Type { ty: QB_T }, param: TypeParam::Type(TypeBound::Copyable) } )) @@ -201,8 +203,11 @@ mod test { ); // Too many arguments: assert_eq!( - def.instantiate_concrete([TypeArg::Type(FLOAT64_TYPE), TypeArg::Type(FLOAT64_TYPE),]) - .unwrap_err(), + def.instantiate_concrete([ + TypeArg::Type { ty: FLOAT64_TYPE }, + TypeArg::Type { ty: FLOAT64_TYPE }, + ]) + .unwrap_err(), SignatureError::TypeArgMismatch(TypeArgError::WrongNumberArgs(2, 1)) ); } diff --git a/src/hugr/validate.rs b/src/hugr/validate.rs index b20ef8e08..5e4438fee 100644 --- a/src/hugr/validate.rs +++ b/src/hugr/validate.rs @@ -1351,7 +1351,7 @@ mod test { let valid = Type::new_extension(CustomType::new( "MyContainer", - vec![TypeArg::Type(USIZE_T)], + vec![TypeArg::Type { ty: USIZE_T }], "MyExt", TypeBound::Any, )); @@ -1363,7 +1363,7 @@ mod test { // valid is Any, so is not allowed as an element of an outer MyContainer. let element_outside_bound = CustomType::new( "MyContainer", - vec![TypeArg::Type(valid.clone())], + vec![TypeArg::Type { ty: valid.clone() }], "MyExt", TypeBound::Any, ); @@ -1371,13 +1371,13 @@ mod test { validate_to_sig_error(element_outside_bound), SignatureError::TypeArgMismatch(TypeArgError::TypeMismatch { param: TypeParam::Type(TypeBound::Copyable), - arg: TypeArg::Type(valid) + arg: TypeArg::Type { ty: valid } }) ); let bad_bound = CustomType::new( "MyContainer", - vec![TypeArg::Type(USIZE_T)], + vec![TypeArg::Type { ty: USIZE_T }], "MyExt", TypeBound::Copyable, ); @@ -1392,7 +1392,9 @@ mod test { // bad_bound claims to be Copyable, which is valid as an element for the outer MyContainer. let nested = CustomType::new( "MyContainer", - vec![TypeArg::Type(Type::new_extension(bad_bound))], + vec![TypeArg::Type { + ty: Type::new_extension(bad_bound), + }], "MyExt", TypeBound::Any, ); @@ -1406,7 +1408,7 @@ mod test { let too_many_type_args = CustomType::new( "MyContainer", - vec![TypeArg::Type(USIZE_T), TypeArg::BoundedNat(3)], + vec![TypeArg::Type { ty: USIZE_T }, TypeArg::BoundedNat { n: 3 }], "MyExt", TypeBound::Any, ); diff --git a/src/ops/constant.rs b/src/ops/constant.rs index fe7757014..b75240fdf 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -202,7 +202,7 @@ mod test { fn test_yaml_const() { let typ_int = CustomType::new( "mytype", - vec![TypeArg::BoundedNat(8)], + vec![TypeArg::BoundedNat { n: 8 }], "myrsrc", TypeBound::Eq, ); diff --git a/src/ops/custom.rs b/src/ops/custom.rs index 413243972..475e20358 100644 --- a/src/ops/custom.rs +++ b/src/ops/custom.rs @@ -286,12 +286,12 @@ mod test { "res".into(), "op", "desc".into(), - vec![TypeArg::Type(USIZE_T)], + vec![TypeArg::Type { ty: USIZE_T }], None, ); let op: ExternalOp = op.into(); assert_eq!(op.name(), "res.op"); assert_eq!(op.description(), "desc"); - assert_eq!(op.args(), &[TypeArg::Type(USIZE_T)]); + assert_eq!(op.args(), &[TypeArg::Type { ty: USIZE_T }]); } } diff --git a/src/std_extensions/arithmetic/int_types.rs b/src/std_extensions/arithmetic/int_types.rs index 144caa604..655441f26 100644 --- a/src/std_extensions/arithmetic/int_types.rs +++ b/src/std_extensions/arithmetic/int_types.rs @@ -33,7 +33,7 @@ pub(super) fn int_type(width_arg: TypeArg) -> Type { lazy_static! { /// Array of valid integer types, indexed by log width of the integer. pub static ref INT_TYPES: [Type; LOG_WIDTH_BOUND as usize] = (0..LOG_WIDTH_BOUND) - .map(|i| int_type(TypeArg::BoundedNat(i as u64))) + .map(|i| int_type(TypeArg::BoundedNat { n: i as u64 })) .collect::>() .try_into() .unwrap(); @@ -58,7 +58,7 @@ pub const LOG_WIDTH_TYPE_PARAM: TypeParam = TypeParam::bounded_nat(unsafe { /// is invalid. pub(super) fn get_log_width(arg: &TypeArg) -> Result { match arg { - TypeArg::BoundedNat(n) if is_valid_log_width(*n as u8) => Ok(*n as u8), + TypeArg::BoundedNat { n } if is_valid_log_width(*n as u8) => Ok(*n as u8), _ => Err(TypeArgError::TypeMismatch { arg: arg.clone(), param: LOG_WIDTH_TYPE_PARAM, @@ -67,7 +67,9 @@ pub(super) fn get_log_width(arg: &TypeArg) -> Result { } pub(super) const fn type_arg(log_width: u8) -> TypeArg { - TypeArg::BoundedNat(log_width as u64) + TypeArg::BoundedNat { + n: log_width as u64, + } } /// An unsigned integer #[derive(Clone, Debug, Eq, PartialEq, serde::Serialize, serde::Deserialize)] @@ -193,12 +195,12 @@ mod test { #[test] fn test_int_widths() { - let type_arg_32 = TypeArg::BoundedNat(5); + let type_arg_32 = TypeArg::BoundedNat { n: 5 }; assert_matches!(get_log_width(&type_arg_32), Ok(5)); - let type_arg_128 = TypeArg::BoundedNat(7); + let type_arg_128 = TypeArg::BoundedNat { n: 7 }; assert_matches!(get_log_width(&type_arg_128), Ok(7)); - let type_arg_256 = TypeArg::BoundedNat(8); + let type_arg_256 = TypeArg::BoundedNat { n: 8 }; assert_matches!( get_log_width(&type_arg_256), Err(TypeArgError::TypeMismatch { .. }) diff --git a/src/std_extensions/arithmetic/mod.rs b/src/std_extensions/arithmetic/mod.rs index f63e3143a..58cf1ad95 100644 --- a/src/std_extensions/arithmetic/mod.rs +++ b/src/std_extensions/arithmetic/mod.rs @@ -20,7 +20,7 @@ mod test { for i in 0..LOG_WIDTH_BOUND { assert_eq!( INT_TYPES[i as usize], - int_type(TypeArg::BoundedNat(i as u64)) + int_type(TypeArg::BoundedNat { n: i as u64 }) ) } } diff --git a/src/std_extensions/collections.rs b/src/std_extensions/collections.rs index b311b2c01..eed5e46ee 100644 --- a/src/std_extensions/collections.rs +++ b/src/std_extensions/collections.rs @@ -44,7 +44,7 @@ impl CustomConst for ListValue { .map_err(|_| error())?; // constant can only hold classic type. - let [TypeArg::Type(t)] = typ.args() else { + let [TypeArg::Type { ty: t }] = typ.args() else { return Err(error()); }; @@ -118,7 +118,7 @@ fn get_type(name: &str) -> &TypeDef { fn list_types(args: &[TypeArg]) -> Result<(Type, Type), SignatureError> { let list_custom_type = get_type(&LIST_TYPENAME).instantiate_concrete(args)?; - let [TypeArg::Type(element_type)] = args else { + let [TypeArg::Type { ty: element_type }] = args else { panic!("should be checked by def.") }; @@ -156,11 +156,11 @@ mod test { let list_def = r.get_type(&LIST_TYPENAME).unwrap(); let list_type = list_def - .instantiate_concrete([TypeArg::Type(USIZE_T)]) + .instantiate_concrete([TypeArg::Type { ty: USIZE_T }]) .unwrap(); assert!(list_def - .instantiate_concrete([TypeArg::BoundedNat(3)]) + .instantiate_concrete([TypeArg::BoundedNat { n: 3 }]) .is_err()); list_def.check_custom(&list_type).unwrap(); @@ -175,12 +175,12 @@ mod test { #[test] fn test_list_ops() { let pop_sig = get_op(&POP_NAME) - .compute_signature(&[TypeArg::Type(QB_T)]) + .compute_signature(&[TypeArg::Type { ty: QB_T }]) .unwrap(); let list_type = Type::new_extension(CustomType::new( LIST_TYPENAME, - vec![TypeArg::Type(QB_T)], + vec![TypeArg::Type { ty: QB_T }], EXTENSION_NAME, TypeBound::Any, )); @@ -191,12 +191,12 @@ mod test { assert_eq!(pop_sig.output(), &both_row); let push_sig = get_op(&PUSH_NAME) - .compute_signature(&[TypeArg::Type(FLOAT64_TYPE)]) + .compute_signature(&[TypeArg::Type { ty: FLOAT64_TYPE }]) .unwrap(); let list_type = Type::new_extension(CustomType::new( LIST_TYPENAME, - vec![TypeArg::Type(FLOAT64_TYPE)], + vec![TypeArg::Type { ty: FLOAT64_TYPE }], EXTENSION_NAME, TypeBound::Copyable, )); diff --git a/src/std_extensions/logic.rs b/src/std_extensions/logic.rs index 178c9504d..66dbbcfc6 100644 --- a/src/std_extensions/logic.rs +++ b/src/std_extensions/logic.rs @@ -50,7 +50,7 @@ fn extension() -> Extension { |arg_values: &[TypeArg]| { let a = arg_values.iter().exactly_one().unwrap(); let n: u64 = match a { - TypeArg::BoundedNat(n) => *n, + TypeArg::BoundedNat { n } => *n, _ => { return Err(TypeArgError::TypeMismatch { arg: a.clone(), @@ -75,7 +75,7 @@ fn extension() -> Extension { |arg_values: &[TypeArg]| { let a = arg_values.iter().exactly_one().unwrap(); let n: u64 = match a { - TypeArg::BoundedNat(n) => *n, + TypeArg::BoundedNat { n } => *n, _ => { return Err(TypeArgError::TypeMismatch { arg: a.clone(), @@ -134,7 +134,7 @@ pub(crate) mod test { /// Generate a logic extension and "and" operation over [`crate::prelude::BOOL_T`] pub(crate) fn and_op() -> LeafOp { EXTENSION - .instantiate_extension_op(AND_NAME, [TypeArg::BoundedNat(2)]) + .instantiate_extension_op(AND_NAME, [TypeArg::BoundedNat { n: 2 }]) .unwrap() .into() } diff --git a/src/types/check.rs b/src/types/check.rs index 3f5ecbae3..d316cbd55 100644 --- a/src/types/check.rs +++ b/src/types/check.rs @@ -55,11 +55,11 @@ impl PrimType { } match (self, val) { - (PrimType::Extension(e), PrimValue::Extension(e_val)) => { + (PrimType::Extension(e), PrimValue::Extension { c: e_val }) => { e_val.0.check_custom_type(e)?; Ok(()) } - (PrimType::Function(t), PrimValue::Function(v)) + (PrimType::Function(t), PrimValue::Function { hugr: v }) if Some(t.as_ref()) == v.get_function_type() => { // exact signature equality, in future this may need to be @@ -68,7 +68,7 @@ impl PrimType { } _ => Err(ConstTypeError::ValueCheckFail( Type::new(TypeEnum::Prim(self.clone())), - Value::Prim(val.clone()), + Value::Prim { val: val.clone() }, )), } } @@ -82,8 +82,8 @@ impl Type { /// This function will return an error if there is a type check error. pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> { match (&self.0, val) { - (TypeEnum::Prim(p), Value::Prim(p_v)) => p.check_type(p_v), - (TypeEnum::Tuple(t), Value::Tuple(t_v)) => { + (TypeEnum::Prim(p), Value::Prim { val: p_v }) => p.check_type(p_v), + (TypeEnum::Tuple(t), Value::Tuple { vs: t_v }) => { if t.len() != t_v.len() { return Err(ConstTypeError::TupleWrongLength); } @@ -92,7 +92,7 @@ impl Type { .try_for_each(|(elem, ty)| ty.check_type(elem)) .map_err(|_| ConstTypeError::ValueCheckFail(self.clone(), val.clone())) } - (TypeEnum::Sum(sum), Value::Sum(tag, value)) => sum + (TypeEnum::Sum(sum), Value::Sum { tag, value }) => sum .get_variant(*tag) .ok_or(ConstTypeError::InvalidSumTag)? .check_type(value), diff --git a/src/types/type_param.rs b/src/types/type_param.rs index 6e4b8d7d7..f20ab3eaf 100644 --- a/src/types/type_param.rs +++ b/src/types/type_param.rs @@ -66,18 +66,34 @@ impl TypeParam { /// A statically-known argument value to an operation. #[derive(Clone, Debug, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[non_exhaustive] +#[serde(tag = "tya")] pub enum TypeArg { /// Where the (Type/Op)Def declares that an argument is a [TypeParam::Type] - Type(Type), + Type { + #[allow(missing_docs)] + ty: Type, + }, /// Instance of [TypeParam::BoundedNat]. 64-bit unsigned integer. - BoundedNat(u64), + BoundedNat { + #[allow(missing_docs)] + n: u64, + }, ///Instance of [TypeParam::Opaque] An opaque value, stored as serialized blob. - Opaque(CustomTypeArg), + Opaque { + #[allow(missing_docs)] + arg: CustomTypeArg, + }, /// Instance of [TypeParam::List] or [TypeParam::Tuple], defined by a /// sequence of arguments. - Sequence(Vec), + Sequence { + #[allow(missing_docs)] + args: Vec, + }, /// Instance of [TypeParam::Extensions], providing the extension ids. - Extensions(ExtensionSet), + Extensions { + #[allow(missing_docs)] + es: ExtensionSet, + }, } impl TypeArg { @@ -86,15 +102,17 @@ impl TypeArg { extension_registry: &ExtensionRegistry, ) -> Result<(), SignatureError> { match self { - TypeArg::Type(ty) => ty.validate(extension_registry), - TypeArg::BoundedNat(_) => Ok(()), - TypeArg::Opaque(custarg) => { + TypeArg::Type { ty } => ty.validate(extension_registry), + TypeArg::BoundedNat { .. } => Ok(()), + TypeArg::Opaque { arg: custarg } => { // We could also add a facility to Extension to validate that the constant *value* // here is a valid instance of the type. custarg.typ.validate(extension_registry) } - TypeArg::Sequence(args) => args.iter().try_for_each(|a| a.validate(extension_registry)), - TypeArg::Extensions(_) => Ok(()), + TypeArg::Sequence { args } => { + args.iter().try_for_each(|a| a.validate(extension_registry)) + } + TypeArg::Extensions { es: _ } => Ok(()), } } } @@ -125,13 +143,15 @@ impl CustomTypeArg { /// Checks a [TypeArg] is as expected for a [TypeParam] pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgError> { match (arg, param) { - (TypeArg::Type(t), TypeParam::Type(bound)) if bound.contains(t.least_upper_bound()) => { + (TypeArg::Type { ty: t }, TypeParam::Type(bound)) + if bound.contains(t.least_upper_bound()) => + { Ok(()) } - (TypeArg::Sequence(items), TypeParam::List(param)) => { + (TypeArg::Sequence { args: items }, TypeParam::List(param)) => { items.iter().try_for_each(|arg| check_type_arg(arg, param)) } - (TypeArg::Sequence(items), TypeParam::Tuple(types)) => { + (TypeArg::Sequence { args: items }, TypeParam::Tuple(types)) => { if items.len() != types.len() { Err(TypeArgError::WrongNumberTuple(items.len(), types.len())) } else { @@ -141,16 +161,18 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr .try_for_each(|(arg, param)| check_type_arg(arg, param)) } } - (TypeArg::BoundedNat(val), TypeParam::BoundedNat(bound)) if bound.valid_value(*val) => { + (TypeArg::BoundedNat { n: val }, TypeParam::BoundedNat(bound)) + if bound.valid_value(*val) => + { Ok(()) } - (TypeArg::Opaque(arg), TypeParam::Opaque(param)) + (TypeArg::Opaque { arg }, TypeParam::Opaque(param)) if param.bound() == TypeBound::Eq && &arg.typ == param => { Ok(()) } - (TypeArg::Extensions(_), TypeParam::Extensions) => Ok(()), + (TypeArg::Extensions { .. }, TypeParam::Extensions) => Ok(()), _ => Err(TypeArgError::TypeMismatch { arg: arg.clone(), param: param.clone(), diff --git a/src/values.rs b/src/values.rs index 10d5ed387..a7d47eed7 100644 --- a/src/values.rs +++ b/src/values.rs @@ -15,21 +15,28 @@ use crate::types::{CustomCheckFailure, CustomType}; /// A constant value of a primitive (or leaf) type. #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "pv")] pub enum PrimValue { /// An extension constant value, that can check it is of a given [CustomType]. /// // Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808 - Extension((Box,)), + Extension { + #[allow(missing_docs)] + c: (Box,), + }, /// A higher-order function value. // TODO use a root parametrised hugr, e.g. Hugr. - Function(Box), + Function { + #[allow(missing_docs)] + hugr: Box, + }, } impl PrimValue { fn name(&self) -> String { match self { - PrimValue::Extension(e) => format!("const:custom:{}", e.0.name()), - PrimValue::Function(h) => { + PrimValue::Extension { c: e } => format!("const:custom:{}", e.0.name()), + PrimValue::Function { hugr: h } => { let Some(t) = h.get_function_type() else { panic!("HUGR root node isn't a valid function parent."); }; @@ -42,26 +49,40 @@ impl PrimValue { /// A value that can be stored as a static constant. Representing core types and /// extension types. #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(tag = "v")] pub enum Value { /// A primitive (non-container) value. - Prim(PrimValue), + Prim { + #[allow(missing_docs)] + val: PrimValue, + }, /// A tuple - Tuple(Vec), + Tuple { + #[allow(missing_docs)] + vs: Vec, + }, /// A Sum variant -- for any Sum type where this value meets /// the type of the variant indicated by the tag - Sum(usize, Box), // Tag and value + Sum { + /// The tag index of the variant + tag: usize, + /// The value of the variant + value: Box, + }, } impl Value { /// Returns the name of this [`Value`]. pub fn name(&self) -> String { match self { - Value::Prim(p) => p.name(), - Value::Tuple(vals) => { + Value::Prim { val: p } => p.name(), + Value::Tuple { vs: vals } => { let names: Vec<_> = vals.iter().map(Value::name).collect(); format!("const:seq:{{{}}}", names.join(", ")) } - Value::Sum(tag, val) => format!("const:sum:{{tag:{tag}, val:{}}}", val.name()), + Value::Sum { tag, value: val } => { + format!("const:sum:{{tag:{tag}, val:{}}}", val.name()) + } } } @@ -72,7 +93,7 @@ impl Value { /// Constant unit type (empty Tuple). pub const fn unit() -> Self { - Self::Tuple(vec![]) + Self::Tuple { vs: vec![] } } /// Constant Sum over units, used as predicates. @@ -87,17 +108,24 @@ impl Value { /// Tuple of values. pub fn tuple(items: impl IntoIterator) -> Self { - Self::Tuple(items.into_iter().collect()) + Self::Tuple { + vs: items.into_iter().collect(), + } } /// Sum value (could be of any compatible type, e.g. a predicate) pub fn sum(tag: usize, value: Value) -> Self { - Self::Sum(tag, Box::new(value)) + Self::Sum { + tag, + value: Box::new(value), + } } /// New custom value (of type that implements [`CustomConst`]). pub fn custom(c: C) -> Self { - Self::Prim(PrimValue::Extension((Box::new(c),))) + Self::Prim { + val: PrimValue::Extension { c: (Box::new(c),) }, + } } } @@ -246,7 +274,11 @@ pub(crate) mod test { #[rstest] fn function_value(simple_dfg_hugr: Hugr) { - let v = Value::Prim(PrimValue::Function(Box::new(simple_dfg_hugr))); + let v = Value::Prim { + val: PrimValue::Function { + hugr: Box::new(simple_dfg_hugr), + }, + }; let correct_type = Type::new_function(FunctionType::new_linear(type_row![ crate::extension::prelude::BOOL_T