From 5b9b16cd47e220e729d4f2a673b01270a85bc3f5 Mon Sep 17 00:00:00 2001 From: Schell Carl Scivally Date: Thu, 5 Aug 2021 12:43:53 +1200 Subject: [PATCH] Fix rust rt link (#8631) * Fix support for linking to only libtvm_runtime also ensures that the ResNet example uses the new support. * Fix build.rs to rebuild if the Python script changes Co-authored-by: Jared Roesch --- rust/tvm-rt/Cargo.toml | 1 + .../src/runtime => tvm-rt/src}/graph_rt.rs | 4 +- rust/tvm-rt/src/lib.rs | 63 ++++++++++--------- rust/tvm-sys/Cargo.toml | 1 + rust/tvm-sys/build.rs | 16 +++-- rust/tvm/examples/resnet/Cargo.toml | 2 +- rust/tvm/examples/resnet/build.rs | 10 ++- rust/tvm/examples/resnet/src/build_resnet.py | 3 + rust/tvm/examples/resnet/src/main.rs | 6 +- rust/tvm/src/runtime/mod.rs | 2 - 10 files changed, 64 insertions(+), 44 deletions(-) rename rust/{tvm/src/runtime => tvm-rt/src}/graph_rt.rs (97%) diff --git a/rust/tvm-rt/Cargo.toml b/rust/tvm-rt/Cargo.toml index eb49558ec6ce..8e69dcd9397f 100644 --- a/rust/tvm-rt/Cargo.toml +++ b/rust/tvm-rt/Cargo.toml @@ -32,6 +32,7 @@ edition = "2018" default = ["dynamic-linking"] dynamic-linking = ["tvm-sys/dynamic-linking"] static-linking = ["tvm-sys/static-linking"] +standalone = ["tvm-sys/runtime-only"] blas = ["ndarray/blas"] [dependencies] diff --git a/rust/tvm/src/runtime/graph_rt.rs b/rust/tvm-rt/src/graph_rt.rs similarity index 97% rename from rust/tvm/src/runtime/graph_rt.rs rename to rust/tvm-rt/src/graph_rt.rs index 421a00386cf5..7db53d466665 100644 --- a/rust/tvm/src/runtime/graph_rt.rs +++ b/rust/tvm-rt/src/graph_rt.rs @@ -19,8 +19,8 @@ use std::convert::TryInto; -use crate::runtime::Function; -use crate::{runtime::function::Result, runtime::ByteArray, Device, Module, NDArray}; +use crate::Function; +use crate::{function::Result, ByteArray, Device, Module, NDArray}; /// An instance of the C++ graph executor. /// diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index ce2d709c2a6c..824dc63f0b50 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -26,8 +26,40 @@ //! The TVM object system enables cross-language interoperability including that of closures for all //! supported languages including C++, and Python. +// Macro to check the return call to TVM runtime shared library. + +#[macro_export] +macro_rules! tvm_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + Err($crate::get_last_error().into()) + } else { + Ok(()) + } + }}; +} + +#[macro_export] +macro_rules! check_call { + ($e:expr) => {{ + if unsafe { $e } != 0 { + panic!("{}", $crate::get_last_error()); + } + }}; +} + +// Define all sumodules. +pub mod array; +pub mod device; +pub mod errors; +pub mod function; +pub mod graph_rt; +pub mod map; +pub mod module; +pub mod ndarray; pub mod object; pub mod string; +mod to_function; pub use object::*; pub use string::*; @@ -52,28 +84,6 @@ use tvm_sys::ffi; pub use tvm_macros::external; -// Macro to check the return call to TVM runtime shared library. - -#[macro_export] -macro_rules! tvm_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - Err($crate::get_last_error().into()) - } else { - Ok(()) - } - }}; -} - -#[macro_export] -macro_rules! check_call { - ($e:expr) => {{ - if unsafe { $e } != 0 { - panic!("{}", $crate::get_last_error()); - } - }}; -} - /// Gets the last error message. pub fn get_last_error() -> &'static str { unsafe { @@ -91,15 +101,6 @@ pub(crate) fn set_last_error(err: &E) { } } -pub mod array; -pub mod device; -pub mod errors; -pub mod function; -pub mod map; -pub mod module; -pub mod ndarray; -mod to_function; - /// Outputs the current TVM version. pub fn version() -> &'static str { match str::from_utf8(ffi::TVM_VERSION) { diff --git a/rust/tvm-sys/Cargo.toml b/rust/tvm-sys/Cargo.toml index c7ee98fc455a..1daa3086028b 100644 --- a/rust/tvm-sys/Cargo.toml +++ b/rust/tvm-sys/Cargo.toml @@ -27,6 +27,7 @@ description = "Low level bindings to TVM's cross language API." default = ["dynamic-linking"] static-linking = [] dynamic-linking = [] +runtime-only = [] [dependencies] thiserror = "^1.0" diff --git a/rust/tvm-sys/build.rs b/rust/tvm-sys/build.rs index 170ccce0a9f1..5990f0d8064f 100644 --- a/rust/tvm-sys/build.rs +++ b/rust/tvm-sys/build.rs @@ -84,11 +84,19 @@ fn main() -> Result<()> { println!("cargo:rerun-if-changed={}", build_path.display()); println!("cargo:rerun-if-changed={}/include", source_path.display()); - match &std::env::var("CARGO_CFG_TARGET_ARCH").unwrap()[..] { + let library_name = if cfg!(feature = "runtime-only") { + "tvm_runtime" + } else { + "tvm" + }; + + match &std::env::var("CARGO_CFG_TARGET_ARCH") + .expect("CARGO_CFG_TARGET_ARCH must be set by CARGO")[..] + { "wasm32" => {} _ => { if cfg!(feature = "static-linking") { - println!("cargo:rustc-link-lib=static=tvm"); + println!("cargo:rustc-link-lib=static={}", library_name); // TODO(@jroesch): move this to tvm-build as library_path? println!( "cargo:rustc-link-search=native={}/build", @@ -97,14 +105,14 @@ fn main() -> Result<()> { } if cfg!(feature = "dynamic-linking") { - println!("cargo:rustc-link-lib=dylib=tvm"); + println!("cargo:rustc-link-lib=dylib={}", library_name); println!( "cargo:rustc-link-search=native={}/build", build_path.display() ); } } - } + }; let runtime_api = source_path.join("include/tvm/runtime/c_runtime_api.h"); let backend_api = source_path.join("include/tvm/runtime/c_backend_api.h"); diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml index 646385a6373e..1e45739dd93d 100644 --- a/rust/tvm/examples/resnet/Cargo.toml +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -25,7 +25,7 @@ edition = "2018" [dependencies] ndarray = "0.12" -tvm = { path = "../../" } +tvm-rt = { path = "../../../tvm-rt", features = ["standalone"] } image = "0.20" csv = "1.1" anyhow = "^1.0" diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs index 9bf7d867e50f..9e3a76433ffc 100644 --- a/rust/tvm/examples/resnet/build.rs +++ b/rust/tvm/examples/resnet/build.rs @@ -22,17 +22,25 @@ use std::{io::Write, path::Path, process::Command}; fn main() -> Result<()> { let out_dir = std::env::var("CARGO_MANIFEST_DIR")?; + let python_script = concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"); + let synset_txt = concat!(env!("CARGO_MANIFEST_DIR"), "/synset.txt"); + + println!("cargo:rerun-if-changed={}", python_script); + println!("cargo:rerun-if-changed={}", synset_txt); + let output = Command::new("python3") - .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(python_script) .arg(&format!("--build-dir={}", out_dir)) .output() .with_context(|| anyhow::anyhow!("failed to run python3"))?; + if !output.status.success() { std::io::stdout() .write_all(&output.stderr) .context("Failed to write error")?; panic!("Failed to execute build script"); } + assert!( Path::new(&format!("{}/deploy_lib.o", out_dir)).exists(), "Could not prepare demo: {}", diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py index 277555eeb409..df02dd78f57c 100644 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -115,6 +115,9 @@ def download_img_labels(): f.write(synset[key]) f.write("\n") + print(synset_path) + print(synset_name) + return synset diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs index 7f5fcd458c26..bd0de1c56ba3 100644 --- a/rust/tvm/examples/resnet/src/main.rs +++ b/rust/tvm/examples/resnet/src/main.rs @@ -27,8 +27,8 @@ use ::ndarray::{Array, ArrayD, Axis}; use image::{FilterType, GenericImageView}; use anyhow::Context as _; -use tvm::runtime::graph_rt::GraphRt; -use tvm::*; +use tvm_rt::graph_rt::GraphRt; +use tvm_rt::*; fn main() -> anyhow::Result<()> { let dev = Device::cpu(0); @@ -107,7 +107,7 @@ fn main() -> anyhow::Result<()> { // create a hash map of (class id, class name) let file = File::open("synset.txt").context("failed to open synset")?; - let synset: Vec = BufReader::new(file) + let synset: Vec = BufReader::new(file) .lines() .into_iter() .map(|x| x.expect("readline failed")) diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs index 84da186557f7..69fbb371824a 100644 --- a/rust/tvm/src/runtime/mod.rs +++ b/rust/tvm/src/runtime/mod.rs @@ -18,5 +18,3 @@ */ pub use tvm_rt::*; - -pub mod graph_rt;