Skip to content

Commit

Permalink
[RUST][FRONTEND] Fix resnet example (apache#3000)
Browse files Browse the repository at this point in the history
Due to the previous changes the frontend resnet example failed to build.  So this patch 

1) fixes it 
2) adds ~~a local `run_tests.sh` to remedy non-existence of MXNet CI (used in python build example)~~ the example build to CI with random weights and a flag for pretrained resnet weights

Please review: @tqchen @nhynes @kazimuth
  • Loading branch information
ehsanmok authored and wweic committed May 13, 2019
1 parent 2e7f691 commit 7bb37d9
Show file tree
Hide file tree
Showing 10 changed files with 138 additions and 139 deletions.
14 changes: 13 additions & 1 deletion rust/common/src/packed_func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ TVMPODValue! {
Bytes(val) => {
(TVMValue { v_handle: val.clone() as *const _ as *mut c_void }, TVMTypeCode_kBytes)
}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr)}
Str(val) => { (TVMValue { v_handle: val.as_ptr() as *mut c_void }, TVMTypeCode_kStr) }
}
}

Expand Down Expand Up @@ -260,12 +260,24 @@ impl<'a> From<&'a str> for TVMArgValue<'a> {
}
}

impl<'a> From<String> for TVMArgValue<'a> {
fn from(s: String) -> Self {
Self::String(CString::new(s).unwrap())
}
}

impl<'a> From<&'a CStr> for TVMArgValue<'a> {
fn from(s: &'a CStr) -> Self {
Self::Str(s)
}
}

impl<'a> From<&'a TVMByteArray> for TVMArgValue<'a> {
fn from(s: &'a TVMByteArray) -> Self {
Self::Bytes(s)
}
}

impl<'a> TryFrom<TVMArgValue<'a>> for &'a str {
type Error = ValueDowncastError;
fn try_from(val: TVMArgValue<'a>) -> Result<Self, Self::Error> {
Expand Down
52 changes: 46 additions & 6 deletions rust/common/src/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

use std::str::FromStr;
use std::{os::raw::c_char, str::FromStr};

use failure::Error;

Expand Down Expand Up @@ -157,17 +157,57 @@ impl_tvm_context!(
DLDeviceType_kDLExtDev: [ext_dev]
);

/// A struct holding TVM byte-array.
///
/// ## Example
///
/// ```
/// let v = b"hello";
/// let barr = TVMByteArray::from(&v);
/// assert_eq!(barr.len(), v.len());
/// assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
/// ```
impl TVMByteArray {
/// Gets the underlying byte-array
pub fn data(&self) -> &'static [u8] {
unsafe { std::slice::from_raw_parts(self.data as *const u8, self.size) }
}

/// Gets the length of the underlying byte-array
pub fn len(&self) -> usize {
self.size
}

/// Converts the underlying byte-array to `Vec<u8>`
pub fn to_vec(&self) -> Vec<u8> {
self.data().to_vec()
}
}

