From 77bec46aecdb32710b9b279b7ddbd327e43bc651 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Mon, 13 Nov 2023 16:30:55 +0000 Subject: [PATCH] refactor!: Closes Flatten `Prim(Type/Value)` in to parent enum #665 BREAKING_CHANGES: In serialization, extension and function values no longer wrapped by "pv". --- src/types.rs | 50 ++++++++++++++++++++++--------------- src/types/check.rs | 40 ++++++------------------------ src/types/primitive.rs | 38 ---------------------------- src/types/serialize.rs | 11 +++------ src/values.rs | 56 +++++++++++------------------------------- 5 files changed, 57 insertions(+), 138 deletions(-) delete mode 100644 src/types/primitive.rs diff --git a/src/types.rs b/src/types.rs index e2452d2b3..5d15f79be 100644 --- a/src/types.rs +++ b/src/types.rs @@ -3,7 +3,7 @@ mod check; pub mod custom; mod poly_func; -mod primitive; +// mod primitive; mod serialize; mod signature; pub mod type_param; @@ -26,7 +26,6 @@ use crate::ops::AliasDecl; use crate::type_row; use std::fmt::Debug; -pub use self::primitive::PrimType; use self::type_param::TypeParam; #[cfg(feature = "pyo3")] @@ -155,8 +154,20 @@ impl From for Type { #[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)] /// Core types: primitive (leaf), tuple (product) or sum (co-product). pub enum TypeEnum { + // TODO optimise with Box ? + // or some static version of this? #[allow(missing_docs)] - Prim(PrimType), + Extension(CustomType), + #[allow(missing_docs)] + #[display(fmt = "Alias({})", "_0.name()")] + Alias(AliasDecl), + #[allow(missing_docs)] + #[display(fmt = "Function({})", "_0")] + Function(Box), + // DeBruijn index, and cache of TypeBound (checked in validation) + #[allow(missing_docs)] + #[display(fmt = "Variable({})", _0)] + Variable(usize, TypeBound), #[allow(missing_docs)] #[display(fmt = "Tuple({})", "_0")] Tuple(TypeRow), @@ -168,7 +179,10 @@ impl TypeEnum { /// The smallest type bound that covers the whole type. fn least_upper_bound(&self) -> TypeBound { match self { - TypeEnum::Prim(p) => p.bound(), + TypeEnum::Extension(c) => c.bound(), + TypeEnum::Alias(a) => a.bound, + TypeEnum::Function(_) => TypeBound::Copyable, + TypeEnum::Variable(_, b) => *b, TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Eq, TypeEnum::Sum(SumType::General { row }) => { least_upper_bound(row.iter().map(Type::least_upper_bound)) @@ -216,7 +230,7 @@ impl Type { /// Initialize a new function type. pub fn new_function(fun_ty: impl Into) -> Self { - Self::new(TypeEnum::Prim(PrimType::Function(Box::new(fun_ty.into())))) + Self::new(TypeEnum::Function(Box::new(fun_ty.into()))) } /// Initialize a new tuple type by providing the elements. @@ -235,12 +249,12 @@ impl Type { // TODO remove? Extensions/TypeDefs should just provide `Type` directly pub const fn new_extension(opaque: CustomType) -> Self { let bound = opaque.bound(); - Type(TypeEnum::Prim(PrimType::Extension(opaque)), bound) + Type(TypeEnum::Extension(opaque), bound) } /// Initialize a new alias. pub fn new_alias(alias: AliasDecl) -> Self { - Self::new(TypeEnum::Prim(PrimType::Alias(alias))) + Self::new(TypeEnum::Alias(alias)) } fn new(type_e: TypeEnum) -> Self { @@ -267,7 +281,7 @@ impl Type { /// For use in type schemes only: `bound` must match that with which the /// variable was declared (i.e. as a [TypeParam::Type]`(bound)`). pub fn new_var_use(idx: usize, bound: TypeBound) -> Self { - Self(TypeEnum::Prim(PrimType::Variable(idx, bound)), bound) + Self(TypeEnum::Variable(idx, bound), bound) } /// Report the least upper TypeBound, if there is one. @@ -307,12 +321,10 @@ impl Type { .iter() .try_for_each(|t| t.validate(extension_registry, var_decls)), TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there - TypeEnum::Prim(PrimType::Alias(_)) => Ok(()), - TypeEnum::Prim(PrimType::Extension(custy)) => { - custy.validate(extension_registry, var_decls) - } - TypeEnum::Prim(PrimType::Function(ft)) => ft.validate(extension_registry, var_decls), - TypeEnum::Prim(PrimType::Variable(idx, bound)) => { + TypeEnum::Alias(_) => Ok(()), + TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls), + TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls), + TypeEnum::Variable(idx, bound) => { check_typevar_decl(var_decls, *idx, &TypeParam::Type(*bound)) } } @@ -320,12 +332,10 @@ impl Type { pub(crate) fn substitute(&self, t: &impl Substitution) -> Self { match &self.0 { - TypeEnum::Prim(PrimType::Alias(_)) | TypeEnum::Sum(SumType::Unit { .. }) => { - self.clone() - } - TypeEnum::Prim(PrimType::Variable(idx, bound)) => t.apply_typevar(*idx, *bound), - TypeEnum::Prim(PrimType::Extension(cty)) => Type::new_extension(cty.substitute(t)), - TypeEnum::Prim(PrimType::Function(bf)) => Type::new_function(bf.substitute(t)), + TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => self.clone(), + TypeEnum::Variable(idx, bound) => t.apply_typevar(*idx, *bound), + TypeEnum::Extension(cty) => Type::new_extension(cty.substitute(t)), + TypeEnum::Function(bf) => Type::new_function(bf.substitute(t)), TypeEnum::Tuple(elems) => Type::new_tuple(subst_row(elems, t)), TypeEnum::Sum(SumType::General { row }) => Type::new_sum(subst_row(row, t)), } diff --git a/src/types/check.rs b/src/types/check.rs index 496cc2a2d..31c8d1962 100644 --- a/src/types/check.rs +++ b/src/types/check.rs @@ -1,12 +1,9 @@ //! Logic for checking values against types. use thiserror::Error; -use crate::{ - values::{PrimValue, Value}, - HugrView, -}; +use crate::{values::Value, HugrView}; -use super::{primitive::PrimType, CustomType, Type, TypeEnum}; +use super::{CustomType, Type, TypeEnum}; /// Struct for custom type check fails. #[derive(Clone, Debug, PartialEq, Eq, Error)] @@ -48,46 +45,25 @@ pub enum ConstTypeError { CustomCheckFail(#[from] CustomCheckFailure), } -impl PrimType { - /// Check that a [`PrimValue`] is a valid instance of this [`PrimType`]. +impl Type { + /// Check that a [`Value`] is a valid instance of this [`Type`]. /// /// # Errors /// /// This function will return an error if there is a type check error. - pub fn check_type(&self, val: &PrimValue) -> Result<(), ConstTypeError> { - if let PrimType::Alias(alias) = self { - return Err(ConstTypeError::NoAliases(alias.name().to_string())); - } - - match (self, val) { - (PrimType::Extension(e), PrimValue::Extension { c: e_val }) => { + pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> { + match (&self.0, val) { + (TypeEnum::Extension(e), Value::Extension { c: e_val }) => { e_val.0.check_custom_type(e)?; Ok(()) } - (PrimType::Function(t), PrimValue::Function { hugr: v }) + (TypeEnum::Function(t), Value::Function { hugr: v }) if v.get_function_type().is_some_and(|f| &**t == f) => { // exact signature equality, in future this may need to be // relaxed to be compatibility checks between the signatures. Ok(()) } - _ => Err(ConstTypeError::ValueCheckFail( - Type::new(TypeEnum::Prim(self.clone())), - Value::Prim { val: val.clone() }, - )), - } - } -} - -impl Type { - /// Check that a [`Value`] is a valid instance of this [`Type`]. - /// - /// # Errors - /// - /// This function will return an error if there is a type check error. - pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> { - match (&self.0, val) { - (TypeEnum::Prim(p), Value::Prim { val: p_v }) => p.check_type(p_v), (TypeEnum::Tuple(t), Value::Tuple { vs: t_v }) => { if t.len() != t_v.len() { return Err(ConstTypeError::TupleWrongLength); diff --git a/src/types/primitive.rs b/src/types/primitive.rs deleted file mode 100644 index fe34151e5..000000000 --- a/src/types/primitive.rs +++ /dev/null @@ -1,38 +0,0 @@ -//! Primitive types which are leaves of the type tree - -use crate::ops::AliasDecl; - -use super::{CustomType, PolyFuncType, TypeBound}; - -#[derive( - Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize, -)] -/// Representation of a Primitive type, i.e. neither a Sum nor a Tuple. -pub enum PrimType { - // TODO optimise with Box ? - // or some static version of this? - #[allow(missing_docs)] - Extension(CustomType), - #[allow(missing_docs)] - #[display(fmt = "Alias({})", "_0.name()")] - Alias(AliasDecl), - #[allow(missing_docs)] - #[display(fmt = "Function({})", "_0")] - Function(Box), - // DeBruijn index, and cache of TypeBound (checked in validation) - #[allow(missing_docs)] - #[display(fmt = "Variable({})", _0)] - Variable(usize, TypeBound), -} - -impl PrimType { - /// Returns the bound of this [`PrimType`]. - pub fn bound(&self) -> TypeBound { - match self { - PrimType::Extension(c) => c.bound(), - PrimType::Alias(a) => a.bound, - PrimType::Function(_) => TypeBound::Copyable, - PrimType::Variable(_, b) => *b, - } - } -} diff --git a/src/types/serialize.rs b/src/types/serialize.rs index 34ad609ed..4febe2238 100644 --- a/src/types/serialize.rs +++ b/src/types/serialize.rs @@ -4,7 +4,6 @@ use super::custom::CustomType; use crate::extension::prelude::{array_type, QB_T, USIZE_T}; use crate::ops::AliasDecl; -use crate::types::primitive::PrimType; #[derive(serde::Serialize, serde::Deserialize, Clone, Debug)] #[serde(tag = "t")] @@ -31,12 +30,10 @@ impl From for SerSimpleType { // TODO short circuiting for array. let Type(value, _) = value; match value { - TypeEnum::Prim(t) => match t { - PrimType::Extension(c) => SerSimpleType::Opaque(c), - PrimType::Alias(a) => SerSimpleType::Alias(a), - PrimType::Function(sig) => SerSimpleType::G(sig), - PrimType::Variable(i, b) => SerSimpleType::V { i, b }, - }, + TypeEnum::Extension(c) => SerSimpleType::Opaque(c), + TypeEnum::Alias(a) => SerSimpleType::Alias(a), + TypeEnum::Function(sig) => SerSimpleType::G(sig), + TypeEnum::Variable(i, b) => SerSimpleType::V { i, b }, TypeEnum::Sum(sum) => SerSimpleType::Sum(sum), TypeEnum::Tuple(inner) => SerSimpleType::Tuple { inner }, } diff --git a/src/values.rs b/src/values.rs index 808cf883e..3ace6fe5a 100644 --- a/src/values.rs +++ b/src/values.rs @@ -13,10 +13,11 @@ use crate::{Hugr, HugrView}; use crate::types::{CustomCheckFailure, CustomType}; -/// A constant value of a primitive (or leaf) type. +/// A value that can be stored as a static constant. Representing core types and +/// extension types. #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(tag = "pv")] -pub enum PrimValue { +#[serde(tag = "v")] +pub enum Value { /// An extension constant value, that can check it is of a given [CustomType]. /// // Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808 @@ -30,32 +31,6 @@ pub enum PrimValue { #[allow(missing_docs)] hugr: Box, }, -} - -impl PrimValue { - fn name(&self) -> String { - match self { - PrimValue::Extension { c: e } => format!("const:custom:{}", e.0.name()), - PrimValue::Function { hugr: h } => { - let Some(t) = h.get_function_type() else { - panic!("HUGR root node isn't a valid function parent."); - }; - format!("const:function:[{}]", t) - } - } - } -} - -/// A value that can be stored as a static constant. Representing core types and -/// extension types. -#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] -#[serde(tag = "v")] -pub enum Value { - /// A primitive (non-container) value. - Prim { - #[allow(missing_docs)] - val: PrimValue, - }, /// A tuple Tuple { #[allow(missing_docs)] @@ -75,7 +50,13 @@ impl Value { /// Returns the name of this [`Value`]. pub fn name(&self) -> String { match self { - Value::Prim { val: p } => p.name(), + Value::Extension { c: e } => format!("const:custom:{}", e.0.name()), + Value::Function { hugr: h } => { + let Some(t) = h.get_function_type() else { + panic!("HUGR root node isn't a valid function parent."); + }; + format!("const:function:[{}]", t) + } Value::Tuple { vs: vals } => { let names: Vec<_> = vals.iter().map(Value::name).collect(); format!("const:seq:{{{}}}", names.join(", ")) @@ -123,17 +104,12 @@ impl Value { /// New custom value (of type that implements [`CustomConst`]). pub fn custom(c: C) -> Self { - Self::Prim { - val: PrimValue::Extension { c: (Box::new(c),) }, - } + Self::Extension { c: (Box::new(c),) } } /// For a Const holding a CustomConst, extract the CustomConst by downcasting. pub fn get_custom_value(&self) -> Option<&T> { - if let Value::Prim { - val: PrimValue::Extension { c: (custom,) }, - } = self - { + if let Value::Extension { c: (custom,) } = self { custom.downcast_ref() } else { None @@ -286,10 +262,8 @@ pub(crate) mod test { #[rstest] fn function_value(simple_dfg_hugr: Hugr) { - let v = Value::Prim { - val: PrimValue::Function { - hugr: Box::new(simple_dfg_hugr), - }, + let v = Value::Function { + hugr: Box::new(simple_dfg_hugr), }; let correct_type = Type::new_function(FunctionType::new_linear(type_row![