Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SerSimpleType: use Vec not TypeRow #381

Merged
merged 3 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/types/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ pub trait PrimType: TypeRowElem + std::fmt::Debug + sealed::Sealed {
fn tag(&self) -> TypeTag;
}

impl TypeRowElem for SimpleType {}
impl TypeRowElem for ClassicType {}
impl TypeRowElem for HashableType {}

// sealed trait pattern to prevent users extending PrimType
mod sealed {
use super::{ClassicType, HashableType, SimpleType};
Expand Down
39 changes: 18 additions & 21 deletions src/types/simple/serialize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use super::HashableType;
use super::PrimType;
use super::TypeTag;

use itertools::Itertools;
use smol_str::SmolStr;

use super::super::custom::CustomType;
Expand All @@ -17,6 +18,7 @@ use super::SimpleType;
use super::super::AbstractSignature;

use crate::ops::constant::HugrIntWidthStore;
use crate::types::type_row::TypeRowElem;

#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
#[serde(tag = "t")]
Expand All @@ -31,11 +33,11 @@ pub(crate) enum SerSimpleType {
signature: Box<AbstractSignature>,
},
Tuple {
row: Box<TypeRow<SerSimpleType>>,
row: Vec<SerSimpleType>,
c: TypeTag,
},
Sum {
row: Box<TypeRow<SerSimpleType>>,
row: Vec<SerSimpleType>,
c: TypeTag,
},
Array {
Expand Down Expand Up @@ -80,15 +82,15 @@ where
fn from(value: Container<T>) -> Self {
match value {
Container::Sum(inner) => SerSimpleType::Sum {
row: Box::new(inner.map_into()),
row: inner.into_owned().into_iter().map_into().collect(),
c: T::TAG, // We could inspect inner.containing_tag(), but this should have been done already
},
Container::Tuple(inner) => SerSimpleType::Tuple {
row: Box::new(inner.map_into()),
row: inner.into_owned().into_iter().map_into().collect(),
c: T::TAG,
},
Container::Array(inner, len) => SerSimpleType::Array {
inner: box_convert(*inner),
inner: Box::new((*inner).into()),
len,
c: T::TAG,
},
Expand Down Expand Up @@ -132,19 +134,14 @@ impl From<SimpleType> for SerSimpleType {
}
}

pub(crate) fn box_convert_try<T, F>(value: T) -> Box<F>
where
T: TryInto<F>,
<T as TryInto<F>>::Error: std::fmt::Debug,
{
Box::new((value).try_into().unwrap())
}

pub(crate) fn box_convert<T, F>(value: T) -> Box<F>
where
T: Into<F>,
{
Box::new((value).into())
fn try_convert_list<T: TryInto<T2>, T2: TypeRowElem>(
values: Vec<T>,
) -> Result<TypeRow<T2>, T::Error> {
let vals = values
.into_iter()
.map(T::try_into)
.collect::<Result<Vec<T2>, T::Error>>()?;
Ok(TypeRow::from(vals))
}

macro_rules! handle_container {
Expand All @@ -166,13 +163,13 @@ impl From<SerSimpleType> for SimpleType {
SerSimpleType::S => HashableType::String.into(),
SerSimpleType::G { signature } => ClassicType::Graph(Box::new(*signature)).into(),
SerSimpleType::Tuple { row: inner, c } => {
handle_container!(c, Tuple(Box::new(inner.try_convert_elems().unwrap())))
handle_container!(c, Tuple(Box::new(try_convert_list(inner).unwrap())))
}
SerSimpleType::Sum { row: inner, c } => {
handle_container!(c, Sum(Box::new(inner.try_convert_elems().unwrap())))
handle_container!(c, Sum(Box::new(try_convert_list(inner).unwrap())))
}
SerSimpleType::Array { inner, len, c } => {
handle_container!(c, Array(box_convert_try(*inner), len))
handle_container!(c, Array(Box::new((*inner).try_into().unwrap()), len))
}
SerSimpleType::Alias { name: s, c } => handle_container!(c, Alias(s)),
SerSimpleType::Opaque { custom, c } => {
Expand Down
2 changes: 0 additions & 2 deletions src/types/type_row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@ use crate::utils::display_list;
/// Base trait for anything that can be put in a [TypeRow]
pub trait TypeRowElem: Clone + 'static {}

impl<T: Clone + 'static> TypeRowElem for T {}

/// List of types, used for function signatures.
#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
//#[cfg_attr(feature = "pyo3", pyclass)] // TODO: expose unparameterized versions
Expand Down