impl<'a> From<&'a [u8]> for TVMByteArray {
fn from(bytes: &[u8]) -> Self {
Self {
data: bytes.as_ptr() as *const i8,
size: bytes.len(),
// Needs AsRef for Vec
impl<T: AsRef<[u8]>> From<T> for TVMByteArray {
fn from(arg: T) -> Self {
let arg = arg.as_ref();
TVMByteArray {
data: arg.as_ptr() as *const c_char,
size: arg.len(),
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn convert() {
let v = vec![1u8, 2, 3];
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.to_vec(), vec![1u8, 2, 3]);
let v = b"hello";
let barr = TVMByteArray::from(&v);
assert_eq!(barr.len(), v.len());
assert_eq!(barr.data(), &[104u8, 101, 108, 108, 111]);
}
}
17 changes: 15 additions & 2 deletions rust/frontend/examples/resnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,25 @@ This end-to-end example shows how to:
* build `Resnet 18` with `tvm` and `nnvm` from Python
* use the provided Rust frontend API to test for an input image

To run the example, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
To run the example with pretrained resnet weights, first `tvm`, `nnvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet`
and to install `tvm` and `nnvm` with `llvm` follow the [TVM installation guide](https://docs.tvm.ai/install/index.html).

* **Build the example**: `cargo build`
* **Build the example**: `cargo build

To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with
`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details.

* **Run the example**: `cargo run`

Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with

```
let output = Command::new("python")
.arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
.arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
.arg(&format!("--pretrained"))
.output()
.expect("Failed to execute command");
```

Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`!
15 changes: 11 additions & 4 deletions rust/frontend/examples/resnet/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,23 @@
* under the License.
*/

use std::process::Command;
use std::{path::Path, process::Command};

fn main() {
let output = Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
let output = Command::new("python3")
.arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py"))
.arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR")))
.output()
.expect("Failed to execute command");
assert!(
std::path::Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_lib.o")).exists(),
Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(),
"Could not prepare demo: {}",
String::from_utf8(output.stderr).unwrap().trim()
String::from_utf8(output.stderr)
.unwrap()
.trim()
.split("\n")
.last()
.unwrap_or("")
);
println!(
"cargo:rustc-link-search=native={}",
Expand Down
62 changes: 37 additions & 25 deletions rust/frontend/examples/resnet/src/build_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,18 @@

import numpy as np

import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model
from mxnet.gluon.utils import download

import tvm
from tvm import relay
from tvm.relay import testing
from tvm.contrib import graph_runtime, cc
import nnvm

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

parser = argparse.ArgumentParser(description='Resnet build example')
aa = parser.add_argument
aa('--build-dir', type=str, required=True, help='directory to put the build artifacts')
aa('--pretrained', action='store_true', help='use a pretrained resnet')
aa('--batch-size', type=int, default=1, help='input image batch size')
aa('--opt-level', type=int, default=3,
help='level of optimization. 0 is unoptimized and 3 is the highest level')
Expand All @@ -45,7 +44,7 @@
aa('--image-name', type=str, default='cat.png', help='name of input image to download')
args = parser.parse_args()

target_dir = osp.dirname(osp.dirname(osp.realpath(__file__)))
build_dir = args.build_dir
batch_size = args.batch_size
opt_level = args.opt_level
target = tvm.target.create(args.target)
Expand All @@ -57,30 +56,42 @@ def build(target_dir):
deploy_lib = osp.join(target_dir, 'deploy_lib.o')
if osp.exists(deploy_lib):
return
# download the pretrained resnet18 trained on imagenet1k dataset for
# image classification task
block = get_model('resnet18_v1', pretrained=True)

sym, params = nnvm.frontend.from_mxnet(block)
# add the softmax layer for prediction
net = nnvm.sym.softmax(sym)
if args.pretrained:
# needs mxnet installed
from mxnet.gluon.model_zoo.vision import get_model

# if `--pretrained` is enabled, it downloads a pretrained
# resnet18 trained on imagenet1k dataset for image classification task
block = get_model('resnet18_v1', pretrained=True)
net, params = relay.frontend.from_mxnet(block, {"data": data_shape})
# we want a probability so add a softmax operator
net = relay.Function(net.params, relay.nn.softmax(net.body),
None, net.type_params, net.attrs)
else:
# use random weights from relay.testing
net, params = relay.testing.resnet.get_workload(
num_layers=18, batch_size=batch_size, image_shape=image_shape)

# compile the model
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(
net, target, shape={"data": data_shape}, params=params)
with relay.build_config(opt_level=opt_level):
graph, lib, params = relay.build_module.build(net, target, params=params)

# save the model artifacts
lib.save(deploy_lib)
cc.create_shared(osp.join(target_dir, "deploy_lib.so"),
[osp.join(target_dir, "deploy_lib.o")])

with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo:
fo.write(graph.json())
fo.write(graph)

with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
fo.write(relay.save_param_dict(params))

def download_img_labels():
""" Download an image and imagenet1k class labels for test"""
from mxnet.gluon.utils import download

img_name = 'cat.png'
synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/',
'4d0b62f3d01426887599d4f7ede23ee5/raw/',
Expand All @@ -97,11 +108,11 @@ def download_img_labels():
w = csv.writer(fout)
w.writerows(synset.items())

def test_build(target_dir):
def test_build(build_dir):
""" Sanity check with random input"""
graph = open(osp.join(target_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(target_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(target_dir,"deploy_param.params"), "rb").read())
graph = open(osp.join(build_dir, "deploy_graph.json")).read()
lib = tvm.module.load(osp.join(build_dir, "deploy_lib.so"))
params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
ctx = tvm.cpu()
module = graph_runtime.create(graph, lib, ctx)
Expand All @@ -112,10 +123,11 @@ def test_build(target_dir):

if __name__ == '__main__':
logger.info("building the model")
build(target_dir)
build(build_dir)
logger.info("build was successful")
logger.info("test the build artifacts")
test_build(target_dir)
test_build(build_dir)
logger.info("test was successful")
download_img_labels()
logger.info("image and synset downloads are successful")
if args.pretrained:
download_img_labels()
logger.info("image and synset downloads are successful")
5 changes: 2 additions & 3 deletions rust/frontend/examples/resnet/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ fn main() {
let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap();
let runtime_create_fn_ret = call_packed!(
runtime_create_fn,
&graph,
graph,
&lib,
&ctx.device_type,
&ctx.device_id
Expand All @@ -107,8 +107,7 @@ fn main() {
.get_function("set_input", false)
.unwrap();

let data_str = "data".to_string();
call_packed!(set_input_fn, &data_str, &input).unwrap();
call_packed!(set_input_fn, "data".to_string(), &input).unwrap();
// get `run` function from runtime module
let ref run_fn = graph_runtime_module.get_function("run", false).unwrap();
// execute the run function. Note that it has no argument
Expand Down
92 changes: 0 additions & 92 deletions rust/frontend/src/bytearray.rs

This file was deleted.

Loading

0 comments on commit 7bb37d9

Please sign in to comment.