Skip to content

Commit

Permalink
Update rust/tvm-sys/src/datatype.rs
Browse files Browse the repository at this point in the history
Co-authored-by: Nick Hynes <[email protected]>
  • Loading branch information
jroesch and nhynes committed May 8, 2020
1 parent 902c271 commit 6818797
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions rust/tvm-sys/src/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,35 +97,40 @@ impl From<DLDataType> for DataType {
}

#[derive(Debug, Error)]
pub enum ParseTvmTypeError {
pub enum ParseDataTypeError {
#[error("invalid number: {0}")]
InvalidNumber(std::num::ParseIntError),
#[error("missing data type specifier (e.g., int32, float64)")]
MissingDataType,
#[error("unknown type: {0}")]
UnknownType(String),
}

/// Implements TVMType conversion from `&str` of general format `{dtype}{bits}x{lanes}`
/// such as "int32", "float32" or with lane "float32x1".
impl FromStr for DataType {
type Err = ParseTvmTypeError;
type Err = ParseDataTypeError;

fn from_str(type_str: &str) -> Result<Self, Self::Err> {
use ParseDataTypeError::*;

if type_str == "bool" {
return Ok(DataType::new(1, 1, 1));
}

let mut type_lanes = type_str.split('x');
let typ = type_lanes.next().expect("Missing dtype");
let typ = type_lanes.next().ok_or(MissingDataType)?;
let lanes = type_lanes
.next()
.map(|l| <u16>::from_str_radix(l, 10))
.unwrap_or(Ok(1))
.map_err(ParseTvmTypeError::InvalidNumber)?;
.map_err(InvalidNumber)?;
let (type_name, bits) = match typ.find(char::is_numeric) {
Some(idx) => {
let (name, bits_str) = typ.split_at(idx);
(
name,
u8::from_str_radix(bits_str, 10).map_err(ParseTvmTypeError::InvalidNumber)?,
u8::from_str_radix(bits_str, 10).map_err(InvalidNumber)?,
)
}
None => (typ, 32),
Expand All @@ -136,7 +141,7 @@ impl FromStr for DataType {
"uint" => DL_UINT_CODE,
"float" => DL_FLOAT_CODE,
"handle" => DL_HANDLE,
_ => return Err(ParseTvmTypeError::UnknownType(type_name.to_string())),
_ => return Err(UnknownType(type_name.to_string())),
};

Ok(DataType::new(type_code, bits, lanes))
Expand Down

0 comments on commit 6818797

Please sign in to comment.