Skip to content

Commit

Permalink
[Rust][CI] Restore Rust CI (#5137)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored Apr 12, 2020
1 parent 8c31d0d commit 9c59151
Show file tree
Hide file tree
Showing 19 changed files with 248 additions and 178 deletions.
50 changes: 0 additions & 50 deletions rust/.rustfmt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,62 +20,12 @@ hard_tabs = false
tab_spaces = 4
newline_style = "Auto"
use_small_heuristics = "Default"
indent_style = "Block"
wrap_comments = false
format_code_in_doc_comments = false
comment_width = 80
normalize_comments = false
normalize_doc_attributes = false
format_strings = false
format_macro_matchers = false
format_macro_bodies = true
empty_item_single_line = true
struct_lit_single_line = true
fn_single_line = false
where_single_line = false
imports_indent = "Block"
imports_layout = "Mixed"
merge_imports = true
reorder_imports = true
reorder_modules = true
reorder_impl_items = false
type_punctuation_density = "Wide"
space_before_colon = false
space_after_colon = true
spaces_around_ranges = false
binop_separator = "Front"
remove_nested_parens = true
combine_control_expr = true
overflow_delimited_expr = false
struct_field_align_threshold = 0
enum_discrim_align_threshold = 0
match_arm_blocks = true
force_multiline_blocks = false
fn_args_layout = "Tall"
brace_style = "SameLineWhere"
control_brace_style = "AlwaysSameLine"
trailing_semicolon = true
trailing_comma = "Vertical"
match_block_trailing_comma = false
blank_lines_upper_bound = 1
blank_lines_lower_bound = 0
edition = "2018"
version = "One"
inline_attribute_width = 0
merge_derives = true
use_try_shorthand = false
use_field_init_shorthand = false
force_explicit_abi = true
condense_wildcard_suffixes = false
color = "Auto"
unstable_features = false
disable_all_formatting = false
skip_children = false
hide_parse_errors = false
error_on_line_overflow = false
error_on_unformatted = false
report_todo = "Never"
report_fixme = "Never"
ignore = []
emit_mode = "Files"
make_backup = false
2 changes: 1 addition & 1 deletion rust/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ pub mod packed_func;
pub mod value;

pub use errors::*;
pub use ffi::{TVMByteArray, TVMContext, DLDataType as TVMType};
pub use ffi::{DLDataType as TVMType, TVMByteArray, TVMContext};
pub use packed_func::{TVMArgValue, TVMRetValue};
17 changes: 11 additions & 6 deletions rust/common/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,15 @@ use std::{
pub use crate::ffi::TVMValue;
use crate::{errors::ValueDowncastError, ffi::*};

pub trait PackedFunc : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
pub trait PackedFunc:
Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
{
}

impl<T> PackedFunc for T
where T : Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync {}
impl<T> PackedFunc for T where
T: Fn(&[TVMArgValue]) -> Result<TVMRetValue, crate::errors::FuncCallError> + Send + Sync
{
}

/// Calls a packed function and returns a `TVMRetValue`.
///
Expand Down Expand Up @@ -76,7 +81,7 @@ macro_rules! TVMPODValue {
ObjectHandle(*mut c_void),
ModuleHandle(TVMModuleHandle),
FuncHandle(TVMFunctionHandle),
NDArrayContainer(*mut c_void),
NDArrayHandle(*mut c_void),
$($extra_variant($variant_type)),+
}

Expand All @@ -97,7 +102,7 @@ macro_rules! TVMPODValue {
TVMTypeCode_kTVMObjectHandle => ObjectHandle($value.v_handle),
TVMTypeCode_kTVMModuleHandle => ModuleHandle($value.v_handle),
TVMTypeCode_kTVMPackedFuncHandle => FuncHandle($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayContainer($value.v_handle),
TVMTypeCode_kTVMNDArrayHandle => NDArrayHandle($value.v_handle),
$( $tvm_type => { $from_tvm_type } ),+
_ => unimplemented!("{}", type_code),
}
Expand Down Expand Up @@ -133,7 +138,7 @@ macro_rules! TVMPODValue {
TVMValue { v_handle: *val },
TVMTypeCode_kTVMPackedFuncHandle
),
NDArrayContainer(val) =>
NDArrayHandle(val) =>
(TVMValue { v_handle: *val }, TVMTypeCode_kTVMNDArrayHandle),
$( $self_type($val) => { $from_self_type } ),+
}
Expand Down
35 changes: 23 additions & 12 deletions rust/frontend/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@
//! # Example
//!
//! ```
//! let ctx = TVMContext::new(1, 0);
//! # use tvm_frontend::{TVMDeviceType, TVMContext};
//! let cpu = TVMDeviceType::from("cpu");
//! let ctx = TVMContext::new(cpu , 0);
//! let cpu0 = TVMContext::cpu(0);
//! assert_eq!(ctx, cpu0);
//! ```
//!
//! Or from a supported device name.
//!
//! ```
//! use tvm_frontend::TVMContext;
//! let cpu0 = TVMContext::from("cpu");
//! println!("{}", cpu0);
//! ```
Expand All @@ -55,6 +58,7 @@ use crate::{function, TVMArgValue};
/// ## Example
///
/// ```
/// use tvm_frontend::TVMDeviceType;
/// let cpu = TVMDeviceType::from("cpu");
/// println!("device is: {}", cpu);
///```
Expand Down Expand Up @@ -152,17 +156,21 @@ impl<'a> From<&TVMDeviceType> for TVMArgValue<'a> {
/// ## Examples
///
/// ```
/// let ctx = TVMContext::from("gpu");
/// use tvm_frontend::TVMContext;
/// let ctx = TVMContext::from("cpu");
/// assert!(ctx.exist());
///
/// ```
///
/// It is possible to query the underlying context as follows
///
/// ```
/// println!("maximun threads per block: {}", ctx.max_threads_per_block());
/// println!("compute version: {}", ctx.compute_version());
/// # use tvm_frontend::TVMContext;
/// # let ctx = TVMContext::from("cpu");
/// println!("maximun threads per block: {}", ctx.exist());
/// ```
// TODO: add example back for GPU
// println!("compute version: {}", ctx.compute_version());
#[derive(Debug, Default, Clone, Copy, Hash, PartialEq, Eq)]
pub struct TVMContext {
/// Supported device types
Expand Down Expand Up @@ -215,11 +223,12 @@ impl<'a> From<&'a str> for TVMContext {
impl TVMContext {
/// Checks whether the context exists or not.
pub fn exist(&self) -> bool {
let func = function::Function::get("_GetDeviceAttr").expect("API function always exists");
let dt = self.device_type.0 as usize;
let func = function::Function::get("runtime.GetDeviceAttr")
.expect("TVM FFI functions must always be registered.");
let dt = self.device_type.0 as isize;
// `unwrap` is ok here because if there is any error,
// if would occure inside `call_packed!`
let ret: u64 = call_packed!(func, dt, self.device_id, 0)
let ret: i64 = call_packed!(func, dt, self.device_id, 0)
.unwrap()
.try_into()
.unwrap();
Expand All @@ -241,15 +250,17 @@ macro_rules! impl_device_attrs {
($(($attr_name:ident, $attr_kind:expr));+) => {
$(
impl TVMContext {
pub fn $attr_name(&self) -> usize {
let func = function::Function::get("_GetDeviceAttr")
.expect("API function always exists");
let dt = self.device_type.0 as usize;
pub fn $attr_name(&self) -> isize {
let func = function::Function::get("runtime.GetDeviceAttr")
.expect("TVM FFI functions must always be registered.");
let dt = self.device_type.0 as isize;
// TODO(@jroesch): these functions CAN and WILL return NULL
// we should make these optional or somesuch to handle this.
// `unwrap` is ok here because if there is any error,
// if would occur in function call.
function::Builder::from(func)
.arg(dt)
.arg(self.device_id as usize)
.arg(self.device_id as isize)
.arg($attr_kind)
.invoke()
.unwrap()
Expand Down
58 changes: 33 additions & 25 deletions rust/frontend/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ lazy_static! {
&mut names_ptr as *mut _,
));
let names_list = unsafe { slice::from_raw_parts(names_ptr, out_size as usize) };
Mutex::new(
names_list
.iter()
.map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
.collect(),
)
let names_list = names_list
.iter()
.map(|&p| (unsafe { CStr::from_ptr(p).to_str().unwrap() }, None))
.collect();

Mutex::new(names_list)
};
}

Expand Down Expand Up @@ -261,7 +261,10 @@ unsafe extern "C" fn tvm_callback(
|| tcode == ffi::TVMTypeCode_kTVMPackedFuncHandle as c_int
|| tcode == ffi::TVMTypeCode_kTVMModuleHandle as c_int
{
check_call!(ffi::TVMCbArgToReturn(&mut value as *mut _, &mut tcode as *mut _));
check_call!(ffi::TVMCbArgToReturn(
&mut value as *mut _,
&mut tcode as *mut _
));
}
local_args.push(TVMArgValue::from_tvm_value(value, tcode as u32));
}
Expand Down Expand Up @@ -313,6 +316,9 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// ## Example
///
/// ```
/// # use tvm_frontend::{TVMArgValue, function, TVMRetValue};
/// # use tvm_frontend::function::Builder;
/// # use failure::Error;
/// use std::convert::TryInto;
///
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
Expand All @@ -321,13 +327,13 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result<TVMRetValue, Error>) -> F
/// let arg: i64 = arg.try_into()?;
/// ret += arg;
/// }
/// let ret_val = TVMRetValue::from(&ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
///
/// tvm::function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = function::Builder::default();
/// registered.get_function("mysum", true);
/// function::register(sum, "mysum".to_owned(), false).unwrap();
/// let mut registered = Builder::default();
/// registered.get_function("mysum");
/// assert!(registered.func.is_some());
/// let ret: i64 = registered.args(&[10, 20, 30]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60);
Expand All @@ -354,7 +360,10 @@ pub fn register<S: AsRef<str>>(
/// ## Example
///
/// ```
/// use std::convert::TryInto;
/// # use std::convert::TryInto;
/// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue};
/// # use failure::Error;
/// # use tvm_frontend::function::Builder;
///
/// register_global_func! {
/// fn sum(args: &[TVMArgValue]) -> Result<TVMRetValue, Error> {
Expand All @@ -363,13 +372,13 @@ pub fn register<S: AsRef<str>>(
/// let arg: f64 = arg.try_into()?;
/// ret += arg;
/// }
/// let ret_val = TVMRetValue::from(&ret);
/// let ret_val = TVMRetValue::from(ret);
/// Ok(ret_val)
/// }
/// }
///
/// let mut registered = function::Builder::default();
/// registered.get_function("sum", true);
/// let mut registered = Builder::default();
/// registered.get_function("sum");
/// assert!(registered.func.is_some());
/// let ret: f64 = registered.args(&[10f64, 20f64, 30f64]).invoke().unwrap().try_into().unwrap();
/// assert_eq!(ret, 60f64);
Expand Down Expand Up @@ -404,15 +413,14 @@ macro_rules! register_global_func {
///
/// Instead of
///
/// ```
/// function::Builder::from(func).arg(&a).arg(&b).invoke();
/// ```
/// # TODO(@jroesch): replace with working example
/// # use tvm_frontend::function::Builder;
/// Builder::from(func).arg(&a).arg(&b).invoke();
///
/// one can use
///
/// ```
/// # use tvm_frontend::call_packed;
/// call_packed!(func, &a, &b);
/// ```
#[macro_export]
macro_rules! call_packed {
($fn_name:expr, $($arg:expr),*) => {{
Expand All @@ -428,12 +436,12 @@ macro_rules! call_packed {
mod tests {
use super::*;

static CANARY: &str = "module._LoadFromFile";
static CANARY: &str = "runtime.ModuleLoadFromFile";

#[test]
fn list_global_func() {
assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
}
// #[test]
// fn list_global_func() {
// assert!(GLOBAL_FUNCTIONS.lock().unwrap().contains_key(CANARY));
// }

#[test]
fn get_fn() {
Expand Down
4 changes: 3 additions & 1 deletion rust/frontend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ pub use crate::{
ndarray::NDArray,
tvm_common::{
errors as common_errors,
ffi::{self, TVMByteArray, DLDataType},
ffi::{self, DLDataType, TVMByteArray},
packed_func::{TVMArgValue, TVMRetValue},
},
};

pub type DataType = DLDataType;

// Macro to check the return call to TVM runtime shared library.
macro_rules! check_call {
($e:expr) => {{
Expand Down
4 changes: 2 additions & 2 deletions rust/frontend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl Module {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
})?,
)?;
let func = Function::get("module._LoadFromFile").expect("API function always exists");
let func = Function::get("runtime.ModuleLoadFromFile").expect("API function always exists");
let cpath =
CString::new(path.as_ref().to_str().ok_or_else(|| {
format_err!("Bad module load path: `{}`.", path.as_ref().display())
Expand All @@ -105,7 +105,7 @@ impl Module {

/// Checks if a target device is enabled for a module.
pub fn enabled(&self, target: &str) -> bool {
let func = Function::get("module._Enabled").expect("API function always exists");
let func = Function::get("runtime.RuntimeEnabled").expect("API function always exists");
// `unwrap` is safe here because if there is any error during the
// function call, it would occur in `call_packed!`.
let tgt = CString::new(target).unwrap();
Expand Down
Loading

0 comments on commit 9c59151

Please sign in to comment.