Skip to content

Commit

Permalink
Almost there ...
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 5, 2020
1 parent 7fedb66 commit f4a940e
Show file tree
Hide file tree
Showing 2 changed files with 193 additions and 81 deletions.
62 changes: 31 additions & 31 deletions rust/tvm-rt/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use lazy_static::lazy_static;
pub use tvm_sys::{ffi, ArgValue, RetValue};

use crate::{errors, Module};
use super::to_function::ToFunction;
use super::to_function::{ToFunction, Typed};

lazy_static! {
static ref GLOBAL_FUNCTIONS: Mutex<BTreeMap<&'static str, Option<Function>>> = {
Expand Down Expand Up @@ -279,7 +279,7 @@ impl<'a, 'm> From<&'m mut Module> for Builder<'a, 'm> {
pub fn register<'a, F, I, O, S: AsRef<str>>(
f: F,
name: S,
) -> Result<()> where F: ToFunction<'a, I, O> {
) -> Result<()> where F: ToFunction<I, O>, F: Typed<I, O> {
register_override(f, name, false)
}

Expand Down Expand Up @@ -319,7 +319,7 @@ pub fn register_override<'a, F, I, O, S: AsRef<str>>(
f: F,
name: S,
override_: bool,
) -> Result<()> where F: ToFunction<'a, I, O> {
) -> Result<()> where F: ToFunction<I, O>, F: Typed<I, O> {
let func = f.to_function();
let name = CString::new(name.as_ref())?;
check_call!(ffi::TVMFuncRegisterGlobal(
Expand Down Expand Up @@ -391,38 +391,37 @@ mod tests {
assert_eq!(func.arg_buf.len(), 3);
}

// #[test]
// fn register_and_call_fn() {
// use crate::{ArgValue, function, RetValue};
// use crate::function::Builder;
// use anyhow::Error;
// use std::convert::TryInto;

// fn sum(args: &[ArgValue]) -> Result<RetValue, Error> {
// let mut ret = 0i64;
// for arg in args.iter() {
// let arg: i64 = arg.try_into()?;
// ret += arg;
// }
// let ret_val = RetValue::from(ret);
// Ok(ret_val)
// }

// function::register_override(sum, "mysum".to_owned(), true).unwrap();
// let mut registered = Builder::default();
// registered.get_function("mysum");
// println!("{:?}", registered.func);
// assert!(registered.func.is_some());
// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
// assert_eq!(ret, 60);
// }
#[test]
fn register_and_call_fn() {
use crate::{ArgValue, function, RetValue};
use crate::function::Builder;
use anyhow::Error;
use std::convert::TryInto;

fn sum(args: &[ArgValue]) -> Result<RetValue, Error> {
let mut ret = 0i64;
for arg in args.iter() {
let arg: i64 = arg.try_into()?;
ret += arg;
}
let ret_val = RetValue::from(ret);
Ok(ret_val)
}

function::register_override(sum, "mysum".to_owned(), true).unwrap();
let mut registered = Builder::default();
registered.get_function("mysum");
println!("{:?}", registered.func);
assert!(registered.func.is_some());
let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
assert_eq!(ret, 60);
}


#[test]
fn register_and_call_closure0() {
use crate::{ArgValue, function, RetValue};
use crate::{function};
use crate::function::Builder;
use anyhow::Error;
use std::convert::TryInto;

fn sum() -> i64 {
Expand All @@ -444,6 +443,7 @@ mod tests {
use crate::function::Builder;
use anyhow::Error;
use std::convert::TryInto;
use tvm_sys::value::*;

fn sum(x: i64) -> i64 {
return 10;
Expand All @@ -464,7 +464,7 @@ mod tests {
use crate::function::Builder;
use anyhow::Error;
use std::convert::TryInto;

use tvm_sys::value::*;
fn sum(a: i64, b: i64, c: i64) -> i64 {
return a + b + c;
}
Expand Down
212 changes: 162 additions & 50 deletions rust/tvm-rt/src/to_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,106 @@ use anyhow::{Result};

pub use tvm_sys::{ffi, ArgValue, RetValue};

use std::convert::{TryInto, TryFrom};
use super::Function;

pub trait ToFunction<'a, I, O>: Sized {
/// A trait representing whether the function arguments
/// and return type can be assigned to a TVM packed function.
///
/// By splitting the conversion to function into two traits
/// we are able to improve error reporting, by splitting the
/// conversion of inputs and outputs to this trait.
///
/// And the implementation of it to `ToFunction`.
pub trait Typed<I, O> {
fn args(i: &[ArgValue<'static>]) -> anyhow::Result<I>;
fn ret(o: O) -> RetValue;
}

impl<'a, F> Typed<&'a [ArgValue<'static>], anyhow::Result<RetValue>> for F where F: Fn(&'a [ArgValue]) -> anyhow::Result<RetValue> {
fn args(args: &[ArgValue<'static>]) -> anyhow::Result<&'a [ArgValue<'static>]> {
// this is BAD but just hacking for time being
Ok(unsafe { std::mem::transmute(args) })
}

fn ret(ret_value: anyhow::Result<RetValue>) -> RetValue {
ret_value.unwrap()
}
}

impl<F, O: Into<RetValue>> Typed<(), O> for F where F: Fn() -> O {
fn args(_args: &[ArgValue<'static>]) -> anyhow::Result<()> {
debug_assert!(_args.len() == 0);
Ok(())
}

fn ret(o: O) -> RetValue {
o.into()
}
}

impl<F, A, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,), O> for F
where F: Fn(A) -> O,
E: std::error::Error + Send + Sync + 'static,
A: TryFrom<ArgValue<'static>, Error=E> {
fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,)> {
debug_assert!(args.len() == 1);
let a: A = args[0].clone().try_into()?;
Ok((a,))
}

fn ret(o: O) -> RetValue {
o.into()
}
}

impl<F, A, B, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,B,), O> for F
where F: Fn(A, B) -> O,
E: std::error::Error + Send + Sync + 'static,
A: TryFrom<ArgValue<'static>, Error=E>,
B: TryFrom<ArgValue<'static>, Error=E> {
fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,B)> {
debug_assert!(args.len() == 1);
let a: A = args[0].clone().try_into()?;
let b: B = args[1].clone().try_into()?;
Ok((a, b))
}

fn ret(o: O) -> RetValue {
o.into()
}
}

impl<F, A, B, C, O: Into<RetValue>, E: Into<anyhow::Error>> Typed<(A,B,C), O> for F
where F: Fn(A, B, C) -> O,
E: std::error::Error + Send + Sync + 'static,
A: TryFrom<ArgValue<'static>, Error=E>,
B: TryFrom<ArgValue<'static>, Error=E>,
C: TryFrom<ArgValue<'static>, Error=E> {
fn args(args: &[ArgValue<'static>]) -> anyhow::Result<(A,B,C,)> {
debug_assert!(args.len() == 1);
let a: A = args[0].clone().try_into()?;
let b: B = args[1].clone().try_into()?;
let c: C = args[2].clone().try_into()?;
Ok((a, b, c,))
}

fn ret(o: O) -> RetValue {
o.into()
}
}

pub trait ToFunction<I, O>: Sized {
type Handle;

fn into_raw(self) -> *mut Self::Handle;
fn call(handle: *mut Self::Handle, args: &[ArgValue<'a>]) -> anyhow::Result<RetValue>;

fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> anyhow::Result<RetValue>
where Self: Typed<I, O>;

fn drop(handle: *mut Self::Handle);

fn to_function(self) -> Function {
fn to_function(self) -> Function where Self: Typed<I, O> {
let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle;
let resource_handle = self.into_raw();
check_call!(ffi::TVMFuncCreateFromCFunc(
Expand All @@ -65,7 +155,7 @@ pub trait ToFunction<'a, I, O>: Sized {
num_args: c_int,
ret: ffi::TVMRetValueHandle,
fhandle: *mut c_void,
) -> c_int {
) -> c_int where Self:Typed<I, O> {
// turning off the incorrect linter complaints
#![allow(unused_assignments, unused_unsafe)]
println!("here");
Expand Down Expand Up @@ -123,87 +213,109 @@ pub trait ToFunction<'a, I, O>: Sized {
}
}

// impl<'a, 'b> ToFunction<&'a [ArgValue<'b>], RetValue> for fn(&[ArgValue]) -> Result<RetValue> {
// type Handle = for <'x, 'y> fn(&'x [ArgValue<'y>]) -> Result<RetValue>;
impl ToFunction<&[ArgValue<'static>], Result<RetValue>> for for <'a> fn(&'a [ArgValue<'static>]) -> Result<RetValue> {
type Handle = fn(&[ArgValue<'static>]) -> Result<RetValue>;

// fn into_raw(self) -> *mut Self::Handle {
// self as *mut Self::Handle
// }
fn into_raw(self) -> *mut Self::Handle {
self as *mut Self::Handle
}

// fn call(handle: *mut Self::Handle, args: &[ArgValue]) -> Result<RetValue> {
// println!("calls");
// let handle: Self::Handle = unsafe { std::mem::transmute(handle) };
// let r = handle(args);
// println!("afters");
// r
// }
fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> {
panic!()
// println!("calls");
// let handle: Self::Handle = unsafe { std::mem::transmute(handle) };
// let r = handle(args);
// println!("afters");
// r
}

// // Function's don't need de-allocation because the pointers are into the code section of memory.
// fn drop(_: *mut Self::Handle) {}
// }
// Function's don't need de-allocation because the pointers are into the code section of memory.
fn drop(_: *mut Self::Handle) {}
}

impl<'a, O: Into<RetValue>, F> ToFunction<'a, (), O> for F where F: Fn() -> O + 'static {
impl<O, F> ToFunction<(), O> for F where F: Fn() -> O + 'static {
type Handle = Box<dyn Fn() -> O + 'static>;

fn into_raw(self) -> *mut Self::Handle {
let ptr: Box<Self::Handle> = Box::new(Box::new(self));
Box::into_raw(ptr)
}

fn call(handle: *mut Self::Handle, _: &[ArgValue<'a>]) -> Result<RetValue> {
fn call(handle: *mut Self::Handle, _: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<(), O> {
// Ideally we shouldn't need to clone, probably doesn't really matter.
unsafe { Ok((*handle)().into()) }
let out = unsafe { (*handle)() };
Ok(F::ret(out))
}

fn drop(_: *mut Self::Handle) {}
}

macro_rules! to_function_instance {
($(($param:ident,$index:expr),)+) => {
impl<'a, $($param,)+ O: Into<RetValue>, F> ToFunction<'a, ($($param,)+), O> for
F where F: Fn($($param,)+) -> O + 'static,
$($param: From<ArgValue<'a>>,)+
$($param: 'a,)+ {
($(($param:ident,$index:tt),)+) => {
impl<F, $($param,)+ O> ToFunction<($($param,)+), O> for
F where F: Fn($($param,)+) -> O + 'static {
type Handle = Box<dyn Fn($($param,)+) -> O + 'static>;

fn into_raw(self) -> *mut Self::Handle {
let ptr: Box<Self::Handle> = Box::new(Box::new(self));
Box::into_raw(ptr)
}

fn call(handle: *mut Self::Handle, args: &[ArgValue<'a>]) -> Result<RetValue> {
fn call(handle: *mut Self::Handle, args: &[ArgValue<'static>]) -> Result<RetValue> where F: Typed<($($param,)+), O> {
// Ideally we shouldn't need to clone, probably doesn't really matter.
let res = unsafe {
(*handle)($(args[$index].clone().into(),)+)
let args = F::args(args)?;
let out = unsafe {
(*handle)($(args.$index),+)
};
Ok(res.into())
Ok(F::ret(out))
}

fn drop(_: *mut Self::Handle) {}
}
}
}

// impl<'a, A, O: Into<RetValue>, F> ToFunction<'a, (A,), O> for F
// where
// F: Fn(A) -> O + 'static,
// A: From<ArgValue<'a>>,
// A: 'a,
// {
// type Handle = Box<dyn Fn(A) -> O + 'static>;
// fn into_raw(self) -> *mut Self::Handle {
// let ptr: Box<Self::Handle> = Box::new(Box::new(self));
// Box::into_raw(ptr)
// }
// fn call(handle: *mut Self::Handle, args: &[ArgValue<'a>]) -> Result<RetValue> {
// let res = unsafe { (*handle)(args[0].clone().into()) };
// Ok(res.into())
// }

// fn drop(_: *mut Self::Handle) {}
// }

to_function_instance!((A, 0),);
to_function_instance!((A, 0), (B, 1),);
to_function_instance!((A, 0), (B, 1), (C, 2),);
to_function_instance!((A, 0), (B, 1), (C, 2), (D, 3),);

#[cfg(test)]
mod tests {
use super::{Function, Typed, ToFunction};
use super::{ArgValue, RetValue};

fn zero() -> i32 { 10 }

fn helper<F, I, O>(f: F) -> Function where F: ToFunction<I, O>, F: Typed<I, O> {
f.to_function()
}

fn func_args(args: &[ArgValue<'static>]) -> anyhow::Result<RetValue> {
Ok(10.into())
}

#[test]
fn test_fn_ptr() {
ToFunction::<&[ArgValue<'static>], anyhow::Result<RetValue>>::to_function(func_args);
}

#[test]
fn test_to_function0() {
helper(zero);
}

fn one_arg(i: i32) -> i32 { i }

#[test]
fn test_to_function1() {
helper(one_arg);
}

fn two_arg(i: i32, j: i32) -> i32 { i }

#[test]
fn test_to_function2() {
helper(two_arg);
}
}

0 comments on commit f4a940e

Please sign in to comment.