Skip to content

Commit

Permalink
[Rust] Add first stage of updating and rewriting Rust bindings. (#5526)
Browse files Browse the repository at this point in the history
* Add tvm-sys

* Use as_mut_ptr

* Address CR feedback

* Update rust/tvm-sys/src/datatype.rs

Co-authored-by: Nick Hynes <[email protected]>

* Final CR comments

* Fix find and replace error in frontend

Co-authored-by: Nick Hynes <[email protected]>
  • Loading branch information
jroesch and nhynes authored May 8, 2020
1 parent 2175f6b commit aded92d
Show file tree
Hide file tree
Showing 12 changed files with 1,294 additions and 1 deletion.
1 change: 1 addition & 0 deletions rust/.rustfmt.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ merge_derives = true
use_try_shorthand = false
use_field_init_shorthand = false
force_explicit_abi = true

3 changes: 2 additions & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ members = [
"frontend",
"frontend/tests/basics",
"frontend/tests/callback",
"frontend/examples/resnet"
"frontend/examples/resnet",
"tvm-sys"
]
35 changes: 35 additions & 0 deletions rust/tvm-sys/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

[package]
name = "tvm-sys"
version = "0.1.0"
authors = ["TVM Contributors"]
license = "Apache-2.0"
edition = "2018"

[features]
bindings = []

[dependencies]
thiserror = "^1.0"
anyhow = "^1.0"
ndarray = "0.12"
enumn = "^0.1"

[build-dependencies]
bindgen = "0.51"
61 changes: 61 additions & 0 deletions rust/tvm-sys/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

extern crate bindgen;

use std::path::PathBuf;

use std::env;

fn main() {
let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({
let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.canonicalize()
.unwrap();
crate_dir
.parent()
.unwrap()
.parent()
.unwrap()
.to_str()
.unwrap()
.to_string()
});

if cfg!(feature = "bindings") {
println!("cargo:rerun-if-env-changed=TVM_HOME");
println!("cargo:rustc-link-lib=dylib=tvm");
println!("cargo:rustc-link-search={}/build", tvm_home);
}

// @see rust-bindgen#550 for `blacklist_type`
bindgen::Builder::default()
.header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home))
.header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home))
.clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home))
.clang_arg(format!("-I{}/include/", tvm_home))
.blacklist_type("max_align_t")
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
.generate()
.expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
.expect("can not write the bindings!");
}
62 changes: 62 additions & 0 deletions rust/tvm-sys/src/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::{
mem,
os::raw::{c_int, c_void},
};

use crate::ffi::{
DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt,
DLDeviceType_kDLCPU, DLTensor,
};

/// `From` conversions to `DLTensor` for `ndarray::Array`.
/// Takes a reference to the `ndarray` since `DLTensor` is not owned.
macro_rules! impl_dltensor_from_ndarray {
($type:ty, $typecode:expr) => {
impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor {
fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self {
DLTensor {
data: arr.as_mut_ptr() as *mut c_void,
ctx: DLContext {
device_type: DLDeviceType_kDLCPU,
device_id: 0,
},
ndim: arr.ndim() as c_int,
dtype: DLDataType {
code: $typecode as u8,
bits: 8 * mem::size_of::<$type>() as u8,
lanes: 1,
},
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const i64 as *mut i64,
byte_offset: 0,
}
}
}
};
}

impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat);
impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt);
impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt);
impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt);
87 changes: 87 additions & 0 deletions rust/tvm-sys/src/byte_array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
use std::os::raw::c_char;

use crate::ffi::TVMByteArray;

/// A newtype wrapping a raw TVM byte-array.
///
/// ## Example
///
/// ```
/// let v = b"hello";
/// let barr = tvm_sys::ByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
/// ```
pub struct ByteArray {
/// The raw FFI ByteArray.
array: TVMByteArray,
}

impl ByteArray {
/// Gets the underlying byte-array
pub fn data(&self) -> &'static [u8] {
unsafe { std::slice::from_raw_parts(self.array.data as *const u8, self.array.size) }
}

/// Gets the length of the underlying byte-array
pub fn len(&self) -> usize {
self.array.size
}

/// Converts the underlying byte-array to `Vec<u8>`
pub fn to_vec(&self) -> Vec<u8> {
self.data().to_vec()
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}
}

// Needs AsRef for Vec
impl<T: AsRef<[u8]>> From<T> for ByteArray {
fn from(arg: T) -> Self {
let arg = arg.as_ref();
ByteArray {
array: TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
},
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn convert() {
let v = vec![1u8, 2, 3];
let barr = ByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.to_vec(), vec![1u8, 2, 3]);
let v = b"hello";
let barr = ByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
}
}
Loading

0 comments on commit aded92d

Please sign in to comment.