diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 6630bf3ded20..0d2fa5ce3986 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -36,6 +37,8 @@ namespace tvm { +using tvm::runtime::String; + /*! * \brief Base type of all the expressions. * \sa Expr @@ -186,7 +189,7 @@ class GlobalVar; class GlobalVarNode : public RelayExprNode { public: /*! \brief The name of the variable, this only acts as a hint. */ - std::string name_hint; + String name_hint; void VisitAttrs(AttrVisitor* v) { v->Visit("name_hint", &name_hint); diff --git a/include/tvm/node/container.h b/include/tvm/node/container.h index ba1edf84383e..82e0fd08c35c 100644 --- a/include/tvm/node/container.h +++ b/include/tvm/node/container.h @@ -472,7 +472,7 @@ class Map : public ObjectRef { /*! * \brief Read element from map. * \param key The key - * \return the corresonding element. + * \return the corresponding element. */ inline const V operator[](const K& key) const { return DowncastNoCheck( diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 9ed87df46618..2ce91b256d36 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -76,7 +76,7 @@ class AttrVisitor { /*! * \brief Virtual function table to support IR/AST node reflection. * - * Functions are stored in columar manner. + * Functions are stored in a columnar manner. * Each column is a vector indexed by Object's type_index. */ class ReflectionVTable { diff --git a/include/tvm/relay/base.h b/include/tvm/relay/base.h index 1d0120675e99..d4454765ea2d 100644 --- a/include/tvm/relay/base.h +++ b/include/tvm/relay/base.h @@ -91,7 +91,7 @@ class IdNode : public Object { * this only acts as a hint to the user, * and is not used for equality. */ - std::string name_hint; + String name_hint; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("name_hint", &name_hint); diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index fe240c30e471..6bc3e57f7410 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include "./base.h" @@ -36,6 +37,7 @@ namespace tvm { namespace relay { +using tvm::runtime::String; using Expr = tvm::RelayExpr; using ExprNode = tvm::RelayExprNode; using BaseFunc = tvm::BaseFunc; @@ -172,7 +174,7 @@ class VarNode : public ExprNode { Type type_annotation; /*! \return The name hint of the variable */ - const std::string& name_hint() const { + const String& name_hint() const { return vid->name_hint; } diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index 920ecfbf9b13..eb3b5aca2654 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -540,6 +540,15 @@ TVM_DLL int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex); */ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); +/*! + * \brief Increase the reference count of an object. + * + * \param obj The object handle. + * \note Internally we increase the reference counter of the object. + * \return 0 when success, -1 when failure happens + */ +TVM_DLL int TVMObjectRetain(TVMObjectHandle obj); + /*! * \brief Free the object. * @@ -550,6 +559,8 @@ TVM_DLL int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex); */ TVM_DLL int TVMObjectFree(TVMObjectHandle obj); +TVM_DLL int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived); + #ifdef __cplusplus } // TVM_EXTERN_C #endif diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8f426415ffee..7eef492c5702 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -325,6 +325,8 @@ class StringObj : public Object { friend class String; }; +TVM_REGISTER_OBJECT_TYPE(StringObj); + /*! * \brief Reference to string objects. * diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index edca925baeb0..6bb076025729 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -68,7 +68,7 @@ enum TypeIndex { * the type index will be assigned during runtime. * Runtime type index can be accessed by ObjectType::TypeIndex(); * - _type_key: - * The unique string identifier of tyep type. + * The unique string identifier of the type. * - _type_final: * Whether the type is terminal type(there is no subclass of the type in the object system). * This field is automatically set by marco TVM_DECLARE_FINAL_OBJECT_INFO diff --git a/python/tvm/runtime/object_generic.py b/python/tvm/runtime/object_generic.py index ac20b67e8299..709279d4c1a7 100644 --- a/python/tvm/runtime/object_generic.py +++ b/python/tvm/runtime/object_generic.py @@ -38,7 +38,7 @@ def asobject(self): def convert_to_object(value): - """Convert a python value to corresponding object type. + """Convert a Python value to corresponding object type. Parameters ---------- diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 8467f6a92ea8..325f73abeb94 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -17,14 +17,14 @@ [workspace] members = [ - "common", + "tvm-sys", "macros", - "runtime", - "runtime/tests/test_tvm_basic", - "runtime/tests/test_tvm_dso", - "runtime/tests/test_nn", - "frontend", - "frontend/tests/basics", - "frontend/tests/callback", - "frontend/examples/resnet" + "graph-runtime", + "graph-runtime/tests/test_tvm_basic", + "graph-runtime/tests/test_tvm_dso", + "graph-runtime/tests/test_nn", + "tvm", + "tvm/tests/basics", + "tvm/tests/callback", + "tvm/examples/resnet" ] diff --git a/rust/common/build.rs b/rust/common/build.rs deleted file mode 100644 index b3ae7b6d1837..000000000000 --- a/rust/common/build.rs +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern crate bindgen; - -use std::path::PathBuf; - -fn main() { - let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ - let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .canonicalize() - .unwrap(); - crate_dir - .parent() - .unwrap() - .parent() - .unwrap() - .to_str() - .unwrap() - .to_string() - }); - if cfg!(feature = "bindings") { - println!("cargo:rerun-if-env-changed=TVM_HOME"); - println!("cargo:rustc-link-lib=dylib=tvm_runtime"); - println!("cargo:rustc-link-search={}/build", tvm_home); - } - - // @see rust-bindgen#550 for `blacklist_type` - bindgen::Builder::default() - .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) - .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) - .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) - .clang_arg(format!("-I{}/include/", tvm_home)) - .blacklist_type("max_align_t") - .layout_tests(false) - .derive_partialeq(true) - .derive_eq(true) - .generate() - .expect("unable to generate bindings") - .write_to_file(PathBuf::from("src/c_runtime_api.rs")) - .expect("can not write the bindings!"); -} diff --git a/rust/frontend/.travis.yml b/rust/graph-runtime/.travis.yml similarity index 100% rename from rust/frontend/.travis.yml rename to rust/graph-runtime/.travis.yml diff --git a/rust/runtime/Cargo.toml b/rust/graph-runtime/Cargo.toml similarity index 89% rename from rust/runtime/Cargo.toml rename to rust/graph-runtime/Cargo.toml index eb531f96e5be..460bd2f810ab 100644 --- a/rust/runtime/Cargo.toml +++ b/rust/graph-runtime/Cargo.toml @@ -16,10 +16,10 @@ # under the License. [package] -name = "tvm-runtime" -version = "0.1.0" +name = "tvm-graph-runtime" +version = "0.2.0" license = "Apache-2.0" -description = "A static TVM runtime" +description = "A static linking friendly TVM graph runtime." repository = "https://github.com/apache/incubator-tvm" readme = "README.md" keywords = ["tvm"] @@ -38,7 +38,7 @@ num_cpus = "1.10" serde = "1.0" serde_derive = "1.0" serde_json = "1.0" -tvm-common = { version = "0.1", path = "../common" } +tvm-sys = { version = "0.1", path = "../tvm-sys" } tvm-macros = { version = "0.1", path = "../macros" } [target.'cfg(not(any(target_arch = "wasm32", target_env = "sgx")))'.dependencies] diff --git a/rust/runtime/src/allocator.rs b/rust/graph-runtime/src/allocator.rs similarity index 100% rename from rust/runtime/src/allocator.rs rename to rust/graph-runtime/src/allocator.rs diff --git a/rust/runtime/src/array.rs b/rust/graph-runtime/src/array.rs similarity index 99% rename from rust/runtime/src/array.rs rename to rust/graph-runtime/src/array.rs index 2b6c7c217e28..cd93f11c4007 100644 --- a/rust/runtime/src/array.rs +++ b/rust/graph-runtime/src/array.rs @@ -19,7 +19,7 @@ use std::{convert::TryFrom, mem, os::raw::c_void, ptr, slice}; -use failure::Error; +use anyhow::Error; use ndarray; use tvm_common::{ array::{DataType, TVMContext}, diff --git a/rust/runtime/src/errors.rs b/rust/graph-runtime/src/errors.rs similarity index 100% rename from rust/runtime/src/errors.rs rename to rust/graph-runtime/src/errors.rs diff --git a/rust/runtime/src/graph.rs b/rust/graph-runtime/src/graph.rs similarity index 99% rename from rust/runtime/src/graph.rs rename to rust/graph-runtime/src/graph.rs index 518bf724f319..251eeabaca2b 100644 --- a/rust/runtime/src/graph.rs +++ b/rust/graph-runtime/src/graph.rs @@ -19,7 +19,7 @@ use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; -use failure::Error; +use anyhow::Error; use nom::{ character::complete::{alpha1, digit1}, number::complete::{le_i32, le_i64, le_u16, le_u32, le_u64, le_u8}, @@ -263,7 +263,7 @@ impl<'m, 't> GraphExecutor<'m, 't> { let func = lib .get_function(&attrs.func_name) - .ok_or_else(|| format_err!("Library is missing function {}", attrs.func_name))?; + .ok_or_else(|| anyhow!("Library is missing function {}", attrs.func_name))?; let arg_indices = node .inputs .iter() diff --git a/rust/runtime/src/lib.rs b/rust/graph-runtime/src/lib.rs similarity index 100% rename from rust/runtime/src/lib.rs rename to rust/graph-runtime/src/lib.rs diff --git a/rust/runtime/src/module/dso.rs b/rust/graph-runtime/src/module/dso.rs similarity index 100% rename from rust/runtime/src/module/dso.rs rename to rust/graph-runtime/src/module/dso.rs diff --git a/rust/runtime/src/module/mod.rs b/rust/graph-runtime/src/module/mod.rs similarity index 100% rename from rust/runtime/src/module/mod.rs rename to rust/graph-runtime/src/module/mod.rs diff --git a/rust/runtime/src/module/syslib.rs b/rust/graph-runtime/src/module/syslib.rs similarity index 100% rename from rust/runtime/src/module/syslib.rs rename to rust/graph-runtime/src/module/syslib.rs diff --git a/rust/runtime/src/threading.rs b/rust/graph-runtime/src/threading.rs similarity index 100% rename from rust/runtime/src/threading.rs rename to rust/graph-runtime/src/threading.rs diff --git a/rust/runtime/src/workspace.rs b/rust/graph-runtime/src/workspace.rs similarity index 99% rename from rust/runtime/src/workspace.rs rename to rust/graph-runtime/src/workspace.rs index 8344dfbb1adf..028e605e22cd 100644 --- a/rust/runtime/src/workspace.rs +++ b/rust/graph-runtime/src/workspace.rs @@ -23,7 +23,7 @@ use std::{ ptr, }; -use failure::Error; +use anyhow::Error; use crate::allocator::Allocation; diff --git a/rust/runtime/tests/.gitignore b/rust/graph-runtime/tests/.gitignore similarity index 100% rename from rust/runtime/tests/.gitignore rename to rust/graph-runtime/tests/.gitignore diff --git a/rust/runtime/tests/build_model.py b/rust/graph-runtime/tests/build_model.py similarity index 98% rename from rust/runtime/tests/build_model.py rename to rust/graph-runtime/tests/build_model.py index ddfa03bae97f..4169441d2942 100755 --- a/rust/runtime/tests/build_model.py +++ b/rust/graph-runtime/tests/build_model.py @@ -41,6 +41,7 @@ def main(): dshape = (32, 16) net = _get_model(dshape) mod, params = testing.create_workload(net) + import pdb; pdb.set_trace() graph, lib, params = relay.build( mod, 'llvm', params=params) diff --git a/rust/runtime/tests/test_graph_serde.rs b/rust/graph-runtime/tests/test_graph_serde.rs similarity index 94% rename from rust/runtime/tests/test_graph_serde.rs rename to rust/graph-runtime/tests/test_graph_serde.rs index 6cea4ad99a39..3d127a110260 100644 --- a/rust/runtime/tests/test_graph_serde.rs +++ b/rust/graph-runtime/tests/test_graph_serde.rs @@ -75,9 +75,9 @@ fn test_load_graph() { .unwrap() .get("func_name") .unwrap(), - "fused_nn_dense_nn_bias_add" + "fused_split" ); - assert_eq!(graph.nodes[3].inputs[0].index, 0); - assert_eq!(graph.nodes[4].inputs[0].index, 0); + assert_eq!(graph.nodes[5].inputs[0].index, 0); + assert_eq!(graph.nodes[6].inputs[0].index, 0); assert_eq!(graph.heads.len(), 3); } diff --git a/rust/runtime/tests/test_nn/Cargo.toml b/rust/graph-runtime/tests/test_nn/Cargo.toml similarity index 96% rename from rust/runtime/tests/test_nn/Cargo.toml rename to rust/graph-runtime/tests/test_nn/Cargo.toml index 89f4bf8aaf73..c14cc29329cc 100644 --- a/rust/runtime/tests/test_nn/Cargo.toml +++ b/rust/graph-runtime/tests/test_nn/Cargo.toml @@ -25,7 +25,7 @@ authors = ["TVM Contributors"] ndarray="0.12" serde = "1.0" serde_json = "1.0" -tvm-runtime = { path = "../../" } +tvm-graph-runtime = { path = "../../" } [build-dependencies] ar = "0.6" diff --git a/rust/runtime/tests/test_nn/build.rs b/rust/graph-runtime/tests/test_nn/build.rs similarity index 100% rename from rust/runtime/tests/test_nn/build.rs rename to rust/graph-runtime/tests/test_nn/build.rs diff --git a/rust/runtime/tests/test_nn/src/build_test_graph.py b/rust/graph-runtime/tests/test_nn/src/build_test_graph.py similarity index 100% rename from rust/runtime/tests/test_nn/src/build_test_graph.py rename to rust/graph-runtime/tests/test_nn/src/build_test_graph.py diff --git a/rust/runtime/tests/test_nn/src/main.rs b/rust/graph-runtime/tests/test_nn/src/main.rs similarity index 100% rename from rust/runtime/tests/test_nn/src/main.rs rename to rust/graph-runtime/tests/test_nn/src/main.rs diff --git a/rust/runtime/tests/test_tvm_basic/Cargo.toml b/rust/graph-runtime/tests/test_tvm_basic/Cargo.toml similarity index 95% rename from rust/runtime/tests/test_tvm_basic/Cargo.toml rename to rust/graph-runtime/tests/test_tvm_basic/Cargo.toml index d11531450298..94a0a894ceba 100644 --- a/rust/runtime/tests/test_tvm_basic/Cargo.toml +++ b/rust/graph-runtime/tests/test_tvm_basic/Cargo.toml @@ -23,7 +23,7 @@ authors = ["TVM Contributors"] [dependencies] ndarray="0.12" -tvm-runtime = { path = "../../" } +tvm-graph-runtime = { path = "../../" } [build-dependencies] ar = "0.6" diff --git a/rust/runtime/tests/test_tvm_basic/build.rs b/rust/graph-runtime/tests/test_tvm_basic/build.rs similarity index 100% rename from rust/runtime/tests/test_tvm_basic/build.rs rename to rust/graph-runtime/tests/test_tvm_basic/build.rs diff --git a/rust/runtime/tests/test_tvm_basic/src/build_test_lib.py b/rust/graph-runtime/tests/test_tvm_basic/src/build_test_lib.py similarity index 100% rename from rust/runtime/tests/test_tvm_basic/src/build_test_lib.py rename to rust/graph-runtime/tests/test_tvm_basic/src/build_test_lib.py diff --git a/rust/runtime/tests/test_tvm_basic/src/main.rs b/rust/graph-runtime/tests/test_tvm_basic/src/main.rs similarity index 100% rename from rust/runtime/tests/test_tvm_basic/src/main.rs rename to rust/graph-runtime/tests/test_tvm_basic/src/main.rs diff --git a/rust/runtime/tests/test_tvm_dso/Cargo.toml b/rust/graph-runtime/tests/test_tvm_dso/Cargo.toml similarity index 95% rename from rust/runtime/tests/test_tvm_dso/Cargo.toml rename to rust/graph-runtime/tests/test_tvm_dso/Cargo.toml index afe7f26e1220..ba301b849e82 100644 --- a/rust/runtime/tests/test_tvm_dso/Cargo.toml +++ b/rust/graph-runtime/tests/test_tvm_dso/Cargo.toml @@ -23,4 +23,4 @@ authors = ["TVM Contributors"] [dependencies] ndarray="0.12" -tvm-runtime = { path = "../../" } +tvm-graph-runtime = { path = "../../" } diff --git a/rust/runtime/tests/test_tvm_dso/build.rs b/rust/graph-runtime/tests/test_tvm_dso/build.rs similarity index 100% rename from rust/runtime/tests/test_tvm_dso/build.rs rename to rust/graph-runtime/tests/test_tvm_dso/build.rs diff --git a/rust/runtime/tests/test_tvm_dso/src/build_test_lib.py b/rust/graph-runtime/tests/test_tvm_dso/src/build_test_lib.py similarity index 100% rename from rust/runtime/tests/test_tvm_dso/src/build_test_lib.py rename to rust/graph-runtime/tests/test_tvm_dso/src/build_test_lib.py diff --git a/rust/runtime/tests/test_tvm_dso/src/main.rs b/rust/graph-runtime/tests/test_tvm_dso/src/main.rs similarity index 100% rename from rust/runtime/tests/test_tvm_dso/src/main.rs rename to rust/graph-runtime/tests/test_tvm_dso/src/main.rs diff --git a/rust/macros/Cargo.toml b/rust/macros/Cargo.toml index 784b35e2fdae..7abc9ae64f7c 100644 --- a/rust/macros/Cargo.toml +++ b/rust/macros/Cargo.toml @@ -32,5 +32,5 @@ proc-macro = true [dependencies] goblin = "0.0.24" proc-macro2 = "^1.0" -quote = "1.0" -syn = "1.0" +quote = "^1.0" +syn = { version = "1.0.17", features = ["full", "extra-traits"] } diff --git a/rust/macros/src/lib.rs b/rust/macros/src/lib.rs index 9f28c74febd6..e9ddc25ddf9c 100644 --- a/rust/macros/src/lib.rs +++ b/rust/macros/src/lib.rs @@ -17,121 +17,17 @@ * under the License. */ -extern crate proc_macro; - -use quote::quote; -use std::{fs::File, io::Read}; -use syn::parse::{Parse, ParseStream, Result}; -use syn::LitStr; - -use std::path::PathBuf; - -struct ImportModule { - importing_file: LitStr, -} - -impl Parse for ImportModule { - fn parse(input: ParseStream) -> Result { - let importing_file: LitStr = input.parse()?; - Ok(ImportModule { importing_file }) - } -} +use proc_macro::TokenStream; +mod import_module; +mod object; #[proc_macro] -pub fn import_module(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let import_module_args = syn::parse_macro_input!(input as ImportModule); - - let manifest = - std::env::var("CARGO_MANIFEST_DIR").expect("variable should always be set by Cargo."); - - let mut path = PathBuf::new(); - path.push(manifest); - path = path.join(import_module_args.importing_file.value()); - - let mut fd = File::open(&path) - .unwrap_or_else(|_| panic!("Unable to find TVM object file at `{}`", path.display())); - let mut buffer = Vec::new(); - fd.read_to_end(&mut buffer).unwrap(); - - let fn_names = match goblin::Object::parse(&buffer).unwrap() { - goblin::Object::Elf(elf) => elf - .syms - .iter() - .filter_map(|s| { - if s.st_type() == 0 || goblin::elf::sym::type_to_str(s.st_type()) == "FILE" { - return None; - } - match elf.strtab.get(s.st_name) { - Some(Ok(name)) if name != "" => { - Some(syn::Ident::new(name, proc_macro2::Span::call_site())) - } - _ => None, - } - }) - .collect::>(), - goblin::Object::Mach(goblin::mach::Mach::Binary(obj)) => { - obj.symbols() - .filter_map(|s| match s { - Ok((name, ref nlist)) - if nlist.is_global() - && nlist.n_sect != 0 - && !name.ends_with("tvm_module_ctx") => - { - Some(syn::Ident::new( - if name.starts_with('_') { - // Mach objects prepend a _ to globals. - &name[1..] - } else { - &name - }, - proc_macro2::Span::call_site(), - )) - } - _ => None, - }) - .collect::>() - } - _ => panic!("Unsupported object format."), - }; - - let extern_fns = quote! { - mod ext { - extern "C" { - #( - pub(super) fn #fn_names( - args: *const tvm_runtime::ffi::TVMValue, - type_codes: *const std::os::raw::c_int, - num_args: std::os::raw::c_int - ) -> std::os::raw::c_int; - )* - } - } - }; - - let fns = quote! { - use tvm_runtime::{ffi::TVMValue, TVMArgValue, TVMRetValue, FuncCallError}; - #extern_fns - - #( - pub fn #fn_names(args: &[TVMArgValue]) -> Result { - let (values, type_codes): (Vec, Vec) = args - .into_iter() - .map(|arg| { - let (val, code) = arg.to_tvm_value(); - (val, code as i32) - }) - .unzip(); - let exit_code = unsafe { - ext::#fn_names(values.as_ptr(), type_codes.as_ptr(), values.len() as i32) - }; - if exit_code == 0 { - Ok(TVMRetValue::default()) - } else { - Err(FuncCallError::get_with_context(stringify!(#fn_names).to_string())) - } - } - )* - }; +pub fn import_module(input: TokenStream) -> TokenStream { + import_module::macro_impl(input) +} - proc_macro::TokenStream::from(fns) +#[proc_macro_derive(Object, attributes(base, ref_name, type_key))] +pub fn macro_impl(input: TokenStream) -> TokenStream { + // let input = proc_macro2::TokenStream::from(input); + TokenStream::from(object::macro_impl(input)) } diff --git a/rust/frontend/.gitignore b/rust/tvm-rt/.gitignore similarity index 100% rename from rust/frontend/.gitignore rename to rust/tvm-rt/.gitignore diff --git a/rust/runtime/.travis.yml b/rust/tvm-rt/.travis.yml similarity index 100% rename from rust/runtime/.travis.yml rename to rust/tvm-rt/.travis.yml diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml new file mode 100644 index 000000000000..417f2567595c --- /dev/null +++ b/rust/tvm-rt/Cargo.toml @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm-rt" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust bindings for the TVM runtime API." +repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +lazy_static = "1.1" +ndarray = "0.12" +num-traits = "0.2" +tvm-sys = { version = "0.1", path = "../tvm-sys/", features = ["bindings"] } +tvm-macros = { version = "0.1", path = "../macros" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/frontend/README.md b/rust/tvm-rt/README.md similarity index 100% rename from rust/frontend/README.md rename to rust/tvm-rt/README.md diff --git a/rust/frontend/examples/resnet/Cargo.toml b/rust/tvm-rt/examples/resnet/Cargo.toml similarity index 100% rename from rust/frontend/examples/resnet/Cargo.toml rename to rust/tvm-rt/examples/resnet/Cargo.toml diff --git a/rust/frontend/examples/resnet/README.md b/rust/tvm-rt/examples/resnet/README.md similarity index 100% rename from rust/frontend/examples/resnet/README.md rename to rust/tvm-rt/examples/resnet/README.md diff --git a/rust/frontend/examples/resnet/build.rs b/rust/tvm-rt/examples/resnet/build.rs similarity index 100% rename from rust/frontend/examples/resnet/build.rs rename to rust/tvm-rt/examples/resnet/build.rs diff --git a/rust/frontend/examples/resnet/src/build_resnet.py b/rust/tvm-rt/examples/resnet/src/build_resnet.py similarity index 100% rename from rust/frontend/examples/resnet/src/build_resnet.py rename to rust/tvm-rt/examples/resnet/src/build_resnet.py diff --git a/rust/frontend/examples/resnet/src/main.rs b/rust/tvm-rt/examples/resnet/src/main.rs similarity index 100% rename from rust/frontend/examples/resnet/src/main.rs rename to rust/tvm-rt/examples/resnet/src/main.rs diff --git a/rust/frontend/src/context.rs b/rust/tvm-rt/src/context.rs similarity index 99% rename from rust/frontend/src/context.rs rename to rust/tvm-rt/src/context.rs index 6d08e391fc78..993fa1b4fcee 100644 --- a/rust/frontend/src/context.rs +++ b/rust/tvm-rt/src/context.rs @@ -46,9 +46,9 @@ use std::{ ptr, }; -use failure::Error; +use anyhow::Result; -use tvm_common::ffi; +use tvm_sys::ffi; use crate::{function, TVMArgValue}; @@ -236,7 +236,7 @@ impl TVMContext { } /// Synchronize the context stream. - pub fn sync(&self) -> Result<(), Error> { + pub fn sync(&self) -> Result<()> { check_call!(ffi::TVMSynchronize( self.device_type.0 as i32, self.device_id as i32, diff --git a/rust/frontend/src/errors.rs b/rust/tvm-rt/src/errors.rs similarity index 70% rename from rust/frontend/src/errors.rs rename to rust/tvm-rt/src/errors.rs index ceda69773a38..77dbba747527 100644 --- a/rust/frontend/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -17,29 +17,29 @@ * under the License. */ -pub use failure::Error; +use thiserror::Error; -#[derive(Debug, Fail)] -#[fail(display = "Cannot convert from an empty array.")] +#[derive(Debug, Error)] +#[error("Cannot convert from an empty array.")] pub struct EmptyArrayError; -#[derive(Debug, Fail)] -#[fail(display = "Handle `{}` is null.", name)] +#[derive(Debug, Error)] +#[error("Handle `{name}` is null.")] pub struct NullHandleError { pub name: String, } -#[derive(Debug, Fail)] -#[fail(display = "Function was not set in `function::Builder`")] +#[derive(Debug, Error)] +#[error("Function was not set in `function::Builder`")] pub struct FunctionNotFoundError; -#[derive(Debug, Fail)] -#[fail(display = "Expected type `{}` but found `{}`", expected, actual)] +#[derive(Debug, Error)] +#[error("Expected type `{expected}` but found `{actual}`")] pub struct TypeMismatchError { pub expected: String, pub actual: String, } -#[derive(Debug, Fail)] -#[fail(display = "Missing NDArray shape.")] +#[derive(Debug, Error)] +#[error("Missing NDArray shape.")] pub struct MissingShapeError; diff --git a/rust/frontend/src/function.rs b/rust/tvm-rt/src/function.rs similarity index 96% rename from rust/frontend/src/function.rs rename to rust/tvm-rt/src/function.rs index 8411b03592d1..07af33dd4e55 100644 --- a/rust/frontend/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -34,7 +34,8 @@ use std::{ sync::Mutex, }; -use failure::Error; +use anyhow::{ensure, Result}; +use lazy_static::lazy_static; use crate::{errors, ffi, Module, TVMArgValue, TVMRetValue}; @@ -179,11 +180,11 @@ impl<'a, 'm> Builder<'a, 'm> { /// Pushes multiple [`TVMArgValue`]s into the function argument buffer. pub fn args(&mut self, args: I) -> &mut Self where - I: IntoIterator, - TVMArgValue<'a>: From<&'a T>, + I: IntoIterator, + TVMArgValue<'a>: From, { args.into_iter().for_each(|arg| { - self.arg(&arg); + self.arg(arg); }); self } @@ -199,7 +200,7 @@ impl<'a, 'm> Builder<'a, 'm> { } /// Calls the function that created from `Builder`. - pub fn invoke(&mut self) -> Result { + pub fn invoke(&mut self) -> Result { #![allow(unused_unsafe)] ensure!(self.func.is_some(), errors::FunctionNotFoundError); @@ -252,8 +253,7 @@ unsafe extern "C" fn tvm_callback( let mut local_args: Vec = Vec::new(); let mut value = MaybeUninit::uninit().assume_init(); let mut tcode = MaybeUninit::uninit().assume_init(); - let rust_fn = - mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + let rust_fn = mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); for i in 0..len { value = args_list[i]; tcode = type_codes_list[i]; @@ -290,13 +290,13 @@ unsafe extern "C" fn tvm_callback( unsafe extern "C" fn tvm_callback_finalizer(fhandle: *mut c_void) { let _rust_fn = - mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); + mem::transmute::<*mut c_void, fn(&[TVMArgValue]) -> Result>(fhandle); // XXX: give converted functions lifetimes so they're not called after use } -fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { +fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> Function { let mut fhandle = ptr::null_mut() as ffi::TVMFunctionHandle; - let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; + let resource_handle = f as *mut fn(&[TVMArgValue]) -> Result; check_call!(ffi::TVMFuncCreateFromCFunc( Some(tvm_callback), resource_handle as *mut c_void, @@ -318,7 +318,7 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> F /// ``` /// # use tvm_frontend::{TVMArgValue, function, TVMRetValue}; /// # use tvm_frontend::function::Builder; -/// # use failure::Error; +/// # use anyhow::Error; /// use std::convert::TryInto; /// /// fn sum(args: &[TVMArgValue]) -> Result { @@ -339,10 +339,10 @@ fn convert_to_tvm_func(f: fn(&[TVMArgValue]) -> Result) -> F /// assert_eq!(ret, 60); /// ``` pub fn register>( - f: fn(&[TVMArgValue]) -> Result, + f: fn(&[TVMArgValue]) -> Result, name: S, override_: bool, -) -> Result<(), Error> { +) -> Result<()> { let func = convert_to_tvm_func(f); let name = CString::new(name.as_ref())?; check_call!(ffi::TVMFuncRegisterGlobal( @@ -362,7 +362,7 @@ pub fn register>( /// ``` /// # use std::convert::TryInto; /// # use tvm_frontend::{register_global_func, TVMArgValue, TVMRetValue}; -/// # use failure::Error; +/// # use anyhow::Error; /// # use tvm_frontend::function::Builder; /// /// register_global_func! { diff --git a/rust/frontend/src/lib.rs b/rust/tvm-rt/src/lib.rs similarity index 90% rename from rust/frontend/src/lib.rs rename to rust/tvm-rt/src/lib.rs index 10e70d2881c1..ab881c13a460 100644 --- a/rust/frontend/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -30,20 +30,20 @@ //! //! Checkout the `examples` repository for more details. -#[macro_use] -extern crate failure; -#[macro_use] -extern crate lazy_static; extern crate ndarray as rust_ndarray; -extern crate num_traits; -extern crate tvm_common; + +pub mod object; +pub mod string; + +pub use object::*; +pub use string::*; use std::{ ffi::{CStr, CString}, str, }; -use failure::Error; +use anyhow::Error; pub use crate::{ context::{TVMContext, TVMDeviceType}, @@ -51,16 +51,18 @@ pub use crate::{ function::Function, module::Module, ndarray::NDArray, - tvm_common::{ - errors as common_errors, - ffi::{self, DLDataType, TVMByteArray}, - packed_func::{TVMArgValue, TVMRetValue}, - }, +}; + +pub use tvm_sys::{ + errors as common_errors, + ffi::{self, DLDataType, TVMByteArray}, + packed_func::{TVMArgValue, TVMRetValue}, }; pub type DataType = DLDataType; // Macro to check the return call to TVM runtime shared library. +#[macro_export] macro_rules! check_call { ($e:expr) => {{ if unsafe { $e } != 0 { diff --git a/rust/frontend/src/module.rs b/rust/tvm-rt/src/module.rs similarity index 91% rename from rust/frontend/src/module.rs rename to rust/tvm-rt/src/module.rs index 1ae4bf752ed7..13b193039184 100644 --- a/rust/frontend/src/module.rs +++ b/rust/tvm-rt/src/module.rs @@ -27,8 +27,8 @@ use std::{ ptr, }; -use failure::Error; -use tvm_common::ffi; +use anyhow::{anyhow, ensure, Error}; +use tvm_sys::ffi; use crate::{errors, function::Function}; @@ -90,15 +90,14 @@ impl Module { .extension() .unwrap_or_else(|| std::ffi::OsStr::new("")) .to_str() - .ok_or_else(|| { - format_err!("Bad module load path: `{}`.", path.as_ref().display()) - })?, + .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?, )?; let func = Function::get("runtime.ModuleLoadFromFile").expect("API function always exists"); - let cpath = - CString::new(path.as_ref().to_str().ok_or_else(|| { - format_err!("Bad module load path: `{}`.", path.as_ref().display()) - })?)?; + let cpath = CString::new( + path.as_ref() + .to_str() + .ok_or_else(|| anyhow!("Bad module load path: `{}`.", path.as_ref().display()))?, + )?; let ret: Module = call_packed!(func, cpath.as_c_str(), ext.as_c_str())?.try_into()?; Ok(ret) } diff --git a/rust/frontend/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs similarity index 96% rename from rust/frontend/src/ndarray.rs rename to rust/tvm-rt/src/ndarray.rs index 6ebd3cb0705e..45f9fb93a08a 100644 --- a/rust/frontend/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -49,13 +49,13 @@ use std::{convert::TryFrom, mem, os::raw::c_int, ptr, slice, str::FromStr}; -use failure::Error; +use anyhow::{bail, ensure, Result}; use num_traits::Num; use rust_ndarray::{Array, ArrayD}; use std::convert::TryInto; use std::ffi::c_void; -use tvm_common::ffi::DLTensor; -use tvm_common::{ffi, TVMType}; +use tvm_sys::ffi::DLTensor; +use tvm_sys::{ffi, TVMType}; use crate::{errors, TVMByteArray, TVMContext}; @@ -147,7 +147,7 @@ impl NDArray { } /// Shows whether the underlying ndarray is contiguous in memory or not. - pub fn is_contiguous(&self) -> Result { + pub fn is_contiguous(&self) -> Result { Ok(match self.strides() { None => true, Some(strides) => { @@ -189,7 +189,7 @@ impl NDArray { /// assert_eq!(ndarray.shape(), Some(&mut shape[..])); /// assert_eq!(ndarray.to_vec::().unwrap(), data); /// ``` - pub fn to_vec(&self) -> Result, Error> { + pub fn to_vec(&self) -> Result> { ensure!(self.shape().is_some(), errors::EmptyArrayError); let earr = NDArray::empty( self.shape().ok_or(errors::MissingShapeError)?, @@ -209,7 +209,7 @@ impl NDArray { } /// Converts the NDArray to [`TVMByteArray`]. - pub fn to_bytearray(&self) -> Result { + pub fn to_bytearray(&self) -> Result { let v = self.to_vec::()?; Ok(TVMByteArray::from(v)) } @@ -239,7 +239,7 @@ impl NDArray { } /// Copies the NDArray to another target NDArray. - pub fn copy_to_ndarray(&self, target: NDArray) -> Result { + pub fn copy_to_ndarray(&self, target: NDArray) -> Result { if self.dtype() != target.dtype() { bail!( "{}", @@ -258,7 +258,7 @@ impl NDArray { } /// Copies the NDArray to a target context. - pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { + pub fn copy_to_ctx(&self, target: &TVMContext) -> Result { let tmp = NDArray::empty( self.shape().ok_or(errors::MissingShapeError)?, *target, @@ -273,7 +273,7 @@ impl NDArray { rnd: &ArrayD, ctx: TVMContext, dtype: TVMType, - ) -> Result { + ) -> Result { let shape = rnd.shape().to_vec(); let mut nd = NDArray::empty(&shape, ctx, dtype); let mut buf = Array::from_iter(rnd.into_iter().map(|&v| v as T)); @@ -304,8 +304,8 @@ impl NDArray { macro_rules! impl_from_ndarray_rustndarray { ($type:ty, $type_name:tt) => { impl<'a> TryFrom<&'a NDArray> for ArrayD<$type> { - type Error = Error; - fn try_from(nd: &NDArray) -> Result, Self::Error> { + type Error = anyhow::Error; + fn try_from(nd: &NDArray) -> Result> { ensure!(nd.shape().is_some(), errors::MissingShapeError); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( @@ -316,8 +316,8 @@ macro_rules! impl_from_ndarray_rustndarray { } impl<'a> TryFrom<&'a mut NDArray> for ArrayD<$type> { - type Error = Error; - fn try_from(nd: &mut NDArray) -> Result, Self::Error> { + type Error = anyhow::Error; + fn try_from(nd: &mut NDArray) -> Result> { ensure!(nd.shape().is_some(), errors::MissingShapeError); assert_eq!(nd.dtype(), TVMType::from_str($type_name)?, "Type mismatch"); Ok(Array::from_shape_vec( diff --git a/rust/tvm-rt/src/object.rs b/rust/tvm-rt/src/object.rs new file mode 100644 index 000000000000..6b139c693e6c --- /dev/null +++ b/rust/tvm-rt/src/object.rs @@ -0,0 +1,265 @@ +use std::convert::TryFrom; +use std::convert::TryInto; +use std::ffi::CString; +use std::ptr::NonNull; +use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index}; +use tvm_sys::{TVMArgValue, TVMRetValue}; + +type Deleter = unsafe extern "C" fn(object: *mut T) -> (); + +#[derive(Debug)] +#[repr(C)] +pub struct Object { + pub type_index: u32, + pub ref_count: i32, + pub fdeleter: Deleter, +} + +unsafe extern "C" fn delete(object: *mut Object) { + let typed_object: *mut T = object as *mut T; + T::typed_delete(typed_object); +} + +fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool { + let mut is_derived = 0; + crate::check_call!(ffi::TVMObjectDerivedFrom( + child_type_index, + parent_type_index, + &mut is_derived + )); + if is_derived == 0 { + false + } else { + true + } +} + +impl Object { + fn new(type_index: u32, deleter: Deleter) -> Object { + Object { + type_index, + // Note: do not touch this field directly again, this is + // a critical section, we write a 1 to the atomic which will now + // be managed by the C++ atomics. + // In the future we should probably use C-atomcis. + ref_count: 1, + fdeleter: deleter, + } + } + + fn get_type_index() -> u32 { + let type_key = T::TYPE_KEY; + let cstring = CString::new(type_key).expect("type key must not contain null characters"); + if type_key == "Object" { + return 0; + } else { + let mut index = 0; + unsafe { + if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 { + panic!(crate::get_last_error()) + } + } + return index; + } + } + + pub fn base_object() -> Object { + let index = Object::get_type_index::(); + Object::new(index, delete::) + } +} + +pub unsafe trait IsObject { + const TYPE_KEY: &'static str; + + fn as_object<'s>(&'s self) -> &'s Object; + + unsafe extern "C" fn typed_delete(object: *mut Self) { + // let object = Box::from_raw(object); + // drop(object) + } +} + +unsafe impl IsObject for Object { + const TYPE_KEY: &'static str = "Object"; + + fn as_object<'s>(&'s self) -> &'s Object { + self + } +} + +// unsafe impl IsObject for ObjectPtr { +// fn as_object<'s>(&'s self) -> &'s Object { +// unsafe { self.ptr.as_ref().as_object() } +// } +// } + +#[repr(C)] +pub struct ObjectPtr { + ptr: NonNull, +} + +impl ObjectPtr { + fn from_raw(object_ptr: *mut Object) -> Option> { + let non_null = NonNull::new(object_ptr); + non_null.map(|ptr| ObjectPtr { ptr }) + } +} + +impl Clone for ObjectPtr { + fn clone(&self) -> Self { + let raw_ptr = self.ptr.as_ptr() as *mut std::ffi::c_void; + unsafe { + assert_eq!(TVMObjectRetain(raw_ptr), 0); + } + ObjectPtr { ptr: self.ptr } + } +} + +impl Drop for ObjectPtr { + fn drop(&mut self) { + let ptr = self.ptr.as_ptr() as *mut std::ffi::c_void; + unsafe { assert_eq!(TVMObjectFree(ptr), 0) } + } +} + +impl ObjectPtr { + pub fn new(object: T) -> ObjectPtr { + let object_ptr = Box::new(object); + let ptr = NonNull::from(Box::leak(object_ptr)); + ObjectPtr { ptr } + } + + pub fn count(&self) -> i32 { + // need to do atomic read in C++ + // ABI compatible atomics is funky/hard. + self.as_object().ref_count + } + + fn as_object<'s>(&'s self) -> &'s Object { + unsafe { self.ptr.as_ref().as_object() } + } + + pub fn upcast(&self) -> ObjectPtr { + ObjectPtr { + ptr: self.ptr.cast(), + } + } + + pub fn downcast(&self) -> anyhow::Result> { + let child_index = Object::get_type_index::(); + let object_index = self.as_object().type_index; + + let is_derived = if child_index == object_index { + true + } else { + derived_from(child_index, object_index) + }; + + if is_derived { + Ok(ObjectPtr { + ptr: self.ptr.cast(), + }) + } else { + Err(anyhow::anyhow!("failed to downcast to object subtype")) + } + } +} + +impl std::ops::Deref for ObjectPtr { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { self.ptr.as_ref() } + } +} + +#[derive(Clone)] +pub struct ObjectRef(pub Option>); + +impl ObjectRef { + pub fn null() -> ObjectRef { + ObjectRef(None) + } +} + +pub trait ToObjectRef { + fn to_object_ref(&self) -> ObjectRef; +} + +impl ToObjectRef for ObjectRef { + fn to_object_ref(&self) -> ObjectRef { + self.clone() + } +} + +impl TryFrom for ObjectRef { + type Error = anyhow::Error; + + fn try_from(ret_val: TVMRetValue) -> Result { + match ret_val { + TVMRetValue::ObjectHandle(handle) => + // I think we can type the lower-level bindings even further. + { + let handle = handle as *mut Object; + Ok(ObjectRef(ObjectPtr::from_raw(handle))) + } + _ => Err(anyhow::anyhow!("unable to convert the result to an Object")), + } + } +} + +impl<'a> From<&ObjectRef> for TVMArgValue<'a> { + fn from(object_ref: &ObjectRef) -> TVMArgValue<'a> { + let object_ptr = &object_ref.0; + let raw_object_ptr = object_ptr + .as_ref() + .map(|p| p.ptr.as_ptr()) + .unwrap_or(std::ptr::null_mut()); + // Should be able to hide this unsafety in raw bindings. + let void_ptr = raw_object_ptr as *mut std::ffi::c_void; + TVMArgValue::ObjectHandle(void_ptr) + } +} + +impl From for TVMArgValue<'static> { + fn from(object_ref: ObjectRef) -> TVMArgValue<'static> { + let object_ptr = &object_ref.0; + let raw_object_ptr = object_ptr + .as_ref() + .map(|p| p.ptr.as_ptr()) + .unwrap_or(std::ptr::null_mut()); + // Should be able to hide this unsafety in raw bindings. + let void_ptr = raw_object_ptr as *mut std::ffi::c_void; + TVMArgValue::ObjectHandle(void_ptr) + } +} + +#[macro_export] +macro_rules! external_func { + (fn $name:ident ( $($arg:ident : $ty:ty),* ) -> $ret_type:ty as $ext_name:literal;) => { + ::paste::item! { + #[allow(non_upper_case_globals)] + static []: ::once_cell::sync::Lazy<&'static $crate::Function> = + ::once_cell::sync::Lazy::new(|| { + $crate::Function::get($ext_name) + .expect(concat!("unable to load external function", stringify!($ext_name), "from TVM registry.")) + }); + } + + pub fn $name($($arg : $ty),*) -> Result<$ret_type, anyhow::Error> { + let func_ref: &$crate::Function = ::paste::expr! { &*[] }; + let res = $crate::call_packed!(func_ref,$($arg),*)?; + let res = res.try_into()?; + Ok(res) + } + } +} + +external_func! { + fn debug_print(object: &ObjectRef) -> CString as "ir.DebugPrinter"; +} + +external_func! { + fn as_text(object: &ObjectRef) -> CString as "ir.TextPrinter"; +} diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs new file mode 100644 index 000000000000..fa7ab73a5a2a --- /dev/null +++ b/rust/tvm-rt/src/string.rs @@ -0,0 +1,83 @@ +use std::ffi::{CString, NulError}; +use std::os::raw::c_char; + +use super::{IsObject, Object, ObjectPtr, ObjectRef}; + +#[repr(C)] +pub struct StringObj { + base: Object, + data: *const c_char, + size: u64, +} + +unsafe impl IsObject for StringObj { + const TYPE_KEY: &'static str = "runtime.String"; + + fn as_object<'s>(&'s self) -> &'s Object { + &self.base + } +} + +pub struct String(Option>); + +impl String { + fn upcast(&self) -> ObjectRef { + ObjectRef(self.0.as_ref().map(|o| o.upcast())) + } +} + +impl String { + pub fn new(string: std::string::String) -> Result { + let cstring = CString::new(string)?; + println!("{:?}", cstring); + // The string is being corrupted. + // why is this wrong + let length = cstring.as_bytes().len(); + + let string_obj = StringObj { + base: Object::base_object::(), + data: cstring.into_raw(), + size: length as u64, + }; + + let object_ptr = ObjectPtr::new(string_obj); + Ok(String(Some(object_ptr))) + } + + pub fn to_cstring(&self) -> Result { + use std::slice; + let ptr = self.0.as_ref().unwrap().data; + let size = self.0.as_ref().unwrap().size; + unsafe { + let slice: &[u8] = slice::from_raw_parts(ptr as *const u8, size as usize); + CString::new(slice) + } + } + + pub fn to_string(&self) -> anyhow::Result { + let string = self.to_cstring()?.into_string()?; + Ok(string) + } +} + +// impl std::convert::From for std::string::String { +// fn from(string: String) -> std::string::String { +// u +// } +// } + +#[cfg(test)] +mod tests { + use super::String; + use crate::{debug_print, IsObject, Object, ObjectPtr, ObjectRef}; + + #[test] + fn test_string_debug() { + let s = String::new("foo".to_string()).unwrap(); + assert!(debug_print(&s.upcast()) + .expect("debug_print failed") + .into_string() + .expect("is cstring") + .contains("foo")) + } +} diff --git a/rust/frontend/src/value.rs b/rust/tvm-rt/src/value.rs similarity index 98% rename from rust/frontend/src/value.rs rename to rust/tvm-rt/src/value.rs index 453c1830a27b..f0e56ffe146b 100644 --- a/rust/frontend/src/value.rs +++ b/rust/tvm-rt/src/value.rs @@ -25,7 +25,7 @@ use std::convert::TryFrom; // use std::ffi::c_void; use crate::{Function, Module, NDArray, TVMArgValue, TVMRetValue}; -use tvm_common::{ +use tvm_sys::{ errors::ValueDowncastError, ffi::{TVMFunctionHandle, TVMModuleHandle}, try_downcast, @@ -135,7 +135,7 @@ impl TryFrom for NDArray { mod tests { use std::{convert::TryInto, str::FromStr}; - use tvm_common::{TVMByteArray, TVMContext, TVMType}; + use tvm_sys::{TVMByteArray, TVMContext, TVMType}; use super::*; diff --git a/rust/frontend/tests/basics/.gitignore b/rust/tvm-rt/tests/basics/.gitignore similarity index 100% rename from rust/frontend/tests/basics/.gitignore rename to rust/tvm-rt/tests/basics/.gitignore diff --git a/rust/frontend/tests/basics/Cargo.toml b/rust/tvm-rt/tests/basics/Cargo.toml similarity index 100% rename from rust/frontend/tests/basics/Cargo.toml rename to rust/tvm-rt/tests/basics/Cargo.toml diff --git a/rust/frontend/tests/basics/build.rs b/rust/tvm-rt/tests/basics/build.rs similarity index 100% rename from rust/frontend/tests/basics/build.rs rename to rust/tvm-rt/tests/basics/build.rs diff --git a/rust/frontend/tests/basics/src/main.rs b/rust/tvm-rt/tests/basics/src/main.rs similarity index 100% rename from rust/frontend/tests/basics/src/main.rs rename to rust/tvm-rt/tests/basics/src/main.rs diff --git a/rust/frontend/tests/basics/src/tvm_add.py b/rust/tvm-rt/tests/basics/src/tvm_add.py similarity index 100% rename from rust/frontend/tests/basics/src/tvm_add.py rename to rust/tvm-rt/tests/basics/src/tvm_add.py diff --git a/rust/frontend/tests/callback/Cargo.toml b/rust/tvm-rt/tests/callback/Cargo.toml similarity index 100% rename from rust/frontend/tests/callback/Cargo.toml rename to rust/tvm-rt/tests/callback/Cargo.toml diff --git a/rust/frontend/tests/callback/src/bin/array.rs b/rust/tvm-rt/tests/callback/src/bin/array.rs similarity index 100% rename from rust/frontend/tests/callback/src/bin/array.rs rename to rust/tvm-rt/tests/callback/src/bin/array.rs diff --git a/rust/frontend/tests/callback/src/bin/error.rs b/rust/tvm-rt/tests/callback/src/bin/error.rs similarity index 100% rename from rust/frontend/tests/callback/src/bin/error.rs rename to rust/tvm-rt/tests/callback/src/bin/error.rs diff --git a/rust/frontend/tests/callback/src/bin/float.rs b/rust/tvm-rt/tests/callback/src/bin/float.rs similarity index 100% rename from rust/frontend/tests/callback/src/bin/float.rs rename to rust/tvm-rt/tests/callback/src/bin/float.rs diff --git a/rust/frontend/tests/callback/src/bin/int.rs b/rust/tvm-rt/tests/callback/src/bin/int.rs similarity index 100% rename from rust/frontend/tests/callback/src/bin/int.rs rename to rust/tvm-rt/tests/callback/src/bin/int.rs diff --git a/rust/frontend/tests/callback/src/bin/string.rs b/rust/tvm-rt/tests/callback/src/bin/string.rs similarity index 100% rename from rust/frontend/tests/callback/src/bin/string.rs rename to rust/tvm-rt/tests/callback/src/bin/string.rs diff --git a/rust/tvm-rt/tests/test_ir.rs b/rust/tvm-rt/tests/test_ir.rs new file mode 100644 index 000000000000..32ae8ce35076 --- /dev/null +++ b/rust/tvm-rt/tests/test_ir.rs @@ -0,0 +1,38 @@ +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use std::str::FromStr; +use tvm::ir::IntImmNode; +use tvm::runtime::String as TString; +use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef}; +use tvm::{call_packed, DLDataType, Function}; +use tvm_sys::TVMRetValue; + +#[test] +fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) +} + +#[test] +fn test_new_string() -> anyhow::Result<()> { + let string = TString::new("hello world!".to_string())?; + Ok(()) +} + +#[test] +fn test_obj_build() -> anyhow::Result<()> { + let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not found."); + + let dt = DLDataType::from_str("int32").expect("Known datatype doesn't convert."); + + let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337) + .expect("foo") + .try_into() + .unwrap(); + + debug_print(&ret_val); + + Ok(()) +} diff --git a/rust/tvm-rt/tvm-sys/build.rs b/rust/tvm-rt/tvm-sys/build.rs new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/rust/common/Cargo.toml b/rust/tvm-sys/Cargo.toml similarity index 90% rename from rust/common/Cargo.toml rename to rust/tvm-sys/Cargo.toml index 60f5a6b336d4..7897e21046fa 100644 --- a/rust/common/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [package] -name = "tvm-common" +name = "tvm-sys" version = "0.1.0" authors = ["TVM Contributors"] license = "Apache-2.0" @@ -26,7 +26,8 @@ edition = "2018" bindings = [] [dependencies] -failure = { version = "0.1", default-features = false, features = ["derive"] } +thiserror = "^1.0" +anyhow = "^1.0" ndarray = "0.12" [build-dependencies] diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs new file mode 100644 index 000000000000..915827bf95f0 --- /dev/null +++ b/rust/tvm-sys/build.rs @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate bindgen; + +use std::path::PathBuf; + +// extern crate cmake; + +use std::env; +// use std::path::Path; +// use std::process::Command; +// use cmake::Config; + +// fn main() { +// if !Path::new("tvm/.git").exists() { +// let _ = Command::new("git") +// .args(&["submodule", "update", "--recursive", "--init"]) +// .status(); +// } + +// let dst = Config::new("tvm") +// .very_verbose(true) +// .build(); + +// // let dst = dst.join("build"); + +// let out_dir = env::var("OUT_DIR").unwrap(); + +// println!("{}", out_dir); +// // let _ = Command::new("mv") +// // .args(&[format!("{}/build/libtvm.dylib", dst.display()), out_dir]) +// // .status(); + +// println!("cargo:rustc-link-search=native={}/lib", dst.display()); +// // TODO(@jroesch): hack for dylib behavior +// for lib in &[/* "tvm", */ "tvm_runtime", /* "tvm_topi" */] { +// // let src = format!("{}/lib/lib{}.dylib", out_dir, lib); +// // let dst = format!("{}/../../../deps", out_dir); +// // let _ = Command::new("mv") +// // .args(&[src, dst]) +// // .status(); +// println!("cargo:rustc-link-lib=dylib={}", lib); +// } +// // "-Wl,-rpath,/scratch/library/" +// println!("cargo:rustc-env=TVM_HOME={}/build", dst.display()); +// // panic!(""); +// // cc::Build::new() +// // .cpp(true) +// // .flag("-std=c++11") +// // .flag("-Wno-ignored-qualifiers") +// // .flag("-Wno-unused-parameter") +// // .include("/Users/jroesch/Git/tvm/include") +// // .include("/Users/jroesch/Git/tvm/3rdparty/dmlc-core/include") +// // .include("/Users/jroesch/Git/tvm/3rdparty/dlpack/include") +// // .include("/Users/jroesch/Git/tvm/3rdparty/HalideIR/src") +// // .file("tvm_wrapper.cc") +// // .compile("tvm_ffi"); +// // println!("cargo:rustc-link-lib=dylib=tvm"); +// // println!("cargo:rustc-link-search=/Users/jroesch/Git/tvm/build"); +// } + +fn main() { + let tvm_home = option_env!("TVM_HOME").map(str::to_string).unwrap_or({ + let crate_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .canonicalize() + .unwrap(); + crate_dir + .parent() + .unwrap() + .parent() + .unwrap() + .to_str() + .unwrap() + .to_string() + }); + + if cfg!(feature = "bindings") { + println!("cargo:rerun-if-env-changed=TVM_HOME"); + // println!("cargo:rustc-link-lib=dylib=tvm_runtime"); + // TODO: move to core + // println!("cargo:rustc-link-lib=dylib=tvm_runtime"); + println!("cargo:rustc-link-lib=dylib=tvm"); + println!("cargo:rustc-link-search={}/build", tvm_home); + } + + // @see rust-bindgen#550 for `blacklist_type` + bindgen::Builder::default() + .header(format!("{}/include/tvm/runtime/c_runtime_api.h", tvm_home)) + .header(format!("{}/include/tvm/runtime/c_backend_api.h", tvm_home)) + .clang_arg(format!("-I{}/3rdparty/dlpack/include/", tvm_home)) + .clang_arg(format!("-I{}/include/", tvm_home)) + .blacklist_type("max_align_t") + .layout_tests(false) + .derive_partialeq(true) + .derive_eq(true) + .generate() + .expect("unable to generate bindings") + .write_to_file(PathBuf::from("src/c_runtime_api.rs")) + .expect("can not write the bindings!"); +} diff --git a/rust/common/src/array.rs b/rust/tvm-sys/src/array.rs similarity index 100% rename from rust/common/src/array.rs rename to rust/tvm-sys/src/array.rs diff --git a/rust/common/src/errors.rs b/rust/tvm-sys/src/errors.rs similarity index 85% rename from rust/common/src/errors.rs rename to rust/tvm-sys/src/errors.rs index 4b8a9ffcb1eb..8479ec62f19f 100644 --- a/rust/common/src/errors.rs +++ b/rust/tvm-sys/src/errors.rs @@ -17,18 +17,17 @@ * under the License. */ -#[derive(Debug, Fail)] -#[fail( - display = "Could not downcast `{}` into `{}`", - expected_type, actual_type -)] +use thiserror::Error; + +#[derive(Error, Debug)] +#[error("invalid header (expected {expected_type:?}, found {actual_type:?})")] pub struct ValueDowncastError { pub actual_type: String, pub expected_type: &'static str, } -#[derive(Debug, Fail)] -#[fail(display = "Function call `{}` returned error: {}", context, message)] +#[derive(Error, Debug)] +#[error("Function call `{context:?}` returned error: {message:?}")] pub struct FuncCallError { context: String, message: String, diff --git a/rust/common/src/lib.rs b/rust/tvm-sys/src/lib.rs similarity index 97% rename from rust/common/src/lib.rs rename to rust/tvm-sys/src/lib.rs index 2ae64e7a32b3..826da8a58316 100644 --- a/rust/common/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -20,9 +20,6 @@ //! This crate contains the refactored basic components required //! for `runtime` and `frontend` TVM crates. -#[macro_use] -extern crate failure; - /// Unified ffi module for both runtime and frontend crates. pub mod ffi { #![allow(non_camel_case_types, non_snake_case, non_upper_case_globals, unused)] diff --git a/rust/common/src/packed_func.rs b/rust/tvm-sys/src/packed_func.rs similarity index 97% rename from rust/common/src/packed_func.rs rename to rust/tvm-sys/src/packed_func.rs index f3bac39b6a10..9594fe1a1bd3 100644 --- a/rust/common/src/packed_func.rs +++ b/rust/tvm-sys/src/packed_func.rs @@ -363,3 +363,11 @@ impl Default for TVMRetValue { Self::Int(0) } } + +impl TryFrom for std::ffi::CString { + type Error = ValueDowncastError; + fn try_from(val: TVMRetValue) -> Result { + try_downcast!(val -> std::ffi::CString, + |TVMRetValue::Str(val)| { val.into() }) + } +} diff --git a/rust/common/src/value.rs b/rust/tvm-sys/src/value.rs similarity index 97% rename from rust/common/src/value.rs rename to rust/tvm-sys/src/value.rs index 321cebefa873..c23c56adbf20 100644 --- a/rust/common/src/value.rs +++ b/rust/tvm-sys/src/value.rs @@ -19,6 +19,8 @@ use std::{os::raw::c_char, str::FromStr}; +use thiserror::Error; + use crate::ffi::*; impl DLDataType { @@ -31,11 +33,11 @@ impl DLDataType { } } -#[derive(Debug, Fail)] +#[derive(Debug, Error)] pub enum ParseTvmTypeError { - #[fail(display = "invalid number: {}", _0)] + #[error("invalid number: {0}")] InvalidNumber(std::num::ParseIntError), - #[fail(display = "unknown type: {}", _0)] + #[error("unknown type: {0}")] UnknownType(String), } @@ -126,8 +128,8 @@ impl_pod_tvm_value!(v_float64, f64, f32, f64); impl_pod_tvm_value!(v_type, DLDataType); impl_pod_tvm_value!(v_ctx, TVMContext); -#[derive(Debug, Fail)] -#[fail(display = "unsupported device: {}", _0)] +#[derive(Debug, Error)] +#[error("unsupported device: {0}")] pub struct UnsupportedDeviceError(String); macro_rules! impl_tvm_context { diff --git a/rust/tvm/.gitignore b/rust/tvm/.gitignore new file mode 100644 index 000000000000..2430329c78b6 --- /dev/null +++ b/rust/tvm/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/tvm/.travis.yml b/rust/tvm/.travis.yml new file mode 100644 index 000000000000..e963b7c0ede5 --- /dev/null +++ b/rust/tvm/.travis.yml @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/frontend/Cargo.toml b/rust/tvm/Cargo.toml similarity index 83% rename from rust/frontend/Cargo.toml rename to rust/tvm/Cargo.toml index 920d069109e9..4cbb6193d952 100644 --- a/rust/frontend/Cargo.toml +++ b/rust/tvm/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [package] -name = "tvm-frontend" +name = "tvm" version = "0.1.0" license = "Apache-2.0" description = "Rust frontend support for TVM" @@ -29,11 +29,17 @@ authors = ["TVM Contributors"] edition = "2018" [dependencies] -failure = "0.1" +thiserror = "^1.0" +anyhow = "^1.0" lazy_static = "1.1" ndarray = "0.12" num-traits = "0.2" -tvm-common = { version = "0.1", path = "../common/", features = ["bindings"] } +tvm-rt = { version = "0.1", path = "../tvm-rt/" } +tvm-sys = { version = "0.1", path = "../tvm-sys/" } +tvm-macros = { version = "*", path = "../macros/" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" [features] blas = ["ndarray/blas"] diff --git a/rust/tvm/README.md b/rust/tvm/README.md new file mode 100644 index 000000000000..01e088f2ea81 --- /dev/null +++ b/rust/tvm/README.md @@ -0,0 +1,235 @@ + + + + + + + + + + + + + + + + + +# TVM Runtime Frontend Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` + +## What Does This Crate Offer? + +Here is a major workflow + +1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) +2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. +3. Deploy your models using **Rust** :heart: + +### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k + +Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. + +Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM + +```python +block = get_model('resnet18_v1', pretrained=True) + +sym, params = relay.frontend.from_mxnet(block, shape_dict) +# compile the model +with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build( + net, target, params=params) +# same the model artifacts +lib.save(os.path.join(target_dir, "deploy_lib.o")) +cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), + [os.path.join(target_dir, "deploy_lib.o")]) + +with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) +with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) +``` + +Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image + +![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) + +as demostrated in the following Rust snippet + +```rust + let graph = fs::read_to_string("deploy_graph.json")?; + // load the built module + let lib = Module::load(&Path::new("deploy_lib.so"))?; + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + )?; + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec = fs::read("deploy_param.params")?; + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr)?; + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input)?; + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,)?; + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output)?; + // flatten the output as Vec + let output = output.to_vec::()?; +``` + +and the model correctly predicts the input image as **tiger cat**. + +## Installations + +Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. + +## Supported TVM Functionalities + +### Use TVM to Generate Shared Library + +One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. + +```python +import os +import tvm +from tvm import te +from tvm.contrib import cc + +def test_add(target_dir): + if not tvm.runtime.enabled("cuda"): + print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) + return + n = te.var("n") + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") + + fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) + fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) + cc.create_shared(os.path.join(target_dir, "add_gpu.so"), + [os.path.join(target_dir, "add_gpu.o")]) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + test_add(sys.argv[1]) +``` + +### Run the Generated Shared Library + +The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. + +```rust +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); + let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); + assert!(fadd.enabled("gpu")); + fadd.import_module(fadd_dep); + fadd.entry(); + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret)? + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} +``` + +**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by +`cargo:rustc-link-search=native=add_gpu`. + +See the tests and examples custom `build.rs` for more details. + +### Convert and Register a Rust Function as a TVM Packed Function + +One can use `register_global_func!` macro to convert and register a Rust +function of type `fn(&[TVMArgValue]) -> Result` to a global TVM **packed function** as follows + +```rust +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e).unwrap(); + let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); + ret += rnd.scalar_sum(); + } + let ret_val = TVMRetValue::from(&ret); + Ok(ret_val) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut registered = function::Builder::default(); + let ret: f64 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + + assert_eq!(ret, 14f64); +} +``` diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml new file mode 100644 index 000000000000..e1a474eb5479 --- /dev/null +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "resnet" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } +image = "0.20" +csv = "1.1" diff --git a/rust/tvm/examples/resnet/README.md b/rust/tvm/examples/resnet/README.md new file mode 100644 index 000000000000..d6e32f7fa768 --- /dev/null +++ b/rust/tvm/examples/resnet/README.md @@ -0,0 +1,45 @@ + + + + + + + + + + + + + + + + + +## Resnet example + +This end-to-end example shows how to: +* build `Resnet 18` with `tvm` from Python +* use the provided Rust frontend API to test for an input image + +To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html). + +* **Build the example**: `cargo build + +To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with +`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. + +* **Run the example**: `cargo run` + +Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with + +``` +let output = Command::new("python") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .arg(&format!("--pretrained")) + .output() + .expect("Failed to execute command"); +``` + +Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`! diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs new file mode 100644 index 000000000000..b9a3c4ccdf12 --- /dev/null +++ b/rust/tvm/examples/resnet/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::Path, process::Command}; + +fn main() { + let output = Command::new("python3") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), + "Could not prepare demo: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + println!( + "cargo:rustc-link-search=native={}", + env!("CARGO_MANIFEST_DIR") + ); +} diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py new file mode 100644 index 000000000000..49c67bf1c4f3 --- /dev/null +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import csv +import logging +from os import path as osp +import sys + +import numpy as np + +import tvm +from tvm import te +from tvm import relay +from tvm.relay import testing +from tvm.contrib import graph_runtime, cc + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser(description='Resnet build example') +aa = parser.add_argument +aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') +aa('--pretrained', action='store_true', help='use a pretrained resnet') +aa('--batch-size', type=int, default=1, help='input image batch size') +aa('--opt-level', type=int, default=3, + help='level of optimization. 0 is unoptimized and 3 is the highest level') +aa('--target', type=str, default='llvm', help='target context for compilation') +aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') +aa('--image-name', type=str, default='cat.png', help='name of input image to download') +args = parser.parse_args() + +build_dir = args.build_dir +batch_size = args.batch_size +opt_level = args.opt_level +target = tvm.target.create(args.target) +image_shape = tuple(map(int, args.image_shape.split(","))) +data_shape = (batch_size,) + image_shape + +def build(target_dir): + """ Compiles resnet18 with TVM""" + deploy_lib = osp.join(target_dir, 'deploy_lib.o') + if osp.exists(deploy_lib): + return + + if args.pretrained: + # needs mxnet installed + from mxnet.gluon.model_zoo.vision import get_model + + # if `--pretrained` is enabled, it downloads a pretrained + # resnet18 trained on imagenet1k dataset for image classification task + block = get_model('resnet18_v1', pretrained=True) + net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) + # we want a probability so add a softmax operator + net = relay.Function(net.params, relay.nn.softmax(net.body), + None, net.type_params, net.attrs) + else: + # use random weights from relay.testing + net, params = relay.testing.resnet.get_workload( + num_layers=18, batch_size=batch_size, image_shape=image_shape) + + # compile the model + with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build_module.build(net, target, params=params) + + # save the model artifacts + lib.save(deploy_lib) + cc.create_shared(osp.join(target_dir, "deploy_lib.so"), + [osp.join(target_dir, "deploy_lib.o")]) + + with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph) + + with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) + +def download_img_labels(): + """ Download an image and imagenet1k class labels for test""" + from mxnet.gluon.utils import download + + img_name = 'cat.png' + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'synset.txt' + download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) + download(synset_url, synset_name) + + with open(synset_name) as fin: + synset = eval(fin.read()) + + with open("synset.csv", "w") as fout: + w = csv.writer(fout) + w.writerows(synset.items()) + +def test_build(build_dir): + """ Sanity check with random input""" + graph = open(osp.join(build_dir, "deploy_graph.json")).read() + lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so")) + params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) + input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.load_params(params) + module.run(data=input_data) + out = module.get_output(0).asnumpy() + + +if __name__ == '__main__': + logger.info("building the model") + build(build_dir) + logger.info("build was successful") + logger.info("test the build artifacts") + test_build(build_dir) + logger.info("test was successful") + if args.pretrained: + download_img_labels() + logger.info("image and synset downloads are successful") diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs new file mode 100644 index 000000000000..0aed72b1eb52 --- /dev/null +++ b/rust/tvm/examples/resnet/src/main.rs @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate csv; +extern crate image; +extern crate ndarray; +extern crate tvm_frontend as tvm; + +use std::{ + collections::HashMap, + convert::TryInto, + fs::{self, File}, + path::Path, + str::FromStr, +}; + +use image::{FilterType, GenericImageView}; +use ndarray::{Array, ArrayD, Axis}; + +use tvm::*; + +fn main() { + let ctx = TVMContext::cpu(0); + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); + println!("original image dimensions: {:?}", img.dimensions()); + // for bigger size images, one needs to first resize to 256x256 + // with `img.resize_exact` method and then `image.crop` to 224x224 + let img = img.resize(224, 224, FilterType::Nearest).to_rgb(); + println!("resized image dimensions: {:?}", img.dimensions()); + let mut pixels: Vec = vec![]; + for pixel in img.pixels() { + let tmp = pixel.data; + // normalize the RGB channels using mean, std of imagenet1k + let tmp = [ + (tmp[0] as f32 - 123.0) / 58.395, // R + (tmp[1] as f32 - 117.0) / 57.12, // G + (tmp[2] as f32 - 104.0) / 57.375, // B + ]; + for e in &tmp { + pixels.push(*e); + } + } + + let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); + let arr: ArrayD = arr.permuted_axes([2, 0, 1]).into_dyn(); + // make arr shape as [1, 3, 224, 224] acceptable to resnet + let arr = arr.insert_axis(Axis(0)); + // create input tensor from rust's ndarray + let input = NDArray::from_rust_ndarray( + &arr, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ) + .unwrap(); + println!( + "input size is {:?}", + input.shape().expect("cannot get the input shape") + ); + let graph = + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + // load the built module + let lib = Module::load(&Path::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/deploy_lib.so" + ))) + .unwrap(); + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + graph, + &lib, + &ctx.device_type, + &ctx.device_id + ) + .unwrap(); + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap(); + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec = + fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr).unwrap(); + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data".to_string(), &input).unwrap(); + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,).unwrap(); + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = NDArray::empty( + output_shape, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output).unwrap(); + // flatten the output as Vec + let output = output.to_vec::().unwrap(); + // find the maximum entry in the output and its index + let mut argmax = -1; + let mut max_prob = 0.; + for i in 0..output.len() { + if output[i] > max_prob { + max_prob = output[i]; + argmax = i as i32; + } + } + // create a hash map of (class id, class name) + let mut synset: HashMap = HashMap::new(); + let file = File::open("synset.csv").unwrap(); + let mut rdr = csv::ReaderBuilder::new() + .has_headers(true) + .from_reader(file); + + for result in rdr.records() { + let record = result.unwrap(); + let id: i32 = record[0].parse().unwrap(); + let cls = record[1].to_string(); + synset.insert(id, cls); + } + + println!( + "input image belongs to the class `{}` with probability {}", + synset + .get(&argmax) + .expect("cannot find the class id for argmax"), + max_prob + ); +} diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs new file mode 100644 index 000000000000..fb0a1c38fe5a --- /dev/null +++ b/rust/tvm/src/ir/array.rs @@ -0,0 +1,69 @@ +use crate::runtime::function::Builder; +use crate::runtime::object::{ObjectRef, ToObjectRef}; +use std::convert::{TryFrom, TryInto}; +use std::marker::PhantomData; +use tvm_sys::TVMRetValue; + +use anyhow::Result; + +pub struct Array { + object: ObjectRef, + _data: PhantomData, +} + +impl Array { + pub fn from_vec(data: Vec) -> Result> { + let iter = data.iter().map(|element| element.to_object_ref()); + + let array_data = Builder::default() + .get_function("node.Array") + .args(iter) + .invoke()? + .try_into()?; + + Ok(Array { + object: array_data, + _data: PhantomData, + }) + } + + pub fn get(&self, index: isize) -> Result + where + T: TryFrom, + { + // TODO(@jroesch): why do we used a signed index here? + let element: T = Builder::default() + .get_function("node.ArrayGetItem") + .arg(self.object.clone()) + .arg(index) + .invoke()? + .try_into()?; + + Ok(element) + } +} +// mod array_api { +// extern_fn! { +// fn _create_array( +// } +// } + +#[cfg(test)] +mod tests { + use super::Array; + use crate::ir::relay::Var; + use crate::runtime::object::ObjectRef; + use anyhow::Result; + + #[test] + fn create_array_and_get() -> Result<()> { + let vec = vec![ + Var::new("foo".into(), ObjectRef::null()), + Var::new("bar".into(), ObjectRef::null()), + ]; + let array = Array::from_vec(vec)?; + assert_eq!(array.get(0)?.name_hint().to_string()?, "foo"); + assert_eq!(array.get(1)?.name_hint().to_string()?, "bar"); + Ok(()) + } +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs new file mode 100644 index 000000000000..bc667fdb19b8 --- /dev/null +++ b/rust/tvm/src/ir/mod.rs @@ -0,0 +1,17 @@ +use crate::runtime::Object; +use crate::DataType; + +pub mod array; +pub mod relay; + +#[repr(C)] +pub struct PrimExprNode { + pub base: Object, + pub dtype: DataType, +} + +#[repr(C)] +pub struct IntImmNode { + pub base: PrimExprNode, + pub value: i64, +} diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs new file mode 100644 index 000000000000..914f70409b8f --- /dev/null +++ b/rust/tvm/src/ir/relay/mod.rs @@ -0,0 +1,218 @@ +use super::array::Array; +use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, ToObjectRef}; +use crate::DataType; +use std::convert::TryFrom; +use std::convert::TryInto; +use tvm_macros::Object; +use tvm_rt::TVMRetValue; + +#[repr(C)] +pub struct IdNode { + pub base: Object, + pub name_hint: TString, +} + +unsafe impl IsObject for IdNode { + const TYPE_KEY: &'static str = "relay.Id"; + + fn as_object<'s>(&'s self) -> &'s Object { + &self.base + } +} + +#[repr(C)] +pub struct Id(Option>); + +impl Id { + fn new(name_hint: TString) -> Id { + let node = IdNode { + base: Object::base_object::(), + name_hint: name_hint, + }; + Id(Some(ObjectPtr::new(node))) + } + + fn upcast(&self) -> ObjectRef { + ObjectRef(self.0.as_ref().map(|o| o.upcast())) + } +} + +// define_ref!(Id, IdNode); + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseExpr"] +#[type_key = "Expr"] +pub struct BaseExprNode { + pub base: Object, +} + +#[repr(C)] +pub struct PrimExprNode { + pub base: BaseExprNode, + pub datatype: DataType, +} + +impl BaseExprNode { + fn base() -> BaseExprNode { + BaseExprNode { + base: Object::base_object::(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Expr"] +#[type_key = "relay.Expr"] +pub struct RelayExpr { + pub base: BaseExprNode, + pub span: ObjectRef, + pub checked_type: ObjectRef, +} + +impl RelayExpr { + fn base() -> RelayExpr { + RelayExpr { + base: BaseExprNode::base::(), + span: ObjectRef::null(), + checked_type: ObjectRef::null(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalVar"] +#[type_key = "relay.GlobalVar"] +pub struct GlobalVarNode { + pub base: RelayExpr, + pub name_hint: TString, +} + +impl GlobalVar { + pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + let node = GlobalVarNode { + base: RelayExpr::base::(), + // span: span, + // checked_type: ObjectRef(None),, + name_hint: TString::new(name_hint).unwrap(), + }; + GlobalVar(Some(ObjectPtr::new(node))) + } + + pub fn upcast(&self) -> ObjectRef { + ObjectRef(self.0.as_ref().map(|o| o.upcast())) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Constant"] +#[type_key = "relay.Constant"] +pub struct ConstantNode { + pub base: RelayExpr, + pub data: ObjectRef, // make this NDArray. +} + +impl Constant { + pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant { + let node = ConstantNode { + base: RelayExpr::base::(), + data: data, + }; + Constant(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Var"] +#[type_key = "relay.Var"] +pub struct VarNode { + pub base: RelayExpr, + pub vid: Id, + pub type_annotation: ObjectRef, +} + +impl Var { + pub fn new(name_hint: String, _span: ObjectRef) -> Var { + let node = VarNode { + base: RelayExpr::base::(), + vid: Id::new(TString::new(name_hint.to_string()).unwrap()), + type_annotation: ObjectRef::null(), + }; + Var(Some(ObjectPtr::new(node))) + } + + pub fn name_hint(&self) -> &TString { + &self.vid.0.as_ref().unwrap().name_hint + } +} + +type Type = ObjectRef; +type Attrs = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Call"] +#[type_key = "relay.Call"] +pub struct CallNode { + pub base: RelayExpr, + pub op: Expr, + pub args: Array, + pub attrs: ObjectRef, + pub type_args: Array, +} + +impl Call { + pub fn new( + op: Expr, + args: Array, + attrs: Attrs, + type_args: Array, + _span: ObjectRef, + ) -> Call { + let node = CallNode { + base: RelayExpr::base::(), + op: op, + args: args, + attrs: attrs, + type_args: type_args, + }; + Call(Some(ObjectPtr::new(node))) + } +} + + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::{as_text, String as TString}; + use anyhow::Result; + + #[test] + fn test_id() -> Result<()> { + let string = TString::new("foo".to_string()).expect("bar"); + let id = Id::new(string); + let cstr = as_text(&id.upcast())?; + assert!(cstr.into_string()?.contains("relay.Id")); + Ok(()) + } + + #[test] + fn test_global() -> Result<()> { + let gv = GlobalVar::new("main".to_string(), ObjectRef::null()); + let cstr = as_text(&gv.upcast())?; + assert!(cstr.into_string()?.contains("@main")); + Ok(()) + } + + #[test] + fn test_var() -> Result<()> { + let var = Var::new("local".to_string(), ObjectRef::null()); + let cstr = as_text(&var.upcast())?; + assert!(cstr.into_string()?.contains("%local")); + Ok(()) + } +} diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs new file mode 100644 index 000000000000..a7a4e641776d --- /dev/null +++ b/rust/tvm/src/lib.rs @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime frontend. +//! +//! One particular use case is that given optimized deep learning model artifacts, +//! (compiled with TVM) which include a shared library +//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them +//! in Rust idomatically to create a TVM Graph Runtime and +//! run the model for some inputs and get the +//! desired predictions *all in Rust*. +//! +//! Checkout the `examples` repository for more details. + +pub use crate::{ + context::{TVMContext, TVMDeviceType}, + errors::*, + function::Function, + module::Module, + ndarray::NDArray, +}; + +// TODO: refactor +pub use tvm_sys::{ + errors as common_errors, + ffi::{self, DLDataType, TVMByteArray}, + packed_func::{TVMArgValue, TVMRetValue}, +}; + +pub type DataType = DLDataType; + +pub use tvm_rt::context; +pub use tvm_rt::errors; +pub use tvm_rt::function; +pub use tvm_rt::module; +pub use tvm_rt::ndarray; +pub use tvm_rt::value; +pub mod ir; +pub mod runtime; + +pub use runtime::version; diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs new file mode 100644 index 000000000000..57d43eea81c9 --- /dev/null +++ b/rust/tvm/src/runtime/mod.rs @@ -0,0 +1 @@ +pub use tvm_rt::*; diff --git a/rust/tvm/tests/basics/.gitignore b/rust/tvm/tests/basics/.gitignore new file mode 100644 index 000000000000..10a4b225a705 --- /dev/null +++ b/rust/tvm/tests/basics/.gitignore @@ -0,0 +1,7 @@ +/target +**/*.rs.bk +Cargo.lock +*.o +*.so +*.ptx +*.json diff --git a/rust/tvm/tests/basics/Cargo.toml b/rust/tvm/tests/basics/Cargo.toml new file mode 100644 index 000000000000..0b059da7727b --- /dev/null +++ b/rust/tvm/tests/basics/Cargo.toml @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "basics" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } + +[features] +default = ["cpu"] +cpu = [] +gpu = [] diff --git a/rust/tvm/tests/basics/build.rs b/rust/tvm/tests/basics/build.rs new file mode 100644 index 000000000000..77a3bae3627d --- /dev/null +++ b/rust/tvm/tests/basics/build.rs @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +fn main() { + let out_dir = std::env::var("OUT_DIR").unwrap(); + + let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py")) + .args(&[ + if cfg!(feature = "cpu") { + "llvm" + } else { + "cuda" + }, + &std::env::var("OUT_DIR").unwrap(), + ]) + .output() + .expect("Failed to execute command"); + assert!( + std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs new file mode 100644 index 000000000000..ca53dcf999dc --- /dev/null +++ b/rust/tvm/tests/basics/src/main.rs @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray as rust_ndarray; +extern crate tvm_frontend as tvm; + +use std::str::FromStr; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + + let (ctx, ctx_name) = if cfg!(feature = "cpu") { + (TVMContext::cpu(0), "cpu") + } else { + (TVMContext::gpu(0), "gpu") + }; + let dtype = DLDataType::from_str("float32").unwrap(); + let mut arr = NDArray::empty(shape, ctx, dtype); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = NDArray::empty(shape, ctx, dtype); + let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); + if !fadd.enabled(ctx_name) { + return; + } + if cfg!(feature = "gpu") { + fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); + } + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .arg(&mut ret) + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} diff --git a/rust/tvm/tests/basics/src/tvm_add.py b/rust/tvm/tests/basics/src/tvm_add.py new file mode 100755 index 000000000000..3911d4074e45 --- /dev/null +++ b/rust/tvm/tests/basics/src/tvm_add.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os.path as osp +import sys + +import tvm +from tvm import te +from tvm.contrib import cc + + +def main(target, out_dir): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name='C') + s = te.create_schedule(C.op) + + if target == 'cuda': + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, te.thread_axis('blockIdx.x')) + s[C].bind(tx, te.thread_axis('threadIdx.x')) + + fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd') + + fadd.save(osp.join(out_dir, 'test_add.o')) + if target == 'cuda': + fadd.imported_modules[0].save(osp.join(out_dir, 'test_add.ptx')) + cc.create_shared( + osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')]) + + +if __name__ == '__main__': + main(sys.argv[1], sys.argv[2]) + diff --git a/rust/tvm/tests/callback/Cargo.toml b/rust/tvm/tests/callback/Cargo.toml new file mode 100644 index 000000000000..5c89d2ac6375 --- /dev/null +++ b/rust/tvm/tests/callback/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "callback" +version = "0.0.0" +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs new file mode 100644 index 000000000000..cb4a8229c401 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +extern crate ndarray as rust_ndarray; +#[macro_use] +extern crate tvm_frontend as tvm; + +use rust_ndarray::ArrayD; +use std::{ + convert::{TryFrom, TryInto}, + str::FromStr, +}; + +use tvm::{errors::Error, *}; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = NDArray::empty( + shape, TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap() + ); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e)?; + let rnd: ArrayD = ArrayD::try_from(&arr)?; + ret += rnd.scalar_sum(); + } + Ok(TVMRetValue::from(ret)) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = NDArray::empty( + shape, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ); + arr.copy_from_buffer(data.as_mut_slice()); + + let mut registered = function::Builder::default(); + let ret: f32 = registered + .get_function("sum") + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 7f32); +} diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs new file mode 100644 index 000000000000..c9f9a6f771cf --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::panic; + +use tvm_frontend::{errors::Error, *}; + +fn main() { + register_global_func! { + fn error(_args: &[TVMArgValue]) -> Result { + Err(errors::TypeMismatchError{ + expected: "i64".to_string(), + actual: "f64".to_string(), + }.into()) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("error"); + assert!(registered.func.is_some()); + registered.args(&[10, 20]); + + println!("expected error message is:"); + panic::set_hook(Box::new(|panic_info| { + // if let Some(msg) = panic_info.message() { + // println!("{:?}", msg); + // } + if let Some(location) = panic_info.location() { + println!( + "panic occurred in file '{}' at line {}", + location.file(), + location.line() + ); + } else { + println!("panic occurred but can't get location information"); + } + })); + + let _result = registered.invoke(); +} diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs new file mode 100644 index 000000000000..7111e287187f --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0.0; + for arg in args.into_iter() { + let val: f64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("sum"); + assert!(registered.func.is_some()); + let ret: f64 = registered + .args(&[10.0f64, 20.0, 30.0]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60f64); +} diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs new file mode 100644 index 000000000000..23910a3244f7 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +fn main() { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0i64; + for arg in args.iter() { + let val: i64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + + tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); + + let mut registered = function::Builder::default(); + registered.get_function("mysum"); + assert!(registered.func.is_some()); + let ret: i64 = registered + .args(&[10, 20, 30]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60); +} diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs new file mode 100644 index 000000000000..9ead58733bbb --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +// FIXME +fn main() { + register_global_func! { + fn concate_str(args: &[TVMArgValue]) -> Result { + let mut ret = "".to_string(); + for arg in args.iter() { + let val: &str = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + } + let a = std::ffi::CString::new("a").unwrap(); + let b = std::ffi::CString::new("b").unwrap(); + let c = std::ffi::CString::new("c").unwrap(); + let mut registered = function::Builder::default(); + registered.get_function("concate_str"); + assert!(registered.func.is_some()); + let ret: String = registered + .arg(a.as_c_str()) + .arg(b.as_c_str()) + .arg(c.as_c_str()) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, "abc".to_owned()); +} diff --git a/rust/tvm/tests/test_ir.rs b/rust/tvm/tests/test_ir.rs new file mode 100644 index 000000000000..a43f27e82eea --- /dev/null +++ b/rust/tvm/tests/test_ir.rs @@ -0,0 +1,37 @@ +use std::convert::TryInto; +use std::str::FromStr; +use tvm::ir::IntImmNode; +use tvm::runtime::String as TString; +use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef}; +use tvm_rt::{call_packed, DLDataType, Function}; +use tvm_sys::TVMRetValue; + +#[test] +fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) +} + +#[test] +fn test_new_string() -> anyhow::Result<()> { + let string = TString::new("hello world!".to_string())?; + Ok(()) +} + +#[test] +fn test_obj_build() -> anyhow::Result<()> { + let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not found."); + + let dt = DLDataType::from_str("int32").expect("Known datatype doesn't convert."); + + let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337) + .expect("foo") + .try_into() + .unwrap(); + + debug_print(&ret_val); + + Ok(()) +} diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 7272213ad406..b3223889cc72 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -162,7 +162,7 @@ GlobalVar::GlobalVar(std::string name_hint) { TVM_REGISTER_NODE_TYPE(GlobalVarNode); TVM_REGISTER_GLOBAL("ir.GlobalVar") -.set_body_typed([](std::string name){ +.set_body_typed([](String name){ return GlobalVar(name); }); @@ -214,4 +214,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) } p->stream << '}'; }); + +TVM_REGISTER_GLOBAL("ir.DebugPrinter") +.set_body([](TVMArgs args, TVMRetValue* ret) { + ObjectRef ref = args[0]; + std::stringstream ss; + ss << ref; + *ret = ss.str(); +}); + } // namespace tvm diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index bda997a59d4d..fc9546a14c9a 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -193,8 +193,7 @@ class RelayTextPrinter : case kTypeData: return Doc::Text("TypeData"); default: - LOG(ERROR) << "Unknown Kind"; - throw; + CHECK(false) << "Unknown Kind"; } } /*! @@ -479,7 +478,8 @@ class RelayTextPrinter : } Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc::Text('@' + op->name_hint); + std::string name_hint = op->name_hint; + return Doc::Text('@' + name_hint); } Doc VisitExpr_(const OpNode* op) final { @@ -939,4 +939,13 @@ TVM_REGISTER_GLOBAL("ir.PrettyPrint") TVM_REGISTER_GLOBAL("ir.AsText") .set_body_typed(AsText); + +TVM_REGISTER_GLOBAL("ir.TextPrinter") +.set_body_typed([](ObjectRef node) { + std::cout << "The program: " << node << std::endl; + auto text = AsText(node, false, nullptr); + std::cout << "The text " << text; + return text; +}); + } // namespace tvm diff --git a/src/relay/transforms/to_cps.cc b/src/relay/transforms/to_cps.cc index e6c83928b098..65ee57f6a3f3 100644 --- a/src/relay/transforms/to_cps.cc +++ b/src/relay/transforms/to_cps.cc @@ -164,7 +164,7 @@ Function ToCPS(const Function& f, // only look unfold non-external calls. BaseFunc base_func = m->Lookup(gv); if (auto* n = base_func.as()) { - auto cps_gv = GlobalVar(gv->name_hint + "_cps"); + auto cps_gv = GlobalVar(std::string(gv->name_hint) + "_cps"); cm->insert({gv, cps_gv}); m->Add(cps_gv, ToCPS(GetRef(n), m, cm)); } else { diff --git a/src/runtime/object.cc b/src/runtime/object.cc index 0d85b9dab42c..2a957f2da6bb 100644 --- a/src/runtime/object.cc +++ b/src/runtime/object.cc @@ -218,12 +218,26 @@ int TVMObjectGetTypeIndex(TVMObjectHandle obj, unsigned* out_tindex) { API_END(); } +int TVMObjectRetain(TVMObjectHandle obj) { + API_BEGIN(); + tvm::runtime::ObjectInternal::ObjectRetain(obj); + API_END(); +} + int TVMObjectFree(TVMObjectHandle obj) { API_BEGIN(); tvm::runtime::ObjectInternal::ObjectFree(obj); API_END(); } + +int TVMObjectDerivedFrom(uint32_t child_type_index, uint32_t parent_type_index, int* is_derived) { + API_BEGIN(); + *is_derived = tvm::runtime::TypeContext::Global()-> + DerivedFrom(child_type_index, parent_type_index); + API_END(); +} + int TVMObjectTypeKey2Index(const char* type_key, unsigned* out_tindex) { API_BEGIN(); out_tindex[0] = tvm::runtime::ObjectInternal::ObjectTypeKey2Index( diff --git a/src/runtime/object_internal.h b/src/runtime/object_internal.h index 79551309d67c..ab48802e774c 100644 --- a/src/runtime/object_internal.h +++ b/src/runtime/object_internal.h @@ -37,6 +37,15 @@ namespace runtime { */ class ObjectInternal { public: + /*! + * \brief Retain an object handle. + */ + static void ObjectRetain(TVMObjectHandle obj) { + if (obj != nullptr) { + static_cast(obj)->IncRef(); + } + } + /*! * \brief Free an object handle. */ diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh index fae07d34e992..987c7cfbee34 100755 --- a/tests/scripts/task_rust.sh +++ b/tests/scripts/task_rust.sh @@ -18,13 +18,14 @@ set -e set -u +set -x export TVM_HOME="$(git rev-parse --show-toplevel)" export LD_LIBRARY_PATH="$TVM_HOME/lib:$TVM_HOME/build:${LD_LIBRARY_PATH:-}" export PYTHONPATH="$TVM_HOME/python":"$TVM_HOME/topi/python" export RUST_DIR="$TVM_HOME/rust" -export LLVM_CONFIG_PATH=`which llvm-config-8` +export LLVM_CONFIG_PATH=`which llvm-config-8 || which llvm-config` echo "Using $LLVM_CONFIG_PATH" cd $RUST_DIR