Skip to content

Commit

Permalink
[Rust] Fixes for wasm32 target (#5489)
Browse files Browse the repository at this point in the history
* [Rust] Fixes for wasm32 target

* [Rust] Add test for wasm32

* allow cargo config to be into repo

* Disable wasm tests in CI
  • Loading branch information
kazum authored May 2, 2020
1 parent 8599f7c commit c7a16d8
Show file tree
Hide file tree
Showing 15 changed files with 239 additions and 8 deletions.
1 change: 1 addition & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ members = [
"runtime",
"runtime/tests/test_tvm_basic",
"runtime/tests/test_tvm_dso",
"runtime/tests/test_wasm32",
"runtime/tests/test_nn",
"frontend",
"frontend/tests/basics",
Expand Down
1 change: 1 addition & 0 deletions rust/common/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ fn main() {
.layout_tests(false)
.derive_partialeq(true)
.derive_eq(true)
.derive_default(true)
.generate()
.expect("unable to generate bindings")
.write_to_file(PathBuf::from("src/c_runtime_api.rs"))
Expand Down
1 change: 1 addition & 0 deletions rust/common/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ macro_rules! impl_dltensor_from_ndarray {
shape: arr.shape().as_ptr() as *const i64 as *mut i64,
strides: arr.strides().as_ptr() as *const isize as *mut i64,
byte_offset: 0,
..Default::default()
}
}
}
Expand Down
9 changes: 7 additions & 2 deletions rust/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ pub mod ffi {

include!(concat!(env!("CARGO_MANIFEST_DIR"), "/src/c_runtime_api.rs"));

pub type BackendPackedCFunc =
extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int;
pub type BackendPackedCFunc = extern "C" fn(
args: *const TVMValue,
type_codes: *const c_int,
num_args: c_int,
out_ret_value: *mut TVMValue,
out_ret_tcode: *mut u32,
) -> c_int;
}

pub mod array;
Expand Down
1 change: 1 addition & 0 deletions rust/runtime/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ impl<'a> Tensor<'a> {
self.strides.as_ref().unwrap().as_ptr()
} as *mut i64,
byte_offset: 0,
..Default::default()
}
}
}
Expand Down
13 changes: 12 additions & 1 deletion rust/runtime/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,18 @@ named! {
// Converts a bytes to String.
named! {
name<String>,
map_res!(length_data!(le_u64), |b: &[u8]| String::from_utf8(b.to_vec()))
do_parse!(
len_l: le_u32 >>
len_h: le_u32 >>
data: take!(len_l) >>
(
if len_h == 0 {
String::from_utf8(data.to_vec()).unwrap()
} else {
panic!("Too long string")
}
)
)
}

// Parses a TVMContext
Expand Down
12 changes: 10 additions & 2 deletions rust/runtime/src/module/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,17 @@ fn wrap_backend_packed_func(func_name: String, func: BackendPackedCFunc) -> Box<
(val, code as i32)
})
.unzip();
let exit_code = func(values.as_ptr(), type_codes.as_ptr(), values.len() as i32);
let ret: TVMRetValue = TVMRetValue::default();
let (mut ret_val, mut ret_type_code) = ret.to_tvm_value();
let exit_code = func(
values.as_ptr(),
type_codes.as_ptr(),
values.len() as i32,
&mut ret_val,
&mut ret_type_code,
);
if exit_code == 0 {
Ok(TVMRetValue::default())
Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code))
} else {
Err(tvm_common::errors::FuncCallError::get_with_context(
func_name.clone(),
Expand Down
9 changes: 7 additions & 2 deletions rust/runtime/src/threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/

use std::{
env,
os::raw::{c_int, c_void},
sync::{
atomic::{AtomicUsize, Ordering},
Expand All @@ -27,6 +26,9 @@ use std::{
thread::{self, JoinHandle},
};

#[cfg(not(target_arch = "wasm32"))]
use std::env;

use crossbeam::channel::{bounded, Receiver, Sender};
use tvm_common::ffi::TVMParallelGroupEnv;

Expand Down Expand Up @@ -147,7 +149,10 @@ impl ThreadPool {

fn run_worker(queue: Receiver<Task>) {
loop {
let task = queue.recv().expect("should recv");
let task = match queue.recv() {
Ok(v) => v,
Err(_) => break,
};
let result = task.run();
if result == <i32>::min_value() {
break;
Expand Down
2 changes: 2 additions & 0 deletions rust/runtime/tests/test_wasm32/.cargo/config
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[build]
target = "wasm32-wasi"
26 changes: 26 additions & 0 deletions rust/runtime/tests/test_wasm32/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

[package]
name = "test-wasm32"
version = "0.0.0"
license = "Apache-2.0"
authors = ["TVM Contributors"]

[dependencies]
ndarray="0.12"
tvm-runtime = { path = "../../" }
71 changes: 71 additions & 0 deletions rust/runtime/tests/test_wasm32/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

use std::{path::PathBuf, process::Command};

fn main() {
let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
out_dir.push("lib");

if !out_dir.is_dir() {
std::fs::create_dir(&out_dir).unwrap();
}

let obj_file = out_dir.join("test.o");
let lib_file = out_dir.join("libtest_wasm32.a");

let output = Command::new(concat!(
env!("CARGO_MANIFEST_DIR"),
"/src/build_test_lib.py"
))
.arg(&out_dir)
.output()
.expect("Failed to execute command");
assert!(
obj_file.exists(),
"Could not build tvm lib: {}",
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);

let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8");
let output = Command::new(ar)
.arg("rcs")
.arg(&lib_file)
.arg(&obj_file)
.output()
.expect("Failed to execute command");
assert!(
lib_file.exists(),
"Could not create archive: {}",
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);

println!("cargo:rustc-link-lib=static=test_wasm32");
println!("cargo:rustc-link-search=native={}", out_dir.display());
}
38 changes: 38 additions & 0 deletions rust/runtime/tests/test_wasm32/src/build_test_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""Prepares a simple TVM library for testing."""

from os import path as osp
import sys

import tvm
from tvm import te

def main():
n = te.var('n')
A = te.placeholder((n,), name='A')
B = te.placeholder((n,), name='B')
C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.te.create_schedule(C.op)
s[C].parallel(s[C].op.axis[0])
print(tvm.lower(s, [A, B, C], simple_mode=True))
tvm.build(s, [A, B, C], 'llvm -target=wasm32-unknown-unknown --system-lib').save(osp.join(sys.argv[1], 'test.o'))

if __name__ == '__main__':
main()
54 changes: 54 additions & 0 deletions rust/runtime/tests/test_wasm32/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

extern "C" {
static __tvm_module_ctx: i32;
}

#[no_mangle]
unsafe fn __get_tvm_module_ctx() -> i32 {
// Refer a symbol in the libtest_wasm32.a to make sure that the link of the
// library is not optimized out.
__tvm_module_ctx
}

extern crate ndarray;
#[macro_use]
extern crate tvm_runtime;

use ndarray::Array;
use tvm_runtime::{DLTensor, Module as _, SystemLibModule};

fn main() {
// try static
let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]);
let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]);
let mut c = Array::from_vec(vec![0f32; 4]);
let e = Array::from_vec(vec![2f32, 2., 4., 4.]);
let mut a_dl: DLTensor = (&mut a).into();
let mut b_dl: DLTensor = (&mut b).into();
let mut c_dl: DLTensor = (&mut c).into();

let syslib = SystemLibModule::default();
let add = syslib
.get_function("default_function")
.expect("main function not found");
call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl).unwrap();
assert!(c.all_close(&e, 1e-8f32));
}
3 changes: 2 additions & 1 deletion tests/lint/check_file_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@
"KEYS",
"DISCLAIMER",
"Jenkinsfile",
# sgx config
# cargo config
"rust/runtime/tests/test_wasm32/.cargo/config",
"apps/sgx/.cargo/config",
# html for demo purposes
"tests/webgl/test_static_webgl_library.html",
Expand Down
6 changes: 6 additions & 0 deletions tests/scripts/task_rust.sh
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ cd tests/test_tvm_dso
cargo run
cd -

# # run wasm32 test
# cd tests/test_wasm32
# cargo build
# wasmtime $RUST_DIR/target/wasm32-wasi/debug/test-wasm32.wasm
# cd -

# run nn graph test
cd tests/test_nn
cargo run
Expand Down

0 comments on commit c7a16d8

Please sign in to comment.