Skip to content

Commit

Permalink
to_tvm_value
Browse files Browse the repository at this point in the history
  • Loading branch information
nhynes committed Apr 5, 2019
1 parent 0a9ccab commit 784081e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 24 deletions.
37 changes: 20 additions & 17 deletions rust/common/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ macro_rules! TVMPODValue {
match $value:ident {
$($tvm_type:ident => { $from_tvm_type:expr })+
},
match self {
match &self {
$($self_type:ident ( $val:ident ) => { $from_self_type:expr })+
}
$(,)?
Expand Down Expand Up @@ -83,34 +83,37 @@ macro_rules! TVMPODValue {
}
}

pub fn into_tvm_value(self) -> (TVMValue, TVMTypeCode) {
pub fn to_tvm_value(&self) -> (TVMValue, TVMTypeCode) {
use $name::*;
match self {
Int(val) => (TVMValue { v_int64: val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: val }, DLDataTypeCode_kDLFloat),
Int(val) => (TVMValue { v_int64: *val }, DLDataTypeCode_kDLInt),
UInt(val) => (TVMValue { v_int64: *val as i64 }, DLDataTypeCode_kDLUInt),
Float(val) => (TVMValue { v_float64: *val }, DLDataTypeCode_kDLFloat),
Null => (TVMValue{ v_int64: 0 },TVMTypeCode_kNull),
Type(val) => (TVMValue { v_type: val }, TVMTypeCode_kTVMType),
Context(val) => (TVMValue { v_ctx: val }, TVMTypeCode_kTVMContext),
Type(val) => (TVMValue { v_type: *val }, TVMTypeCode_kTVMType),
Context(val) => (TVMValue { v_ctx: val.clone() }, TVMTypeCode_kTVMContext),
String(val) => {
(
TVMValue { v_handle: val.into_raw() as *mut c_void },
TVMValue { v_handle: val.as_ptr() as *mut c_void },
TVMTypeCode_kStr,
)
}
Handle(val) => (TVMValue { v_handle: val }, TVMTypeCode_kHandle),
Handle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kHandle),
ArrayHandle(val) => {
(
TVMValue { v_handle: val as *const _ as *mut c_void },
TVMTypeCode_kArrayHandle,
)
},
NodeHandle(val) => (TVMValue { v_handle: val }, TVMTypeCode_kNodeHandle),
NodeHandle(val) => (TVMValue { v_handle: *val }, TVMTypeCode_kNodeHandle),
ModuleHandle(val) =>
(TVMValue { v_handle: val }, TVMTypeCode_kModuleHandle),
FuncHandle(val) => (TVMValue { v_handle: val }, TVMTypeCode_kFuncHandle),
(TVMValue { v_handle: *val }, TVMTypeCode_kModuleHandle),
FuncHandle(val) => (
TVMValue { v_handle: val.clone() },
TVMTypeCode_kFuncHandle
),
NDArrayContainer(val) =>
(TVMValue { v_handle: val }, TVMTypeCode_kNDArrayContainer),
(TVMValue { v_handle: val.clone() }, TVMTypeCode_kNDArrayContainer),
$( $self_type($val) => { $from_self_type } ),+
}
}
Expand All @@ -129,9 +132,9 @@ TVMPODValue! {
TVMTypeCode_kBytes => { Bytes(&*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *const i8)) }
},
match self {
match &self {
Bytes(val) => {
(TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kBytes)
(TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes)
}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr)}
}
Expand Down Expand Up @@ -159,9 +162,9 @@ TVMPODValue! {
TVMTypeCode_kBytes => { Bytes(*(value.v_handle as *const TVMByteArray)) }
TVMTypeCode_kStr => { Str(CStr::from_ptr(value.v_handle as *mut i8)) }
},
match self {
match &self {
Bytes(val) =>
{ (TVMValue { v_handle: &val as *const _ as *mut c_void }, TVMTypeCode_kBytes ) }
{ (TVMValue { v_handle: val as *const _ as *mut c_void }, TVMTypeCode_kBytes ) }
Str(val) =>
{ (TVMValue { v_str: val.as_ptr() }, TVMTypeCode_kStr ) }
}
Expand Down
9 changes: 3 additions & 6 deletions rust/frontend/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,8 @@ impl<'a, 'm> Builder<'a, 'm> {
ensure!(self.func.is_some(), errors::FunctionNotFoundError);

let num_args = self.arg_buf.len();
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) = self
.arg_buf
.iter()
.map(|arg| arg.clone().into_tvm_value())
.unzip();
let (mut values, mut type_codes): (Vec<ffi::TVMValue>, Vec<ffi::TVMTypeCode>) =
self.arg_buf.iter().map(|arg| arg.to_tvm_value()).unzip();

let mut ret_val = unsafe { std::mem::uninitialized::<TVMValue>() };
let mut ret_type_code = 0i32;
Expand Down Expand Up @@ -265,7 +262,7 @@ unsafe extern "C" fn tvm_callback(
}
};

let (mut ret_val, ret_tcode) = rv.into_tvm_value();
let (mut ret_val, ret_tcode) = rv.to_tvm_value();
let mut ret_type_code = ret_tcode as c_int;
check_call!(ffi::TVMCFuncSetReturn(
ret,
Expand Down
2 changes: 1 addition & 1 deletion rust/runtime/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub(super) fn wrap_backend_packed_func(
let (values, type_codes): (Vec<TVMValue>, Vec<i32>) = args
.into_iter()
.map(|arg| {
let (val, code) = arg.clone().into_tvm_value();
let (val, code) = arg.to_tvm_value();
(val, code as i32)
})
.unzip();
Expand Down

0 comments on commit 784081e

Please sign in to comment.