diff --git a/apps/README.md b/apps/README.md
index 685750633493..6db95ac7f593 100644
--- a/apps/README.md
+++ b/apps/README.md
@@ -26,3 +26,4 @@ If you are interested in writing optimized kernels with TVM, checkout [TOPI: TVM
- [android_rpc](android_rpc) Android RPC server.
- [benchmark](benchmark) Example end to end compilation benchmarks
- [howto_deploy](howto_deploy) Tutorial on how to deploy TVM with minimum code dependency.
+- [wasm_standalone](tvm-standalone) WebAssembly standalone for deep learning framework with TVM runtime.
diff --git a/apps/wasm-standalone/.gitignore b/apps/wasm-standalone/.gitignore
new file mode 100644
index 000000000000..54fb6c73048d
--- /dev/null
+++ b/apps/wasm-standalone/.gitignore
@@ -0,0 +1,8 @@
+# Built packages
+**/lib/
+
+
+#Added by cargo
+
+**/target/
+**/Cargo.lock
diff --git a/apps/wasm-standalone/README.md b/apps/wasm-standalone/README.md
new file mode 100644
index 000000000000..4b6678797795
--- /dev/null
+++ b/apps/wasm-standalone/README.md
@@ -0,0 +1,202 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# WebAssembly Standalone for Deep Learning Framework with TVM Runtime
+
+#### Experimental notice: This project is still *experimental* and only serves as a proof of concept for running deep learning frameworks on [WebAssembly runtime](https://github.com/bytecodealliance/wasmtime) with [TVM stack](https://tvm.apache.org/).
+
+- [WebAssembly Standalone for Deep Learning Framework with TVM Runtime](#webassembly-standalone-for-deep-learning-framework-with-tvm-runtime)
+ - [Motivation](#motivation)
+ - [Framework Landscape](#framework-landscape)
+ - [Project Status](#project-status)
+ - [PoC Guidelines](#poc-guidelines)
+ - [Pre-installation](#pre-installation)
+ - [Build ResNet50 model](#build-resnet50-model)
+ - [Build wasm-graph package](#build-wasm-graph-package)
+ - [Test](#test)
+ - [Future Work](#future-work)
+ - [More networks support](#more-networks-support)
+ - [Performance benchmark](#performance-benchmark)
+ - [Native TVM Rust runtime support](#native-tvm-rust-runtime-support)
+ - [Appendix](#appendix)
+ - [System packages install](#system-packages-install)
+
+## Motivation
+
+
+
+As demonstrated in TVM runtime [tutorials](https://tvm.apache.org/docs/tutorials/relay_quick_start.html), TVM already supports WASM as the optional hardware backend, so we can leverage the features of WebAssembly (portability, security) and TVM runtime (domain-specific, optimization) to build a flexible and auto-optimized graph compiler for all deep learning frameworks.
+
+## Framework Landscape
+
+The figures below demonstrate the whole landscape of running deep learning frameworks on WASM runtime with TVM compiler stack.
+
+* WASM graph generation
+ ```
+ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
+ | | | | | |
+ | Framework Model | ---> | ONNX Model | ---> | TVM Relay Python API |
+ |_ _ _ _ _ _ _ _ _ _| |_ _ _ _ _ _ _| |_ _ _ _ _ _ _ _ _ _ _ _|
+ ||
+ \/
+ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
+ | | | |
+ | WASM Graph Builder | | TVM Compiler Stack |
+ | (TVM runtime) | |_ _ _ _ _ _ _ _ _ _ _|
+ |_ _ _ _ _ _ _ _ _ _ _| ||
+ || \/
+ _ _ _ _ _ _ _ _ _ || _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
+ | | \/ | | llvm-ar | |
+ | wasm_graph.wasm | <--- | libgraph_wasm32.a | <------- | graph.o |
+ |_ _ _ _ _ _ _ _ _| |_ _ _ _ _ _ _ _ _ _| |_ _ _ _ _|
+ ```
+
+* WASM graph loading
+ ```
+ _ _ _ _ _ _ _ _ _ _ _
+ | |
+ | WASM Graph Loader |
+ | (WASM runtime) |
+ |_ _ _ _ _ _ _ _ _ _ _|
+ ||
+ \/
+ _ _ _ _ _ _ _ _ _ _
+ | |
+ | wasm_graph.wasm |
+ |_ _ _ _ _ _ _ _ _ _|
+ ```
+
+## Project Status
+
+This project should be considered **experimental** at the very early stage, all rich features are under active development. Here is the current operator support matrix:
+
+| Model Name | Status |
+| ---------- | ------ |
+| ResNet50 | ✔️ |
+| LeNet |
— |
+
+**NOTICE**: Currently this project is ONLY tested on Ubuntu system, so `Ubuntu 16.04+` should be prepared as the testing environment.
+
+## PoC Guidelines
+
+### Pre-installation
+
+* Rust
+
+ Before running this demo, please make sure [Rust](#system-packages-install) has been installed.
+
+ After Rust installed, execute the code below to add `wasm32-wasi` target:
+ ```shell
+ rustup target add wasm32-wasi
+ ```
+
+* TVM
+
+ Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html) for the detailed instruction.
+
+* LLVM
+
+ `LLVM 10.0` or later is REQUIRED.
+
+### Build ResNet50 model
+
+- Build DL library in the WebAssembly format.
+
+ - Download model
+
+ ```
+ cd wasm-graph/tools && wget https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet50v1/resnet50v1.onnx
+ ```
+
+ - Compile
+
+ ```
+ LLVM_AR=llvm-ar-10 python ./build_graph_lib.py -O3 ./resnet50v1.onnx
+ ```
+
+### Build wasm-graph package
+
+```shell
+cd wasm-graph && cargo build --release
+cp ./target/wasm32-wasi/release/wasm_graph.wasm ./lib/wasm_graph_resnet50.wasm
+```
+
+### Test
+
+Before running this demo, please make sure [`Rust`](#system-packages-install) has been installed.
+
+Next run the command below to install the runtime package for testing (`rust` REQUIRED):
+
+```shell
+cd wasm-runtime/tests/test_graph_resnet50 && cargo build
+```
+
+Check the usage of `test_graph_resnet50`:
+
+```shell
+~# ./target/debug/test_graph_resnet50 -h
+
+Usage: ./target/debug/test_graph_resnet50 [options]
+
+Options:
+ -g, --wasm-graph-file FILE_PATH
+ set the path to wasm graph file
+ -i, --input-data-file FILE_PATH
+ set the path to input image file
+ -l, --label-class-file FILE_PATH
+ set the path to label class file
+ -h, --help print this help menu
+```
+
+Next perform model inference using these commands below:
+```
+$ cp ../../../wasm-graph/lib/wasm_graph_resnet50.wasm ./
+$ wget -O cat.png https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true
+$ wget -O synset.csv https://raw.githubusercontent.com/kazum/tvm-wasm/master/synset.csv
+$ ./target/debug/test_graph_resnet50 -g ./wasm_graph_resnet50.wasm -i ./cat.png -l ./synset.csv
+original image dimensions: (256, 256)
+resized image dimensions: (224, 224)
+input image belongs to the class `tabby, tabby cat`
+```
+
+## Future Work
+
+### More networks support
+TODO
+
+### Performance benchmark
+
+We are working on several improvements on performances:
+* WebAssembly simd128 support (**Done**)
+* Auto-tvm enhancement for llvm target
+
+### Native TVM Rust runtime support
+TODO
+
+## Appendix
+
+### System packages install
+
+* Rust (latest version)
+
+ If you are running Windows, to install Rust, download and run the [RUST-INIT.EXE](https://win.rustup.rs/), and then follow the onscreen instructions.
+
+ If you are a Linux user, run the following in your terminal, then follow the on-screen instructions to install Rust.
+
+ ```shell
+ curl https://sh.rustup.rs -sSf | sh
+ ```
diff --git a/apps/wasm-standalone/wasm-graph/.cargo/config b/apps/wasm-standalone/wasm-graph/.cargo/config
new file mode 100644
index 000000000000..b01a37beeb90
--- /dev/null
+++ b/apps/wasm-standalone/wasm-graph/.cargo/config
@@ -0,0 +1,3 @@
+[build]
+target = "wasm32-wasi"
+rustflags = ["-C", "link-arg=--whole-archive", "-C", "link-arg=-lgraph_wasm32"]
diff --git a/apps/wasm-standalone/wasm-graph/Cargo.toml b/apps/wasm-standalone/wasm-graph/Cargo.toml
new file mode 100644
index 000000000000..9cdc8f599579
--- /dev/null
+++ b/apps/wasm-standalone/wasm-graph/Cargo.toml
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "wasm-graph"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+edition = "2018"
+description = "WebAssembly graph to deep learning frameworks using TVM"
+readme = "README.md"
+repository = "https://github.com/apache/incubator-tvm"
+license = "Apache-2.0"
+keywords = ["wasm", "machine learning", "tvm"]
+
+[profile.release]
+lto = true
+opt-level = 's'
+
+[lib]
+crate-type = ['cdylib']
+
+[dependencies]
+serde = "1.0.53"
+serde_derive = "1.0.53"
+serde_json = "1.0.53"
+ndarray = "0.12"
+tvm-sys = { path = "../../../rust/tvm-sys" }
+tvm-graph-rt = { path = "../../../rust/tvm-graph-rt" }
+lazy_static = "1.1.1"
diff --git a/apps/wasm-standalone/wasm-graph/build.rs b/apps/wasm-standalone/wasm-graph/build.rs
new file mode 100644
index 000000000000..8fd4c3c411fc
--- /dev/null
+++ b/apps/wasm-standalone/wasm-graph/build.rs
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+fn main() {
+ let out_dir = concat!(env!("CARGO_MANIFEST_DIR"), "/lib");
+
+ println!("cargo:rustc-link-search=native={}", out_dir);
+}
diff --git a/apps/wasm-standalone/wasm-graph/src/lib.rs b/apps/wasm-standalone/wasm-graph/src/lib.rs
new file mode 100644
index 000000000000..2b4187849edc
--- /dev/null
+++ b/apps/wasm-standalone/wasm-graph/src/lib.rs
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#[macro_use]
+extern crate lazy_static;
+#[macro_use]
+extern crate serde_derive;
+
+mod types;
+mod utils;
+
+use std::{collections::HashMap, convert::TryFrom, env, sync::Mutex};
+
+use tvm_graph_rt::{Graph, GraphExecutor, SystemLibModule, Tensor as TVMTensor};
+
+use types::Tensor;
+
+extern "C" {
+ fn __wasm_call_ctors();
+}
+
+lazy_static! {
+ static ref SYSLIB: SystemLibModule = SystemLibModule::default();
+ static ref GRAPH_EXECUTOR: Mutex> = {
+ unsafe {
+ // This is necessary to invoke TVMBackendRegisterSystemLibSymbol
+ // API calls.
+ __wasm_call_ctors();
+ }
+ let graph = Graph::try_from(include_str!(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/lib/graph.json"
+ )))
+ .unwrap();
+ let params_bytes =
+ include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/lib/graph.params"));
+ let params = tvm_graph_rt::load_param_dict(params_bytes)
+ .unwrap()
+ .into_iter()
+ .map(|(k, v)| (k, v.to_owned()))
+ .collect::>>();
+
+ let mut exec = GraphExecutor::new(graph, &*SYSLIB).unwrap();
+ exec.load_params(params);
+
+ Mutex::new(exec)
+ };
+}
+
+#[no_mangle]
+pub extern "C" fn run(wasm_addr: i32, in_size: i32) -> i32 {
+ let in_tensor = unsafe { utils::load_input(wasm_addr, in_size as usize) };
+ let input: TVMTensor = in_tensor.as_dltensor().into();
+
+ GRAPH_EXECUTOR.lock().unwrap().set_input("data", input);
+ GRAPH_EXECUTOR.lock().unwrap().run();
+ let output = GRAPH_EXECUTOR
+ .lock()
+ .unwrap()
+ .get_output(0)
+ .unwrap()
+ .as_dltensor(false);
+
+ let out_tensor: Tensor = output.into();
+ let out_size = unsafe { utils::store_output(wasm_addr, out_tensor) };
+ out_size as i32
+}
diff --git a/apps/wasm-standalone/wasm-graph/src/types.rs b/apps/wasm-standalone/wasm-graph/src/types.rs
new file mode 100644
index 000000000000..9d4dff96d189
--- /dev/null
+++ b/apps/wasm-standalone/wasm-graph/src/types.rs
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{
+ any::TypeId,
+ os::raw::{c_int, c_void},
+ slice,
+};
+pub use tvm_sys::ffi::DLTensor;
+use tvm_sys::ffi::{
+ DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDeviceType_kDLCPU,
+};
+
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+pub enum DataType {
+ FP32,
+ INT32,
+ INT8,
+}
+
+impl DataType {
+ pub fn as_dldtype(&self) -> DLDataType {
+ match self {
+ DataType::INT32 => DLDataType {
+ code: DLDataTypeCode_kDLInt as u8,
+ bits: 32u8,
+ lanes: 1u16,
+ },
+ DataType::INT8 => DLDataType {
+ code: DLDataTypeCode_kDLInt as u8,
+ bits: 8u8,
+ lanes: 1u16,
+ },
+ DataType::FP32 => DLDataType {
+ code: DLDataTypeCode_kDLFloat as u8,
+ bits: 32u8,
+ lanes: 1u16,
+ },
+ }
+ }
+
+ /// Returns whether this `DataType` represents primitive type `T`.
+ pub fn is_type(&self) -> bool {
+ let typ = TypeId::of::();
+ typ == TypeId::of::() || typ == TypeId::of::() || typ == TypeId::of::()
+ }
+}
+
+impl From for DataType {
+ fn from(dl_dtype: DLDataType) -> Self {
+ if dl_dtype.code == DLDataTypeCode_kDLInt as u8 && dl_dtype.bits == 32u8 {
+ DataType::INT32
+ } else if dl_dtype.code == DLDataTypeCode_kDLInt as u8 && dl_dtype.bits == 8u8 {
+ DataType::INT8
+ } else if dl_dtype.code == DLDataTypeCode_kDLFloat as u8 && dl_dtype.bits == 32u8 {
+ DataType::FP32
+ } else {
+ DataType::FP32
+ }
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Tensor {
+ pub(crate) dtype: DataType,
+ pub(crate) shape: Vec,
+ pub(crate) strides: Option>,
+ pub(crate) data: Vec,
+}
+
+#[allow(dead_code)]
+impl Tensor {
+ pub fn new(dtype: DataType, shape: Vec, strides: Vec, data: Vec) -> Self {
+ Tensor {
+ dtype,
+ shape,
+ strides: Some(strides),
+ data,
+ }
+ }
+
+ pub fn dtype(&self) -> DataType {
+ self.dtype.clone()
+ }
+
+ pub fn ndim(&self) -> usize {
+ self.shape.len()
+ }
+
+ pub fn shape(&self) -> Vec {
+ self.shape.clone()
+ }
+
+ pub fn data(&self) -> Vec {
+ self.data.clone()
+ }
+
+ pub fn as_dltensor(&self) -> DLTensor {
+ DLTensor {
+ data: self.data.as_ptr() as *mut c_void,
+ ctx: DLContext {
+ device_type: DLDeviceType_kDLCPU,
+ device_id: 0 as c_int,
+ },
+ ndim: self.shape.len() as c_int,
+ dtype: self.dtype().as_dldtype(),
+ shape: self.shape.as_ptr() as *mut i64,
+ strides: self.strides.as_ref().unwrap().as_ptr() as *mut i64,
+ byte_offset: 0,
+ ..Default::default()
+ }
+ }
+
+ /// Returns the data of this `Tensor` as a `Vec`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the `Tensor` does not contain elements of type `T`.
+ pub fn to_vec(&self) -> Vec {
+ assert!(self.dtype().is_type::());
+
+ unsafe {
+ slice::from_raw_parts(
+ self.data().as_ptr() as *const T,
+ self.shape().iter().map(|v| *v as usize).product::() as usize,
+ )
+ .to_vec()
+ }
+ }
+}
+
+impl Default for Tensor {
+ fn default() -> Self {
+ Self {
+ dtype: DataType::FP32,
+ shape: Vec::new(),
+ strides: None,
+ data: Vec::new(),
+ }
+ }
+}
+
+impl From for Tensor {
+ fn from(dlt: DLTensor) -> Self {
+ unsafe {
+ let shape = slice::from_raw_parts_mut(dlt.shape, dlt.ndim as usize).to_vec();
+ let size = shape.iter().map(|v| *v as usize).product::() as usize;
+ let itemsize: usize = (dlt.dtype.bits >> 3).into();
+ let data = slice::from_raw_parts(dlt.data as *const u8, size * itemsize).to_vec();
+
+ Self {
+ dtype: DataType::from(dlt.dtype),
+ shape,
+ strides: if dlt.strides.is_null() {
+ None
+ } else {
+ Some(
+ slice::from_raw_parts_mut(dlt.strides as *mut usize, dlt.ndim as usize)
+ .to_vec(),
+ )
+ },
+ data,
+ }
+ }
+ }
+}
diff --git a/apps/wasm-standalone/wasm-graph/src/utils.rs b/apps/wasm-standalone/wasm-graph/src/utils.rs
new file mode 100644
index 000000000000..fd4a71745f4f
--- /dev/null
+++ b/apps/wasm-standalone/wasm-graph/src/utils.rs
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use super::types::*;
+use serde_json;
+use std::ptr;
+
+pub unsafe fn load_input(in_addr: i32, in_size: usize) -> Tensor {
+ let in_addr = in_addr as *mut u8;
+
+ let mut data_vec = Vec::new();
+ for i in 0..in_size {
+ data_vec.push(ptr::read(in_addr.offset(i as isize)));
+ }
+ let input: Tensor = serde_json::from_slice(&data_vec).unwrap();
+
+ input
+}
+
+pub unsafe fn store_output(out_addr: i32, output: Tensor) -> usize {
+ let out_addr = out_addr as *mut u8;
+
+ let data_vec = serde_json::to_vec(&output).unwrap();
+ let data_size = data_vec.len();
+ for i in 0..data_size {
+ ptr::write(out_addr.offset(i as isize), *data_vec.get(i).unwrap());
+ }
+
+ data_size
+}
diff --git a/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
new file mode 100644
index 000000000000..78f80faeea8f
--- /dev/null
+++ b/apps/wasm-standalone/wasm-graph/tools/build_graph_lib.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Builds a simple graph for testing."""
+import argparse
+import os
+import subprocess
+import sys
+
+import onnx
+import tvm
+from tvm import relay
+
+
+def _get_mod_and_params(model_file):
+ onnx_model = onnx.load(model_file)
+ shape_dict = {}
+ for input in onnx_model.graph.input:
+ shape_dict[input.name] = [dim.dim_value for dim in input.type.tensor_type.shape.dim]
+
+ return relay.frontend.from_onnx(onnx_model, shape_dict)
+
+
+def build_graph_lib(model_file, opt_level):
+ """Compiles the pre-trained model with TVM"""
+ out_dir = os.path.join(sys.path[0], "../lib")
+ if not os.path.exists(out_dir):
+ os.makedirs(out_dir)
+
+ # Compile the relay mod
+ mod, params = _get_mod_and_params(model_file)
+ target = 'llvm -target=wasm32-unknown-unknown -mattr=+simd128 --system-lib'
+ with tvm.transform.PassContext(opt_level=opt_level):
+ graph_json, lib, params = relay.build(mod, target=target, params=params)
+
+ # Save the model artifacts to obj_file
+ obj_file = os.path.join(out_dir, 'graph.o')
+ lib.save(obj_file)
+ # Run llvm-ar to archive obj_file into lib_file
+ lib_file = os.path.join(out_dir, 'libgraph_wasm32.a')
+ cmds = [os.environ.get("LLVM_AR", "llvm-ar-10"), 'rcs', lib_file, obj_file]
+ subprocess.run(cmds)
+
+ with open(os.path.join(out_dir, 'graph.json'), 'w') as f_graph:
+ f_graph.write(graph_json)
+
+ with open(os.path.join(out_dir, 'graph.params'), 'wb') as f_params:
+ f_params.write(relay.save_param_dict(params))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='ONNX model build example')
+ parser.add_argument('model_file', type=str, help='the path of onnx model file')
+ parser.add_argument('-O', '--opt-level', type=int, default=0,
+ help='level of optimization. 0 is unoptimized and 3 is the highest level')
+ args = parser.parse_args()
+
+ build_graph_lib(args.model_file, args.opt_level)
diff --git a/apps/wasm-standalone/wasm-runtime/Cargo.toml b/apps/wasm-standalone/wasm-runtime/Cargo.toml
new file mode 100644
index 000000000000..db00a55c31b5
--- /dev/null
+++ b/apps/wasm-standalone/wasm-runtime/Cargo.toml
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "wasm-runtime"
+version = "0.1.0"
+authors = ["TVM Contributors"]
+edition = "2018"
+description = "WebAssembly runtime to deep learning frameworks using wasmtime"
+repository = "https://github.com/apache/incubator-tvm"
+license = "Apache-2.0"
+keywords = ["wasm", "machine learning", "wasmtime"]
+
+[dependencies]
+wasmtime = "0.16.0"
+wasmtime-wasi = "0.16.0"
+anyhow = "1.0.31"
+serde = "1.0.53"
+serde_json = "1.0.53"
+serde_derive = "1.0.53"
+ndarray = "0.12"
diff --git a/apps/wasm-standalone/wasm-runtime/src/graph.rs b/apps/wasm-standalone/wasm-runtime/src/graph.rs
new file mode 100644
index 000000000000..e7c39cbb0687
--- /dev/null
+++ b/apps/wasm-standalone/wasm-runtime/src/graph.rs
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use anyhow::Result;
+use wasmtime::*;
+use wasmtime_wasi::{Wasi, WasiCtx};
+
+use super::Tensor;
+
+pub struct GraphExecutor {
+ pub(crate) wasm_addr: i32,
+ pub(crate) input_size: i32,
+ pub(crate) output_size: i32,
+ pub(crate) instance: Option,
+}
+
+#[allow(dead_code)]
+impl GraphExecutor {
+ pub fn new() -> Self {
+ Self {
+ wasm_addr: 0,
+ input_size: 0,
+ output_size: 0,
+ instance: None,
+ }
+ }
+
+ pub fn instantiate(&mut self, wasm_graph_file: String) -> Result<()> {
+ let engine = Engine::new(Config::new().wasm_simd(true));
+ let store = Store::new(&engine);
+
+ // First set up our linker which is going to be linking modules together. We
+ // want our linker to have wasi available, so we set that up here as well.
+ let mut linker = Linker::new(&store);
+ // Create an instance of `Wasi` which contains a `WasiCtx`. Note that
+ // `WasiCtx` provides a number of ways to configure what the target program
+ // will have access to.
+ let wasi = Wasi::new(&store, WasiCtx::new(std::env::args())?);
+ wasi.add_to_linker(&mut linker)?;
+
+ let module = Module::from_file(&store, &wasm_graph_file)?;
+ self.instance = Some(linker.instantiate(&module)?);
+
+ Ok(())
+ }
+
+ pub fn set_input(&mut self, input_data: Tensor) -> Result<()> {
+ let memory = self
+ .instance
+ .as_ref()
+ .unwrap()
+ .get_memory("memory")
+ .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;
+
+ // Specify the wasm address to access the wasm memory.
+ let wasm_addr = memory.data_size();
+ // Serialize the data into a JSON string.
+ let in_data = serde_json::to_vec(&input_data)?;
+ let in_size = in_data.len();
+ // Grow up memory size according to in_size to avoid memory leak.
+ memory.grow((in_size >> 16) as u32 + 1)?;
+
+ // Insert the input data into wasm memory.
+ for i in 0..in_size {
+ unsafe {
+ memory.data_unchecked_mut()[wasm_addr + i] = *in_data.get(i).unwrap();
+ }
+ }
+
+ self.wasm_addr = wasm_addr as i32;
+ self.input_size = in_size as i32;
+ Ok(())
+ }
+
+ pub fn run(&mut self) -> Result<()> {
+ // Invoke `run` export.
+ let run = self
+ .instance
+ .as_ref()
+ .unwrap()
+ .get_func("run")
+ .ok_or_else(|| anyhow::format_err!("failed to find `run` function export!"))?
+ .get2::()?;
+
+ let out_size = run(self.wasm_addr, self.input_size)?;
+ if out_size == 0 {
+ panic!("graph run failed!");
+ }
+
+ self.output_size = out_size;
+ Ok(())
+ }
+
+ pub fn get_output(&self) -> Result {
+ let memory = self
+ .instance
+ .as_ref()
+ .unwrap()
+ .get_memory("memory")
+ .ok_or_else(|| anyhow::format_err!("failed to find `memory` export"))?;
+
+ let out_data = unsafe {
+ &memory.data_unchecked()[self.wasm_addr as usize..][..self.output_size as usize]
+ };
+ let out_vec: Tensor = serde_json::from_slice(out_data).unwrap();
+ Ok(out_vec)
+ }
+}
+
+impl Default for GraphExecutor {
+ fn default() -> Self {
+ Self::new()
+ }
+}
diff --git a/apps/wasm-standalone/wasm-runtime/src/lib.rs b/apps/wasm-standalone/wasm-runtime/src/lib.rs
new file mode 100644
index 000000000000..fa41cade035d
--- /dev/null
+++ b/apps/wasm-standalone/wasm-runtime/src/lib.rs
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#[macro_use]
+extern crate serde_derive;
+
+mod graph;
+mod types;
+
+pub use graph::GraphExecutor;
+pub use types::Tensor;
diff --git a/apps/wasm-standalone/wasm-runtime/src/types.rs b/apps/wasm-standalone/wasm-runtime/src/types.rs
new file mode 100644
index 000000000000..762a75d3c910
--- /dev/null
+++ b/apps/wasm-standalone/wasm-runtime/src/types.rs
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use std::{any::TypeId, mem, slice};
+
+#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
+pub enum DataType {
+ FP32,
+ INT32,
+ INT8,
+}
+
+impl DataType {
+ /// Returns whether this `DataType` represents primitive type `T`.
+ pub fn is_type(&self) -> bool {
+ let typ = TypeId::of::();
+ typ == TypeId::of::() || typ == TypeId::of::() || typ == TypeId::of::()
+ }
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub struct Tensor {
+ pub(crate) dtype: DataType,
+ pub(crate) shape: Vec,
+ pub(crate) strides: Option>,
+ pub(crate) data: Vec,
+}
+
+#[allow(dead_code)]
+impl Tensor {
+ pub fn new(dtype: DataType, shape: Vec, strides: Vec, data: Vec) -> Self {
+ Tensor {
+ dtype,
+ shape,
+ strides: Some(strides),
+ data,
+ }
+ }
+
+ pub fn dtype(&self) -> DataType {
+ self.dtype.clone()
+ }
+
+ pub fn ndim(&self) -> usize {
+ self.shape.len()
+ }
+
+ pub fn shape(&self) -> Vec {
+ self.shape.clone()
+ }
+
+ pub fn data(&self) -> Vec {
+ self.data.clone()
+ }
+
+ /// Returns the data of this `Tensor` as a `Vec`.
+ ///
+ /// # Panics
+ ///
+ /// Panics if the `Tensor` does not contain elements of type `T`.
+ pub fn to_vec(&self) -> Vec {
+ assert!(self.dtype().is_type::());
+
+ unsafe {
+ slice::from_raw_parts(
+ self.data().as_ptr() as *const T,
+ self.shape().iter().map(|v| *v as usize).product::() as usize,
+ )
+ .to_vec()
+ }
+ }
+}
+
+impl Default for Tensor {
+ fn default() -> Self {
+ Self {
+ dtype: DataType::FP32,
+ shape: Vec::new(),
+ strides: None,
+ data: Vec::new(),
+ }
+ }
+}
+
+/// `From` conversions to `Tensor` for `ndarray::Array`.
+/// Takes a reference to the `ndarray` since `Tensor` is not owned.
+macro_rules! impl_tensor_from_ndarray {
+ ($type:ty, $typecode:expr) => {
+ impl From> for Tensor {
+ fn from(arr: ndarray::Array<$type, D>) -> Self {
+ Tensor {
+ dtype: $typecode,
+ shape: arr.shape().iter().map(|v| *v as i64).collect(),
+ strides: Some(arr.strides().iter().map(|v| *v as usize).collect()),
+ data: unsafe {
+ slice::from_raw_parts(
+ arr.as_ptr() as *const u8,
+ arr.len() * mem::size_of::<$type>(),
+ )
+ .to_vec()
+ },
+ }
+ }
+ }
+ };
+}
+
+impl_tensor_from_ndarray!(f32, DataType::FP32);
+impl_tensor_from_ndarray!(i32, DataType::INT32);
+impl_tensor_from_ndarray!(i8, DataType::INT8);
diff --git a/apps/wasm-standalone/wasm-runtime/tests/test_graph_resnet50/Cargo.toml b/apps/wasm-standalone/wasm-runtime/tests/test_graph_resnet50/Cargo.toml
new file mode 100644
index 000000000000..67ffe3429363
--- /dev/null
+++ b/apps/wasm-standalone/wasm-runtime/tests/test_graph_resnet50/Cargo.toml
@@ -0,0 +1,30 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "test_graph_resnet50"
+version = "0.1.0"
+license = "Apache-2.0"
+authors = ["TVM Contributors"]
+edition = "2018"
+
+[dependencies]
+getopts = "0.2.21"
+ndarray = "0.12"
+csv = "1.1"
+image = "0.20"
+wasm-runtime = { path = "../../" }
diff --git a/apps/wasm-standalone/wasm-runtime/tests/test_graph_resnet50/src/main.rs b/apps/wasm-standalone/wasm-runtime/tests/test_graph_resnet50/src/main.rs
new file mode 100644
index 000000000000..befac124e9e4
--- /dev/null
+++ b/apps/wasm-standalone/wasm-runtime/tests/test_graph_resnet50/src/main.rs
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+use getopts::Options;
+use image::{FilterType, GenericImageView};
+use ndarray::Array;
+use std::{collections::HashMap, env, fs::File, io::BufReader};
+use wasm_runtime::{GraphExecutor, Tensor};
+
+const IMG_HEIGHT: usize = 224;
+const IMG_WIDTH: usize = 224;
+
+fn print_usage(program: &str, opts: Options) {
+ let brief = format!("Usage: {} [options]", program);
+ print!("{}", opts.usage(&brief));
+}
+
+fn main() {
+ let args: Vec = env::args().collect();
+ let program = args[0].clone();
+
+ let mut opts = Options::new();
+ opts.optopt(
+ "g",
+ "wasm-graph-file",
+ "set the path to wasm graph file",
+ "FILE_PATH",
+ );
+ opts.optopt(
+ "i",
+ "input-data-file",
+ "set the path to input image file",
+ "FILE_PATH",
+ );
+ opts.optopt(
+ "l",
+ "label-class-file",
+ "set the path to label class file",
+ "FILE_PATH",
+ );
+ opts.optflag("h", "help", "print this help menu");
+ let matches = match opts.parse(&args[1..]) {
+ Ok(m) => m,
+ Err(f) => panic!(f.to_string()),
+ };
+ if matches.opt_present("h") {
+ print_usage(&program, opts);
+ return;
+ }
+ let wasm_graph_file: String = match matches.opt_str("g") {
+ Some(s) => s,
+ None => String::from(""),
+ };
+ let input_data_file: String = match matches.opt_str("i") {
+ Some(s) => s,
+ None => String::from(""),
+ };
+ let label_class_file: String = match matches.opt_str("l") {
+ Some(s) => s,
+ None => String::from(""),
+ };
+ let img = image::open(input_data_file).unwrap();
+ let input = data_preprocess(img);
+
+ let mut graph_exec = GraphExecutor::new();
+ graph_exec.instantiate(wasm_graph_file).unwrap();
+ graph_exec.set_input(input).unwrap();
+ graph_exec.run().unwrap();
+ let output: Tensor = match graph_exec.get_output() {
+ Ok(m) => m,
+ Err(f) => panic!(f.to_string()),
+ };
+ output_assert(output, label_class_file);
+}
+
+fn data_preprocess(img: image::DynamicImage) -> Tensor {
+ println!("original image dimensions: {:?}", img.dimensions());
+ let img = img
+ .resize_exact(IMG_HEIGHT as u32, IMG_WIDTH as u32, FilterType::Nearest)
+ .to_rgb();
+ println!("resized image dimensions: {:?}", img.dimensions());
+ let mut pixels: Vec = vec![];
+ for pixel in img.pixels() {
+ let tmp = pixel.data;
+ // normalize the RGB channels using mean, std of imagenet1k
+ let tmp = [
+ (tmp[0] as f32 - 123.0) / 58.395, // R
+ (tmp[1] as f32 - 117.0) / 57.12, // G
+ (tmp[2] as f32 - 104.0) / 57.375, // B
+ ];
+ for e in &tmp {
+ pixels.push(*e);
+ }
+ }
+
+ // (H,W,C) -> (C,H,W)
+ let arr = Array::from_shape_vec((IMG_HEIGHT, IMG_WIDTH, 3), pixels).unwrap();
+ let arr = arr.permuted_axes([2, 0, 1]);
+ let arr = Array::from_iter(arr.into_iter().copied().map(|v| v));
+
+ Tensor::from(arr)
+}
+
+fn output_assert(out_tensor: Tensor, label_class_file: String) {
+ let output = out_tensor.to_vec::();
+
+ // Find the maximum entry in the output and its index.
+ let mut argmax = -1;
+ let mut max_prob = 0.;
+ for i in 0..output.len() {
+ if output[i] > max_prob {
+ max_prob = output[i];
+ argmax = i as i32;
+ }
+ }
+
+ // Create a hash map of (class id, class name)
+ let mut synset: HashMap = HashMap::new();
+ let mut rdr = csv::ReaderBuilder::new().from_reader(BufReader::new(
+ File::open(label_class_file.as_str()).unwrap(),
+ ));
+
+ for result in rdr.records() {
+ let record = result.unwrap();
+ let id: i32 = record[0].parse().unwrap();
+ let cls = record[1].to_string();
+ synset.insert(id, cls);
+ }
+
+ println!(
+ "input image belongs to the class `{}`",
+ synset
+ .get(&argmax)
+ .expect("cannot find the class id for argmax")
+ );
+}
diff --git a/rust/tvm-graph-rt/src/array.rs b/rust/tvm-graph-rt/src/array.rs
index b911aa816489..deacf11bec04 100644
--- a/rust/tvm-graph-rt/src/array.rs
+++ b/rust/tvm-graph-rt/src/array.rs
@@ -271,7 +271,7 @@ impl<'a> Tensor<'a> {
}
}
- pub(crate) fn as_dltensor(&self, flatten: bool) -> DLTensor {
+ pub fn as_dltensor(&self, flatten: bool) -> DLTensor {
assert!(!flatten || self.is_contiguous());
DLTensor {
data: unsafe { self.data.as_mut_ptr().offset(self.byte_offset) } as *mut c_void,
diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs
index 4359db9b8c20..802d7aeb6779 100644
--- a/rust/tvm-macros/src/external.rs
+++ b/rust/tvm-macros/src/external.rs
@@ -51,7 +51,12 @@ impl Parse for External {
assert!(method.semi_token != None);
let ident = sig.ident;
let generics = sig.generics;
- let inputs = sig.inputs.iter().map(|param| param.clone()).collect();
+ let inputs = sig
+ .inputs
+ .iter()
+ .cloned()
+ .map(|param| param.clone())
+ .collect();
let ret_type = sig.output;
Ok(External {
diff --git a/tests/lint/check_file_type.py b/tests/lint/check_file_type.py
index 1379a506d0b7..f803647d91a1 100644
--- a/tests/lint/check_file_type.py
+++ b/tests/lint/check_file_type.py
@@ -79,7 +79,7 @@
"idl",
# opencl file
"cl",
- }
+}
# List of file names allowed
ALLOW_FILE_NAME = {
@@ -98,7 +98,7 @@
".scalafmt.conf",
"Cargo.lock",
"with_the_same_user",
- }
+}
# List of specific files allowed in relpath to
ALLOW_SPECIFIC_FILE = {
@@ -111,6 +111,7 @@
"rust/runtime/tests/test_wasm32/.cargo/config",
"rust/tvm-graph-rt/tests/test_wasm32/.cargo/config",
"apps/sgx/.cargo/config",
+ "apps/wasm-standalone/wasm-graph/.cargo/config",
# html for demo purposes
"web/apps/browser/rpc_server.html",
# images are normally not allowed
@@ -121,7 +122,7 @@
"docs/_static/css/tvm_theme.css",
"docs/_static/img/tvm-logo-small.png",
"docs/_static/img/tvm-logo-square.png",
- }
+}
def filename_allowed(name):
@@ -162,7 +163,7 @@ def copyright_line(line):
if line.find("Copyright " + "(c)") != -1:
return True
if (line.find("Copyright") != -1 and
- line.find(" by") != -1):
+ line.find(" by") != -1):
return True
return False
@@ -236,5 +237,6 @@ def main():
print("check_file_type.py: all checks passed..")
+
if __name__ == "__main__":
main()