Skip to content

Commit

Permalink
refactor!: Closes Flatten Prim(Type/Value) in to parent enum #665
Browse files Browse the repository at this point in the history
BREAKING_CHANGES: In serialization, extension and function values no longer
wrapped by "pv".
  • Loading branch information
ss2165 committed Nov 13, 2023
1 parent aa68c49 commit 3551e2d
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 140 deletions.
51 changes: 30 additions & 21 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
mod check;
pub mod custom;
mod poly_func;
mod primitive;
mod serialize;
mod signature;
pub mod type_param;
Expand All @@ -26,7 +25,6 @@ use crate::ops::AliasDecl;
use crate::type_row;
use std::fmt::Debug;

pub use self::primitive::PrimType;
use self::type_param::TypeParam;

#[cfg(feature = "pyo3")]
Expand Down Expand Up @@ -153,10 +151,22 @@ impl From<SumType> for Type {
}

#[derive(Clone, PartialEq, Debug, Eq, derive_more::Display)]
/// Core types: primitive (leaf), tuple (product) or sum (co-product).
/// Core types
pub enum TypeEnum {
// TODO optimise with Box<CustomType> ?
// or some static version of this?
#[allow(missing_docs)]
Prim(PrimType),
Extension(CustomType),
#[allow(missing_docs)]
#[display(fmt = "Alias({})", "_0.name()")]
Alias(AliasDecl),
#[allow(missing_docs)]
#[display(fmt = "Function({})", "_0")]
Function(Box<PolyFuncType>),
// DeBruijn index, and cache of TypeBound (checked in validation)
#[allow(missing_docs)]
#[display(fmt = "Variable({})", _0)]
Variable(usize, TypeBound),
#[allow(missing_docs)]
#[display(fmt = "Tuple({})", "_0")]
Tuple(TypeRow),
Expand All @@ -168,7 +178,10 @@ impl TypeEnum {
/// The smallest type bound that covers the whole type.
fn least_upper_bound(&self) -> TypeBound {
match self {
TypeEnum::Prim(p) => p.bound(),
TypeEnum::Extension(c) => c.bound(),
TypeEnum::Alias(a) => a.bound,
TypeEnum::Function(_) => TypeBound::Copyable,
TypeEnum::Variable(_, b) => *b,
TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Eq,
TypeEnum::Sum(SumType::General { row }) => {
least_upper_bound(row.iter().map(Type::least_upper_bound))
Expand Down Expand Up @@ -216,7 +229,7 @@ impl Type {

/// Initialize a new function type.
pub fn new_function(fun_ty: impl Into<PolyFuncType>) -> Self {
Self::new(TypeEnum::Prim(PrimType::Function(Box::new(fun_ty.into()))))
Self::new(TypeEnum::Function(Box::new(fun_ty.into())))
}

/// Initialize a new tuple type by providing the elements.
Expand All @@ -235,12 +248,12 @@ impl Type {
// TODO remove? Extensions/TypeDefs should just provide `Type` directly
pub const fn new_extension(opaque: CustomType) -> Self {
let bound = opaque.bound();
Type(TypeEnum::Prim(PrimType::Extension(opaque)), bound)
Type(TypeEnum::Extension(opaque), bound)
}

/// Initialize a new alias.
pub fn new_alias(alias: AliasDecl) -> Self {
Self::new(TypeEnum::Prim(PrimType::Alias(alias)))
Self::new(TypeEnum::Alias(alias))
}

fn new(type_e: TypeEnum) -> Self {
Expand All @@ -267,7 +280,7 @@ impl Type {
/// For use in type schemes only: `bound` must match that with which the
/// variable was declared (i.e. as a [TypeParam::Type]`(bound)`).
pub fn new_var_use(idx: usize, bound: TypeBound) -> Self {
Self(TypeEnum::Prim(PrimType::Variable(idx, bound)), bound)
Self(TypeEnum::Variable(idx, bound), bound)
}

/// Report the least upper TypeBound, if there is one.
Expand Down Expand Up @@ -307,25 +320,21 @@ impl Type {
.iter()
.try_for_each(|t| t.validate(extension_registry, var_decls)),
TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there
TypeEnum::Prim(PrimType::Alias(_)) => Ok(()),
TypeEnum::Prim(PrimType::Extension(custy)) => {
custy.validate(extension_registry, var_decls)
}
TypeEnum::Prim(PrimType::Function(ft)) => ft.validate(extension_registry, var_decls),
TypeEnum::Prim(PrimType::Variable(idx, bound)) => {
TypeEnum::Alias(_) => Ok(()),
TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls),
TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls),
TypeEnum::Variable(idx, bound) => {
check_typevar_decl(var_decls, *idx, &TypeParam::Type(*bound))
}
}
}

pub(crate) fn substitute(&self, t: &impl Substitution) -> Self {
match &self.0 {
TypeEnum::Prim(PrimType::Alias(_)) | TypeEnum::Sum(SumType::Unit { .. }) => {
self.clone()
}
TypeEnum::Prim(PrimType::Variable(idx, bound)) => t.apply_typevar(*idx, *bound),
TypeEnum::Prim(PrimType::Extension(cty)) => Type::new_extension(cty.substitute(t)),
TypeEnum::Prim(PrimType::Function(bf)) => Type::new_function(bf.substitute(t)),
TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => self.clone(),
TypeEnum::Variable(idx, bound) => t.apply_typevar(*idx, *bound),
TypeEnum::Extension(cty) => Type::new_extension(cty.substitute(t)),
TypeEnum::Function(bf) => Type::new_function(bf.substitute(t)),
TypeEnum::Tuple(elems) => Type::new_tuple(subst_row(elems, t)),
TypeEnum::Sum(SumType::General { row }) => Type::new_sum(subst_row(row, t)),
}
Expand Down
40 changes: 8 additions & 32 deletions src/types/check.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
//! Logic for checking values against types.
use thiserror::Error;

use crate::{
values::{PrimValue, Value},
HugrView,
};
use crate::{values::Value, HugrView};

use super::{primitive::PrimType, CustomType, Type, TypeEnum};
use super::{CustomType, Type, TypeEnum};

/// Struct for custom type check fails.
#[derive(Clone, Debug, PartialEq, Eq, Error)]
Expand Down Expand Up @@ -48,46 +45,25 @@ pub enum ConstTypeError {
CustomCheckFail(#[from] CustomCheckFailure),
}

impl PrimType {
/// Check that a [`PrimValue`] is a valid instance of this [`PrimType`].
impl Type {
/// Check that a [`Value`] is a valid instance of this [`Type`].
///
/// # Errors
///
/// This function will return an error if there is a type check error.
pub fn check_type(&self, val: &PrimValue) -> Result<(), ConstTypeError> {
if let PrimType::Alias(alias) = self {
return Err(ConstTypeError::NoAliases(alias.name().to_string()));
}

match (self, val) {
(PrimType::Extension(e), PrimValue::Extension { c: e_val }) => {
pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> {
match (&self.0, val) {
(TypeEnum::Extension(e), Value::Extension { c: e_val }) => {
e_val.0.check_custom_type(e)?;
Ok(())
}
(PrimType::Function(t), PrimValue::Function { hugr: v })
(TypeEnum::Function(t), Value::Function { hugr: v })
if v.get_function_type().is_some_and(|f| &**t == f) =>
{
// exact signature equality, in future this may need to be
// relaxed to be compatibility checks between the signatures.
Ok(())
}
_ => Err(ConstTypeError::ValueCheckFail(
Type::new(TypeEnum::Prim(self.clone())),
Value::Prim { val: val.clone() },
)),
}
}
}

impl Type {
/// Check that a [`Value`] is a valid instance of this [`Type`].
///
/// # Errors
///
/// This function will return an error if there is a type check error.
pub fn check_type(&self, val: &Value) -> Result<(), ConstTypeError> {
match (&self.0, val) {
(TypeEnum::Prim(p), Value::Prim { val: p_v }) => p.check_type(p_v),
(TypeEnum::Tuple(t), Value::Tuple { vs: t_v }) => {
if t.len() != t_v.len() {
return Err(ConstTypeError::TupleWrongLength);
Expand Down
2 changes: 1 addition & 1 deletion src/types/poly_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::{FunctionType, Substitution};
/// A polymorphic function type, e.g. of a [Graph], or perhaps an [OpDef].
/// (Nodes/operations in the Hugr are not polymorphic.)
///
/// [Graph]: crate::values::PrimValue::Function
/// [Graph]: crate::values::Value::Function
/// [OpDef]: crate::extension::OpDef
#[derive(
Clone, PartialEq, Debug, Eq, derive_more::Display, serde::Serialize, serde::Deserialize,
Expand Down
38 changes: 0 additions & 38 deletions src/types/primitive.rs

This file was deleted.

11 changes: 4 additions & 7 deletions src/types/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use super::custom::CustomType;

use crate::extension::prelude::{array_type, QB_T, USIZE_T};
use crate::ops::AliasDecl;
use crate::types::primitive::PrimType;

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(tag = "t")]
Expand All @@ -31,12 +30,10 @@ impl From<Type> for SerSimpleType {
// TODO short circuiting for array.
let Type(value, _) = value;
match value {
TypeEnum::Prim(t) => match t {
PrimType::Extension(c) => SerSimpleType::Opaque(c),
PrimType::Alias(a) => SerSimpleType::Alias(a),
PrimType::Function(sig) => SerSimpleType::G(sig),
PrimType::Variable(i, b) => SerSimpleType::V { i, b },
},
TypeEnum::Extension(c) => SerSimpleType::Opaque(c),
TypeEnum::Alias(a) => SerSimpleType::Alias(a),
TypeEnum::Function(sig) => SerSimpleType::G(sig),
TypeEnum::Variable(i, b) => SerSimpleType::V { i, b },
TypeEnum::Sum(sum) => SerSimpleType::Sum(sum),
TypeEnum::Tuple(inner) => SerSimpleType::Tuple { inner },
}
Expand Down
56 changes: 15 additions & 41 deletions src/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ use crate::{Hugr, HugrView};

use crate::types::{CustomCheckFailure, CustomType};

/// A constant value of a primitive (or leaf) type.
/// A value that can be stored as a static constant. Representing core types and
/// extension types.
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "pv")]
pub enum PrimValue {
#[serde(tag = "v")]
pub enum Value {
/// An extension constant value, that can check it is of a given [CustomType].
///
// Note: the extra level of tupling is to avoid https://github.com/rust-lang/rust/issues/78808
Expand All @@ -30,32 +31,6 @@ pub enum PrimValue {
#[allow(missing_docs)]
hugr: Box<Hugr>,
},
}

impl PrimValue {
fn name(&self) -> String {
match self {
PrimValue::Extension { c: e } => format!("const:custom:{}", e.0.name()),
PrimValue::Function { hugr: h } => {
let Some(t) = h.get_function_type() else {
panic!("HUGR root node isn't a valid function parent.");
};
format!("const:function:[{}]", t)
}
}
}
}

/// A value that can be stored as a static constant. Representing core types and
/// extension types.
#[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(tag = "v")]
pub enum Value {
/// A primitive (non-container) value.
Prim {
#[allow(missing_docs)]
val: PrimValue,
},
/// A tuple
Tuple {
#[allow(missing_docs)]
Expand All @@ -75,7 +50,13 @@ impl Value {
/// Returns the name of this [`Value`].
pub fn name(&self) -> String {
match self {
Value::Prim { val: p } => p.name(),
Value::Extension { c: e } => format!("const:custom:{}", e.0.name()),
Value::Function { hugr: h } => {
let Some(t) = h.get_function_type() else {
panic!("HUGR root node isn't a valid function parent.");
};
format!("const:function:[{}]", t)
}
Value::Tuple { vs: vals } => {
let names: Vec<_> = vals.iter().map(Value::name).collect();
format!("const:seq:{{{}}}", names.join(", "))
Expand Down Expand Up @@ -123,17 +104,12 @@ impl Value {

/// New custom value (of type that implements [`CustomConst`]).
pub fn custom<C: CustomConst>(c: C) -> Self {
Self::Prim {
val: PrimValue::Extension { c: (Box::new(c),) },
}
Self::Extension { c: (Box::new(c),) }
}

/// For a Const holding a CustomConst, extract the CustomConst by downcasting.
pub fn get_custom_value<T: CustomConst>(&self) -> Option<&T> {
if let Value::Prim {
val: PrimValue::Extension { c: (custom,) },
} = self
{
if let Value::Extension { c: (custom,) } = self {
custom.downcast_ref()
} else {
None
Expand Down Expand Up @@ -286,10 +262,8 @@ pub(crate) mod test {

#[rstest]
fn function_value(simple_dfg_hugr: Hugr) {
let v = Value::Prim {
val: PrimValue::Function {
hugr: Box::new(simple_dfg_hugr),
},
let v = Value::Function {
hugr: Box::new(simple_dfg_hugr),
};

let correct_type = Type::new_function(FunctionType::new_linear(type_row![
Expand Down

0 comments on commit 3551e2d

Please sign in to comment.