From fcdfffa2adfce22244ed19d3234060d7f4968147 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 10 Jun 2020 02:29:39 -0700 Subject: [PATCH 01/13] Adapt to new macro --- rust/Cargo.toml | 4 +++- rust/runtime/tests/test_wasm32/Cargo.toml | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 6849c039f86f..c0d0bb8cc8b2 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -29,5 +29,7 @@ members = [ "frontend/tests/callback", "frontend/examples/resnet", "tvm-sys", - "tvm-rt" + "tvm-macros", + "tvm-rt", + "tvm" ] diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml index 1d3373a9e60f..51f15ff08b67 100644 --- a/rust/runtime/tests/test_wasm32/Cargo.toml +++ b/rust/runtime/tests/test_wasm32/Cargo.toml @@ -22,5 +22,6 @@ license = "Apache-2.0" authors = ["TVM Contributors"] [dependencies] +anyhow = "*" ndarray="0.12" tvm-runtime = { path = "../../" } From 43325abcdf9b1aa8d405a9769aa43dc12f66112d Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 10 Jun 2020 02:30:10 -0700 Subject: [PATCH 02/13] Add tvm crate --- rust/tvm/.gitignore | 7 + rust/tvm/.travis.yml | 22 ++ rust/tvm/Cargo.toml | 45 ++++ rust/tvm/README.md | 235 +++++++++++++++++++ rust/tvm/examples/resnet/Cargo.toml | 29 +++ rust/tvm/examples/resnet/README.md | 45 ++++ rust/tvm/examples/resnet/build.rs | 42 ++++ rust/tvm/examples/resnet/src/build_resnet.py | 134 +++++++++++ rust/tvm/examples/resnet/src/main.rs | 160 +++++++++++++ rust/tvm/src/ir/array.rs | 74 ++++++ rust/tvm/src/ir/mod.rs | 17 ++ rust/tvm/src/ir/relay/mod.rs | 232 ++++++++++++++++++ rust/tvm/src/lib.rs | 47 ++++ rust/tvm/src/runtime/mod.rs | 1 + rust/tvm/src/transform.rs | 42 ++++ rust/tvm/tests/basics/.gitignore | 7 + rust/tvm/tests/basics/Cargo.toml | 32 +++ rust/tvm/tests/basics/build.rs | 46 ++++ rust/tvm/tests/basics/src/main.rs | 55 +++++ rust/tvm/tests/basics/src/tvm_add.py | 50 ++++ rust/tvm/tests/callback/Cargo.toml | 26 ++ rust/tvm/tests/callback/src/bin/array.rs | 72 ++++++ rust/tvm/tests/callback/src/bin/error.rs | 56 +++++ rust/tvm/tests/callback/src/bin/float.rs | 50 ++++ rust/tvm/tests/callback/src/bin/int.rs | 49 ++++ rust/tvm/tests/callback/src/bin/string.rs | 54 +++++ rust/tvm/tests/test_ir.rs | 37 +++ 27 files changed, 1666 insertions(+) create mode 100644 rust/tvm/.gitignore create mode 100644 rust/tvm/.travis.yml create mode 100644 rust/tvm/Cargo.toml create mode 100644 rust/tvm/README.md create mode 100644 rust/tvm/examples/resnet/Cargo.toml create mode 100644 rust/tvm/examples/resnet/README.md create mode 100644 rust/tvm/examples/resnet/build.rs create mode 100644 rust/tvm/examples/resnet/src/build_resnet.py create mode 100644 rust/tvm/examples/resnet/src/main.rs create mode 100644 rust/tvm/src/ir/array.rs create mode 100644 rust/tvm/src/ir/mod.rs create mode 100644 rust/tvm/src/ir/relay/mod.rs create mode 100644 rust/tvm/src/lib.rs create mode 100644 rust/tvm/src/runtime/mod.rs create mode 100644 rust/tvm/src/transform.rs create mode 100644 rust/tvm/tests/basics/.gitignore create mode 100644 rust/tvm/tests/basics/Cargo.toml create mode 100644 rust/tvm/tests/basics/build.rs create mode 100644 rust/tvm/tests/basics/src/main.rs create mode 100755 rust/tvm/tests/basics/src/tvm_add.py create mode 100644 rust/tvm/tests/callback/Cargo.toml create mode 100644 rust/tvm/tests/callback/src/bin/array.rs create mode 100644 rust/tvm/tests/callback/src/bin/error.rs create mode 100644 rust/tvm/tests/callback/src/bin/float.rs create mode 100644 rust/tvm/tests/callback/src/bin/int.rs create mode 100644 rust/tvm/tests/callback/src/bin/string.rs create mode 100644 rust/tvm/tests/test_ir.rs diff --git a/rust/tvm/.gitignore b/rust/tvm/.gitignore new file mode 100644 index 000000000000..2430329c78b6 --- /dev/null +++ b/rust/tvm/.gitignore @@ -0,0 +1,7 @@ +target +**/*.rs.bk +Cargo.lock +/tests/basics/add_* +/examples/resnet/deploy_* +/examples/resnet/*.png +/examples/resnet/synset.* diff --git a/rust/tvm/.travis.yml b/rust/tvm/.travis.yml new file mode 100644 index 000000000000..e963b7c0ede5 --- /dev/null +++ b/rust/tvm/.travis.yml @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/tvm/Cargo.toml b/rust/tvm/Cargo.toml new file mode 100644 index 000000000000..ebfb5e64a4a7 --- /dev/null +++ b/rust/tvm/Cargo.toml @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "tvm" +version = "0.1.0" +license = "Apache-2.0" +description = "Rust frontend support for TVM" +repository = "https://github.com/apache/incubator-tvm" +homepage = "https://github.com/apache/incubator-tvm" +readme = "README.md" +keywords = ["rust", "tvm"] +categories = ["api-bindings", "science"] +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +thiserror = "^1.0" +anyhow = "^1.0" +lazy_static = "1.1" +ndarray = "0.12" +num-traits = "0.2" +tvm-rt = { version = "0.1", path = "../tvm-rt/" } +tvm-sys = { version = "0.1", path = "../tvm-sys/" } +tvm-macros = { version = "*", path = "../tvm-macros/" } +paste = "0.1" +mashup = "0.1" +once_cell = "^1.3.1" + +[features] +blas = ["ndarray/blas"] diff --git a/rust/tvm/README.md b/rust/tvm/README.md new file mode 100644 index 000000000000..01e088f2ea81 --- /dev/null +++ b/rust/tvm/README.md @@ -0,0 +1,235 @@ + + + + + + + + + + + + + + + + + +# TVM Runtime Frontend Support + +This crate provides an idiomatic Rust API for [TVM](https://github.com/apache/incubator-tvm) runtime frontend. Currently this requires **Nightly Rust** and tested on `rustc 1.32.0-nightly` + +## What Does This Crate Offer? + +Here is a major workflow + +1. Train your **Deep Learning** model using any major framework such as [PyTorch](https://pytorch.org/), [Apache MXNet](https://mxnet.incubator.apache.org/) or [TensorFlow](https://www.tensorflow.org/) +2. Use **TVM** to build optimized model artifacts on a supported context such as CPU, GPU, OpenCL and specialized accelerators. +3. Deploy your models using **Rust** :heart: + +### Example: Deploy Image Classification from Pretrained Resnet18 on ImageNet1k + +Please checkout [examples/resnet](examples/resnet) for the complete end-to-end example. + +Here's a Python snippet for downloading and building a pretrained Resnet18 via Apache MXNet and TVM + +```python +block = get_model('resnet18_v1', pretrained=True) + +sym, params = relay.frontend.from_mxnet(block, shape_dict) +# compile the model +with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build( + net, target, params=params) +# same the model artifacts +lib.save(os.path.join(target_dir, "deploy_lib.o")) +cc.create_shared(os.path.join(target_dir, "deploy_lib.so"), + [os.path.join(target_dir, "deploy_lib.o")]) + +with open(os.path.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph.json()) +with open(os.path.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) +``` + +Now, we need to input the artifacts to create and run the *Graph Runtime* to detect our input cat image + +![cat](https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true) + +as demostrated in the following Rust snippet + +```rust + let graph = fs::read_to_string("deploy_graph.json")?; + // load the built module + let lib = Module::load(&Path::new("deploy_lib.so"))?; + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create", true).unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + &graph, + &lib, + &ctx.device_type, + &ctx.device_id + )?; + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into()?; + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec = fs::read("deploy_param.params")?; + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr)?; + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data", &input)?; + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,)?; + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = empty(output_shape, TVMContext::cpu(0), TVMType::from("float32")); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output)?; + // flatten the output as Vec + let output = output.to_vec::()?; +``` + +and the model correctly predicts the input image as **tiger cat**. + +## Installations + +Please follow TVM [installations](https://tvm.apache.org/docs/install/index.html), `export TVM_HOME=/path/to/tvm` and add `libtvm_runtime` to your `LD_LIBRARY_PATH`. + +*Note:* To run the end-to-end examples and tests, `tvm` and `topi` need to be added to your `PYTHONPATH` or it's automatic via an Anaconda environment when it is installed individually. + +## Supported TVM Functionalities + +### Use TVM to Generate Shared Library + +One can use the following Python snippet to generate `add_gpu.so` which add two vectors on GPU. + +```python +import os +import tvm +from tvm import te +from tvm.contrib import cc + +def test_add(target_dir): + if not tvm.runtime.enabled("cuda"): + print("skip {__file__} because cuda is not enabled...".format(__file__=__file__)) + return + n = te.var("n") + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name="C") + s = te.create_schedule(C.op) + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, tvm.thread_axis("blockIdx.x")) + s[C].bind(tx, tvm.thread_axis("threadIdx.x")) + fadd_cuda = tvm.build(s, [A, B, C], "cuda", target_host="llvm", name="myadd") + + fadd_cuda.save(os.path.join(target_dir, "add_gpu.o")) + fadd_cuda.imported_modules[0].save(os.path.join(target_dir, "add_gpu.ptx")) + cc.create_shared(os.path.join(target_dir, "add_gpu.so"), + [os.path.join(target_dir, "add_gpu.o")]) + + +if __name__ == "__main__": + import sys + if len(sys.argv) != 2: + sys.exit(-1) + test_add(sys.argv[1]) +``` + +### Run the Generated Shared Library + +The following code snippet demonstrates how to load and test the generated shared library (`add_gpu.so`) in Rust. + +```rust +extern crate tvm_frontend as tvm; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = empty(shape, TVMContext::gpu(0), TVMType::from("float32")); + let mut fadd = Module::load(&Path::new("add_gpu.so")).unwrap(); + let fadd_dep = Module::load(&Path::new("add_gpu.ptx")).unwrap(); + assert!(fadd.enabled("gpu")); + fadd.import_module(fadd_dep); + fadd.entry(); + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .set_output(&mut ret)? + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} +``` + +**Note:** it is required to instruct the `rustc` to link to the generated `add_gpu.so` in runtime, for example by +`cargo:rustc-link-search=native=add_gpu`. + +See the tests and examples custom `build.rs` for more details. + +### Convert and Register a Rust Function as a TVM Packed Function + +One can use `register_global_func!` macro to convert and register a Rust +function of type `fn(&[TVMArgValue]) -> Result` to a global TVM **packed function** as follows + +```rust +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::*; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e).unwrap(); + let rnd: ArrayD = ArrayD::try_from(&arr).unwrap(); + ret += rnd.scalar_sum(); + } + let ret_val = TVMRetValue::from(&ret); + Ok(ret_val) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = empty(shape, TVMContext::cpu(0), TVMType::from("float32")); + arr.copy_from_buffer(data.as_mut_slice()); + let mut registered = function::Builder::default(); + let ret: f64 = registered + .get_function("sum", true) + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + + assert_eq!(ret, 14f64); +} +``` diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml new file mode 100644 index 000000000000..e1a474eb5479 --- /dev/null +++ b/rust/tvm/examples/resnet/Cargo.toml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "resnet" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } +image = "0.20" +csv = "1.1" diff --git a/rust/tvm/examples/resnet/README.md b/rust/tvm/examples/resnet/README.md new file mode 100644 index 000000000000..d6e32f7fa768 --- /dev/null +++ b/rust/tvm/examples/resnet/README.md @@ -0,0 +1,45 @@ + + + + + + + + + + + + + + + + + +## Resnet example + +This end-to-end example shows how to: +* build `Resnet 18` with `tvm` from Python +* use the provided Rust frontend API to test for an input image + +To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` +and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html). + +* **Build the example**: `cargo build + +To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with +`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. + +* **Run the example**: `cargo run` + +Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with + +``` +let output = Command::new("python") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .arg(&format!("--pretrained")) + .output() + .expect("Failed to execute command"); +``` + +Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`! diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs new file mode 100644 index 000000000000..b9a3c4ccdf12 --- /dev/null +++ b/rust/tvm/examples/resnet/build.rs @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::{path::Path, process::Command}; + +fn main() { + let output = Command::new("python3") + .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) + .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) + .output() + .expect("Failed to execute command"); + assert!( + Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), + "Could not prepare demo: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + println!( + "cargo:rustc-link-search=native={}", + env!("CARGO_MANIFEST_DIR") + ); +} diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py new file mode 100644 index 000000000000..49c67bf1c4f3 --- /dev/null +++ b/rust/tvm/examples/resnet/src/build_resnet.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import csv +import logging +from os import path as osp +import sys + +import numpy as np + +import tvm +from tvm import te +from tvm import relay +from tvm.relay import testing +from tvm.contrib import graph_runtime, cc + +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +parser = argparse.ArgumentParser(description='Resnet build example') +aa = parser.add_argument +aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') +aa('--pretrained', action='store_true', help='use a pretrained resnet') +aa('--batch-size', type=int, default=1, help='input image batch size') +aa('--opt-level', type=int, default=3, + help='level of optimization. 0 is unoptimized and 3 is the highest level') +aa('--target', type=str, default='llvm', help='target context for compilation') +aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') +aa('--image-name', type=str, default='cat.png', help='name of input image to download') +args = parser.parse_args() + +build_dir = args.build_dir +batch_size = args.batch_size +opt_level = args.opt_level +target = tvm.target.create(args.target) +image_shape = tuple(map(int, args.image_shape.split(","))) +data_shape = (batch_size,) + image_shape + +def build(target_dir): + """ Compiles resnet18 with TVM""" + deploy_lib = osp.join(target_dir, 'deploy_lib.o') + if osp.exists(deploy_lib): + return + + if args.pretrained: + # needs mxnet installed + from mxnet.gluon.model_zoo.vision import get_model + + # if `--pretrained` is enabled, it downloads a pretrained + # resnet18 trained on imagenet1k dataset for image classification task + block = get_model('resnet18_v1', pretrained=True) + net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) + # we want a probability so add a softmax operator + net = relay.Function(net.params, relay.nn.softmax(net.body), + None, net.type_params, net.attrs) + else: + # use random weights from relay.testing + net, params = relay.testing.resnet.get_workload( + num_layers=18, batch_size=batch_size, image_shape=image_shape) + + # compile the model + with relay.build_config(opt_level=opt_level): + graph, lib, params = relay.build_module.build(net, target, params=params) + + # save the model artifacts + lib.save(deploy_lib) + cc.create_shared(osp.join(target_dir, "deploy_lib.so"), + [osp.join(target_dir, "deploy_lib.o")]) + + with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: + fo.write(graph) + + with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: + fo.write(relay.save_param_dict(params)) + +def download_img_labels(): + """ Download an image and imagenet1k class labels for test""" + from mxnet.gluon.utils import download + + img_name = 'cat.png' + synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', + '4d0b62f3d01426887599d4f7ede23ee5/raw/', + '596b27d23537e5a1b5751d2b0481ef172f58b539/', + 'imagenet1000_clsid_to_human.txt']) + synset_name = 'synset.txt' + download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) + download(synset_url, synset_name) + + with open(synset_name) as fin: + synset = eval(fin.read()) + + with open("synset.csv", "w") as fout: + w = csv.writer(fout) + w.writerows(synset.items()) + +def test_build(build_dir): + """ Sanity check with random input""" + graph = open(osp.join(build_dir, "deploy_graph.json")).read() + lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so")) + params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) + input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) + ctx = tvm.cpu() + module = graph_runtime.create(graph, lib, ctx) + module.load_params(params) + module.run(data=input_data) + out = module.get_output(0).asnumpy() + + +if __name__ == '__main__': + logger.info("building the model") + build(build_dir) + logger.info("build was successful") + logger.info("test the build artifacts") + test_build(build_dir) + logger.info("test was successful") + if args.pretrained: + download_img_labels() + logger.info("image and synset downloads are successful") diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs new file mode 100644 index 000000000000..0aed72b1eb52 --- /dev/null +++ b/rust/tvm/examples/resnet/src/main.rs @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate csv; +extern crate image; +extern crate ndarray; +extern crate tvm_frontend as tvm; + +use std::{ + collections::HashMap, + convert::TryInto, + fs::{self, File}, + path::Path, + str::FromStr, +}; + +use image::{FilterType, GenericImageView}; +use ndarray::{Array, ArrayD, Axis}; + +use tvm::*; + +fn main() { + let ctx = TVMContext::cpu(0); + let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); + println!("original image dimensions: {:?}", img.dimensions()); + // for bigger size images, one needs to first resize to 256x256 + // with `img.resize_exact` method and then `image.crop` to 224x224 + let img = img.resize(224, 224, FilterType::Nearest).to_rgb(); + println!("resized image dimensions: {:?}", img.dimensions()); + let mut pixels: Vec = vec![]; + for pixel in img.pixels() { + let tmp = pixel.data; + // normalize the RGB channels using mean, std of imagenet1k + let tmp = [ + (tmp[0] as f32 - 123.0) / 58.395, // R + (tmp[1] as f32 - 117.0) / 57.12, // G + (tmp[2] as f32 - 104.0) / 57.375, // B + ]; + for e in &tmp { + pixels.push(*e); + } + } + + let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); + let arr: ArrayD = arr.permuted_axes([2, 0, 1]).into_dyn(); + // make arr shape as [1, 3, 224, 224] acceptable to resnet + let arr = arr.insert_axis(Axis(0)); + // create input tensor from rust's ndarray + let input = NDArray::from_rust_ndarray( + &arr, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ) + .unwrap(); + println!( + "input size is {:?}", + input.shape().expect("cannot get the input shape") + ); + let graph = + fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); + // load the built module + let lib = Module::load(&Path::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/deploy_lib.so" + ))) + .unwrap(); + // get the global TVM graph runtime function + let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); + let runtime_create_fn_ret = call_packed!( + runtime_create_fn, + graph, + &lib, + &ctx.device_type, + &ctx.device_id + ) + .unwrap(); + // get graph runtime module + let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap(); + // get the registered `load_params` from runtime module + let ref load_param_fn = graph_runtime_module + .get_function("load_params", false) + .unwrap(); + // parse parameters and convert to TVMByteArray + let params: Vec = + fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); + let barr = TVMByteArray::from(¶ms); + // load the parameters + call_packed!(load_param_fn, &barr).unwrap(); + // get the set_input function + let ref set_input_fn = graph_runtime_module + .get_function("set_input", false) + .unwrap(); + + call_packed!(set_input_fn, "data".to_string(), &input).unwrap(); + // get `run` function from runtime module + let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); + // execute the run function. Note that it has no argument + call_packed!(run_fn,).unwrap(); + // prepare to get the output + let output_shape = &mut [1, 1000]; + let output = NDArray::empty( + output_shape, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ); + // get the `get_output` function from runtime module + let ref get_output_fn = graph_runtime_module + .get_function("get_output", false) + .unwrap(); + // execute the get output function + call_packed!(get_output_fn, &0, &output).unwrap(); + // flatten the output as Vec + let output = output.to_vec::().unwrap(); + // find the maximum entry in the output and its index + let mut argmax = -1; + let mut max_prob = 0.; + for i in 0..output.len() { + if output[i] > max_prob { + max_prob = output[i]; + argmax = i as i32; + } + } + // create a hash map of (class id, class name) + let mut synset: HashMap = HashMap::new(); + let file = File::open("synset.csv").unwrap(); + let mut rdr = csv::ReaderBuilder::new() + .has_headers(true) + .from_reader(file); + + for result in rdr.records() { + let record = result.unwrap(); + let id: i32 = record[0].parse().unwrap(); + let cls = record[1].to_string(); + synset.insert(id, cls); + } + + println!( + "input image belongs to the class `{}` with probability {}", + synset + .get(&argmax) + .expect("cannot find the class id for argmax"), + max_prob + ); +} diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs new file mode 100644 index 000000000000..2b5a23b63867 --- /dev/null +++ b/rust/tvm/src/ir/array.rs @@ -0,0 +1,74 @@ +use std::convert::TryFrom; +use std::marker::PhantomData; + +use crate::runtime::object::{ObjectRef, ToObjectRef}; + +use tvm_rt::external; +use tvm_rt::RetValue; + +use anyhow::Result; + +#[derive(Clone)] +pub struct Array { + object: ObjectRef, + _data: PhantomData, +} + +external! { + #[name("node.ArrayGetItem")] + fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; +} + +impl Array { + pub fn from_vec(data: Vec) -> Result> { + unimplemented!() + // let iter = data.iter().map(|element| element.to_object_ref()); + + // let array_data = Builder::default() + // .get_function("node.Array") + // .args(iter) + // .invoke()? + // .try_into()?; + + // Ok(Array { + // object: array_data, + // _data: PhantomData, + // }) + } + + pub fn get(&self, index: isize) -> Result + where + T: TryFrom, + { + unimplemented!() + // // TODO(@jroesch): why do we used a signed index here? + // let element: T = Builder::default() + // .get_function("node.ArrayGetItem") + // .arg(self.object.clone()) + // .arg(index) + // .invoke()? + // .try_into()?; + + // Ok(element) + } +} + +#[cfg(test)] +mod tests { + use super::Array; + use crate::ir::relay::Var; + use crate::runtime::object::ObjectRef; + use anyhow::Result; + + #[test] + fn create_array_and_get() -> Result<()> { + let vec = vec![ + Var::new("foo".into(), ObjectRef::null()), + Var::new("bar".into(), ObjectRef::null()), + ]; + let array = Array::from_vec(vec)?; + assert_eq!(array.get(0)?.name_hint().to_string()?, "foo"); + assert_eq!(array.get(1)?.name_hint().to_string()?, "bar"); + Ok(()) + } +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs new file mode 100644 index 000000000000..bc667fdb19b8 --- /dev/null +++ b/rust/tvm/src/ir/mod.rs @@ -0,0 +1,17 @@ +use crate::runtime::Object; +use crate::DataType; + +pub mod array; +pub mod relay; + +#[repr(C)] +pub struct PrimExprNode { + pub base: Object, + pub dtype: DataType, +} + +#[repr(C)] +pub struct IntImmNode { + pub base: PrimExprNode, + pub value: i64, +} diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs new file mode 100644 index 000000000000..ac7b707bdcd9 --- /dev/null +++ b/rust/tvm/src/ir/relay/mod.rs @@ -0,0 +1,232 @@ +use super::array::Array; +use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, ToObjectRef}; +use crate::DataType; +use tvm_macros::Object; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Id"] +#[type_key = "relay.Id"] +pub struct IdNode { + pub base: Object, + pub name_hint: TString, +} + +impl Id { + fn new(name_hint: TString) -> Id { + let node = IdNode { + base: Object::base_object::(), + name_hint: name_hint, + }; + Id(Some(ObjectPtr::new(node))) + } +} + +// define_ref!(Id, IdNode); + +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseExpr"] +#[type_key = "Expr"] +pub struct BaseExprNode { + pub base: Object, +} + +#[repr(C)] +pub struct PrimExprNode { + pub base: BaseExprNode, + pub datatype: DataType, +} + +impl BaseExprNode { + fn base() -> BaseExprNode { + BaseExprNode { + base: Object::base_object::(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Expr"] +#[type_key = "relay.Expr"] +pub struct RelayExpr { + pub base: BaseExprNode, + pub span: ObjectRef, + pub checked_type: ObjectRef, +} + +impl RelayExpr { + fn base() -> RelayExpr { + RelayExpr { + base: BaseExprNode::base::(), + span: ObjectRef::null(), + checked_type: ObjectRef::null(), + } + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "GlobalVar"] +#[type_key = "relay.GlobalVar"] +pub struct GlobalVarNode { + pub base: RelayExpr, + pub name_hint: TString, +} + +impl GlobalVar { + pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { + let node = GlobalVarNode { + base: RelayExpr::base::(), + // span: span, + // checked_type: ObjectRef(None),, + name_hint: TString::new(name_hint).unwrap(), + }; + GlobalVar(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Constant"] +#[type_key = "relay.Constant"] +pub struct ConstantNode { + pub base: RelayExpr, + pub data: ObjectRef, // make this NDArray. +} + +impl Constant { + pub fn new(data: ObjectRef, _span: ObjectRef) -> Constant { + let node = ConstantNode { + base: RelayExpr::base::(), + data: data, + }; + Constant(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Var"] +#[type_key = "relay.Var"] +pub struct VarNode { + pub base: RelayExpr, + pub vid: Id, + pub type_annotation: ObjectRef, +} + +impl Var { + pub fn new(name_hint: String, _span: ObjectRef) -> Var { + let node = VarNode { + base: RelayExpr::base::(), + vid: Id::new(TString::new(name_hint.to_string()).unwrap()), + type_annotation: ObjectRef::null(), + }; + Var(Some(ObjectPtr::new(node))) + } + + pub fn name_hint(&self) -> &TString { + &self.vid.0.as_ref().unwrap().name_hint + } + + pub fn to_expr(self) -> Expr { + unsafe { Expr(std::mem::transmute(self.0)) } + } +} + +pub type Type = ObjectRef; +pub type Attrs = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Call"] +#[type_key = "relay.Call"] +pub struct CallNode { + pub base: RelayExpr, + pub op: Expr, + pub args: Array, + pub attrs: ObjectRef, + pub type_args: Array, +} + +impl Call { + pub fn new( + op: Expr, + args: Array, + attrs: Attrs, + type_args: Array, + _span: ObjectRef, + ) -> Call { + let node = CallNode { + base: RelayExpr::base::(), + op: op, + args: args, + attrs: attrs, + type_args: type_args, + }; + Call(Some(ObjectPtr::new(node))) + } +} + +#[repr(C)] +#[derive(Object)] +#[ref_name = "Function"] +#[type_key = "relay.Function"] +pub struct FunctionNode { + pub base: RelayExpr, + pub params: Array, + pub body: Expr, + pub ret_type: Type, + pub type_params: Array, +} + +impl Function { + pub fn new( + params: Array, + body: Expr, + ret_type: Type, + type_params: Array, + ) -> Function { + let node = FunctionNode { + base: RelayExpr::base::(), + params: params, + body: body, + ret_type: ret_type, + type_params: type_params, + }; + Function(Some(ObjectPtr::new(node))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::runtime::{as_text, String as TString}; + use anyhow::Result; + + #[test] + fn test_id() -> Result<()> { + let string = TString::new("foo".to_string()).expect("bar"); + let id = Id::new(string); + let cstr = as_text(&id.upcast())?; + assert!(cstr.into_string()?.contains("relay.Id")); + Ok(()) + } + + #[test] + fn test_global() -> Result<()> { + let gv = GlobalVar::new("main".to_string(), ObjectRef::null()); + let cstr = as_text(&gv.upcast())?; + assert!(cstr.into_string()?.contains("@main")); + Ok(()) + } + + #[test] + fn test_var() -> Result<()> { + let var = Var::new("local".to_string(), ObjectRef::null()); + let cstr = as_text(&var.upcast())?; + assert!(cstr.into_string()?.contains("%local")); + Ok(()) + } +} diff --git a/rust/tvm/src/lib.rs b/rust/tvm/src/lib.rs new file mode 100644 index 000000000000..64252a4f9c6f --- /dev/null +++ b/rust/tvm/src/lib.rs @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +//! [TVM](https://github.com/apache/incubator-tvm) is a compiler stack for deep learning systems. +//! +//! This crate provides an idiomatic Rust API for TVM runtime frontend. +//! +//! One particular use case is that given optimized deep learning model artifacts, +//! (compiled with TVM) which include a shared library +//! `lib.so`, `graph.json` and a byte-array `param.params`, one can load them +//! in Rust idomatically to create a TVM Graph Runtime and +//! run the model for some inputs and get the +//! desired predictions *all in Rust*. +//! +//! Checkout the `examples` repository for more details. + +pub use crate::{errors::*, function::Function, module::Module, ndarray::NDArray}; + +pub use tvm_rt::{Context, DataType, DeviceType}; + +pub use tvm_rt::context; +pub use tvm_rt::errors; +pub use tvm_rt::function; +pub use tvm_rt::module; +pub use tvm_rt::ndarray; +pub use tvm_rt::value; +pub mod ir; +pub mod runtime; +pub mod transform; + +pub use runtime::version; diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs new file mode 100644 index 000000000000..57d43eea81c9 --- /dev/null +++ b/rust/tvm/src/runtime/mod.rs @@ -0,0 +1 @@ +pub use tvm_rt::*; diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs new file mode 100644 index 000000000000..0f10ca3bc522 --- /dev/null +++ b/rust/tvm/src/transform.rs @@ -0,0 +1,42 @@ +use crate::ir::array::Array; +use crate::runtime::{external, Function, String as TString}; +use crate::runtime::{Object, ObjectPtr, ObjectRef}; +use tvm_macros::Object; + +type Pass = ObjectRef; + +#[repr(C)] +#[derive(Object)] +#[ref_name = "PassInfo"] +#[type_key = "transform.PassInfo"] +pub struct PassInfoNode { + pub base: Object, + pub opt_level: i32, + pub name: TString, + pub required: Array, +} + +impl PassInfo { + pub fn new(opt_level: i32, name: String, required: Vec) -> anyhow::Result { + let required: Result<_, _> = required + .into_iter() + .map(|name| TString::new(name)) + .collect(); + + let required = Array::from_vec(required?)?; + + let node = PassInfoNode { + base: Object::base_object::(), + opt_level, + name: TString::new(name).unwrap(), + required, + }; + + Ok(PassInfo(Some(ObjectPtr::new(node)))) + } +} + +external! { + #[name("relay._transform.MakeFunctionPass")] + fn create_func_pass(func: Function, pass_info: PassInfo) -> Pass; +} diff --git a/rust/tvm/tests/basics/.gitignore b/rust/tvm/tests/basics/.gitignore new file mode 100644 index 000000000000..10a4b225a705 --- /dev/null +++ b/rust/tvm/tests/basics/.gitignore @@ -0,0 +1,7 @@ +/target +**/*.rs.bk +Cargo.lock +*.o +*.so +*.ptx +*.json diff --git a/rust/tvm/tests/basics/Cargo.toml b/rust/tvm/tests/basics/Cargo.toml new file mode 100644 index 000000000000..0b059da7727b --- /dev/null +++ b/rust/tvm/tests/basics/Cargo.toml @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "basics" +version = "0.0.0" +authors = ["TVM Contributors"] +license = "Apache-2.0" +build = "build.rs" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } + +[features] +default = ["cpu"] +cpu = [] +gpu = [] diff --git a/rust/tvm/tests/basics/build.rs b/rust/tvm/tests/basics/build.rs new file mode 100644 index 000000000000..77a3bae3627d --- /dev/null +++ b/rust/tvm/tests/basics/build.rs @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +fn main() { + let out_dir = std::env::var("OUT_DIR").unwrap(); + + let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py")) + .args(&[ + if cfg!(feature = "cpu") { + "llvm" + } else { + "cuda" + }, + &std::env::var("OUT_DIR").unwrap(), + ]) + .output() + .expect("Failed to execute command"); + assert!( + std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(), + "Could not build tvm lib: {}", + String::from_utf8(output.stderr) + .unwrap() + .trim() + .split("\n") + .last() + .unwrap_or("") + ); + + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs new file mode 100644 index 000000000000..ca53dcf999dc --- /dev/null +++ b/rust/tvm/tests/basics/src/main.rs @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +extern crate ndarray as rust_ndarray; +extern crate tvm_frontend as tvm; + +use std::str::FromStr; + +use tvm::*; + +fn main() { + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + + let (ctx, ctx_name) = if cfg!(feature = "cpu") { + (TVMContext::cpu(0), "cpu") + } else { + (TVMContext::gpu(0), "gpu") + }; + let dtype = DLDataType::from_str("float32").unwrap(); + let mut arr = NDArray::empty(shape, ctx, dtype); + arr.copy_from_buffer(data.as_mut_slice()); + let mut ret = NDArray::empty(shape, ctx, dtype); + let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); + if !fadd.enabled(ctx_name) { + return; + } + if cfg!(feature = "gpu") { + fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); + } + function::Builder::from(&mut fadd) + .arg(&arr) + .arg(&arr) + .arg(&mut ret) + .invoke() + .unwrap(); + + assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); +} diff --git a/rust/tvm/tests/basics/src/tvm_add.py b/rust/tvm/tests/basics/src/tvm_add.py new file mode 100755 index 000000000000..3911d4074e45 --- /dev/null +++ b/rust/tvm/tests/basics/src/tvm_add.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os.path as osp +import sys + +import tvm +from tvm import te +from tvm.contrib import cc + + +def main(target, out_dir): + n = te.var('n') + A = te.placeholder((n,), name='A') + B = te.placeholder((n,), name='B') + C = te.compute(A.shape, lambda i: A[i] + B[i], name='C') + s = te.create_schedule(C.op) + + if target == 'cuda': + bx, tx = s[C].split(C.op.axis[0], factor=64) + s[C].bind(bx, te.thread_axis('blockIdx.x')) + s[C].bind(tx, te.thread_axis('threadIdx.x')) + + fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd') + + fadd.save(osp.join(out_dir, 'test_add.o')) + if target == 'cuda': + fadd.imported_modules[0].save(osp.join(out_dir, 'test_add.ptx')) + cc.create_shared( + osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')]) + + +if __name__ == '__main__': + main(sys.argv[1], sys.argv[2]) + diff --git a/rust/tvm/tests/callback/Cargo.toml b/rust/tvm/tests/callback/Cargo.toml new file mode 100644 index 000000000000..5c89d2ac6375 --- /dev/null +++ b/rust/tvm/tests/callback/Cargo.toml @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "callback" +version = "0.0.0" +authors = ["TVM Contributors"] +edition = "2018" + +[dependencies] +ndarray = "0.12" +tvm = { path = "../../" } diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs new file mode 100644 index 000000000000..cb4a8229c401 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/array.rs @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +extern crate ndarray as rust_ndarray; +#[macro_use] +extern crate tvm_frontend as tvm; + +use rust_ndarray::ArrayD; +use std::{ + convert::{TryFrom, TryInto}, + str::FromStr, +}; + +use tvm::{errors::Error, *}; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0f32; + let shape = &mut [2]; + for arg in args.iter() { + let e = NDArray::empty( + shape, TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap() + ); + let arg: NDArray = arg.try_into()?; + let arr = arg.copy_to_ndarray(e)?; + let rnd: ArrayD = ArrayD::try_from(&arr)?; + ret += rnd.scalar_sum(); + } + Ok(TVMRetValue::from(ret)) + } + } + + let shape = &mut [2]; + let mut data = vec![3f32, 4.0]; + let mut arr = NDArray::empty( + shape, + TVMContext::cpu(0), + DLDataType::from_str("float32").unwrap(), + ); + arr.copy_from_buffer(data.as_mut_slice()); + + let mut registered = function::Builder::default(); + let ret: f32 = registered + .get_function("sum") + .arg(&arr) + .arg(&arr) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 7f32); +} diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs new file mode 100644 index 000000000000..c9f9a6f771cf --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/error.rs @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use std::panic; + +use tvm_frontend::{errors::Error, *}; + +fn main() { + register_global_func! { + fn error(_args: &[TVMArgValue]) -> Result { + Err(errors::TypeMismatchError{ + expected: "i64".to_string(), + actual: "f64".to_string(), + }.into()) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("error"); + assert!(registered.func.is_some()); + registered.args(&[10, 20]); + + println!("expected error message is:"); + panic::set_hook(Box::new(|panic_info| { + // if let Some(msg) = panic_info.message() { + // println!("{:?}", msg); + // } + if let Some(location) = panic_info.location() { + println!( + "panic occurred in file '{}' at line {}", + location.file(), + location.line() + ); + } else { + println!("panic occurred but can't get location information"); + } + })); + + let _result = registered.invoke(); +} diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs new file mode 100644 index 000000000000..7111e287187f --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/float.rs @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +fn main() { + register_global_func! { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0.0; + for arg in args.into_iter() { + let val: f64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + } + + let mut registered = function::Builder::default(); + registered.get_function("sum"); + assert!(registered.func.is_some()); + let ret: f64 = registered + .args(&[10.0f64, 20.0, 30.0]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60f64); +} diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs new file mode 100644 index 000000000000..23910a3244f7 --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/int.rs @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +extern crate tvm_frontend as tvm; + +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +fn main() { + fn sum(args: &[TVMArgValue]) -> Result { + let mut ret = 0i64; + for arg in args.iter() { + let val: i64 = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + + tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); + + let mut registered = function::Builder::default(); + registered.get_function("mysum"); + assert!(registered.func.is_some()); + let ret: i64 = registered + .args(&[10, 20, 30]) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, 60); +} diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs new file mode 100644 index 000000000000..9ead58733bbb --- /dev/null +++ b/rust/tvm/tests/callback/src/bin/string.rs @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#![allow(unused_imports)] + +#[macro_use] +extern crate tvm_frontend as tvm; +use std::convert::TryInto; +use tvm::{errors::Error, *}; + +// FIXME +fn main() { + register_global_func! { + fn concate_str(args: &[TVMArgValue]) -> Result { + let mut ret = "".to_string(); + for arg in args.iter() { + let val: &str = arg.try_into()?; + ret += val; + } + Ok(TVMRetValue::from(ret)) + } + } + let a = std::ffi::CString::new("a").unwrap(); + let b = std::ffi::CString::new("b").unwrap(); + let c = std::ffi::CString::new("c").unwrap(); + let mut registered = function::Builder::default(); + registered.get_function("concate_str"); + assert!(registered.func.is_some()); + let ret: String = registered + .arg(a.as_c_str()) + .arg(b.as_c_str()) + .arg(c.as_c_str()) + .invoke() + .unwrap() + .try_into() + .unwrap(); + assert_eq!(ret, "abc".to_owned()); +} diff --git a/rust/tvm/tests/test_ir.rs b/rust/tvm/tests/test_ir.rs new file mode 100644 index 000000000000..a43f27e82eea --- /dev/null +++ b/rust/tvm/tests/test_ir.rs @@ -0,0 +1,37 @@ +use std::convert::TryInto; +use std::str::FromStr; +use tvm::ir::IntImmNode; +use tvm::runtime::String as TString; +use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef}; +use tvm_rt::{call_packed, DLDataType, Function}; +use tvm_sys::TVMRetValue; + +#[test] +fn test_new_object() -> anyhow::Result<()> { + let object = Object::base_object::(); + let ptr = ObjectPtr::new(object); + assert_eq!(ptr.count(), 1); + Ok(()) +} + +#[test] +fn test_new_string() -> anyhow::Result<()> { + let string = TString::new("hello world!".to_string())?; + Ok(()) +} + +#[test] +fn test_obj_build() -> anyhow::Result<()> { + let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not found."); + + let dt = DLDataType::from_str("int32").expect("Known datatype doesn't convert."); + + let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337) + .expect("foo") + .try_into() + .unwrap(); + + debug_print(&ret_val); + + Ok(()) +} From f7aa26a5b9cd7dbe4d793250223622375c4e0244 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 10 Jun 2020 15:03:55 -0700 Subject: [PATCH 03/13] Fix out of tree pass with new bindings --- rust/Cargo.toml | 3 +- rust/tvm-macros/src/external.rs | 6 ++-- rust/tvm-macros/src/object.rs | 26 +++++++++++------ rust/tvm-rt/src/function.rs | 5 ++-- rust/tvm-rt/src/object/mod.rs | 31 ++++++++++++++++---- rust/tvm-rt/src/string.rs | 4 +-- rust/tvm/src/ir/array.rs | 50 +++++++++++++-------------------- rust/tvm/src/ir/relay/mod.rs | 2 +- rust/tvm/src/transform.rs | 19 +++++++++---- 9 files changed, 88 insertions(+), 58 deletions(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index c0d0bb8cc8b2..afe62071116a 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -31,5 +31,6 @@ members = [ "tvm-sys", "tvm-macros", "tvm-rt", - "tvm" + "tvm", + "out-of-tree" ] diff --git a/rust/tvm-macros/src/external.rs b/rust/tvm-macros/src/external.rs index 8833d6084574..2fcee49d3abd 100644 --- a/rust/tvm-macros/src/external.rs +++ b/rust/tvm-macros/src/external.rs @@ -88,7 +88,7 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let tvm_rt_crate = crate::util::get_tvm_rt_crate(); - let err_type = quote! { #tvm_rt_crate::Error }; + let result_type = quote! { #tvm_rt_crate::function::Result }; let mut items = Vec::new(); @@ -142,9 +142,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> proc_macro::TokenStream { items.push(global); let wrapper = quote! { - pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> Result<#ret_type, #err_type> { + pub fn #name<#(#ty_params),*>(#(#args : #tys),*) -> #result_type<#ret_type> { let func_ref: #tvm_rt_crate::Function = #global_name.clone(); - let func_ref: Box Result<#ret_type, #err_type>> = func_ref.to_boxed_fn(); + let func_ref: Box #result_type<#ret_type>> = func_ref.to_boxed_fn(); let res: #ret_type = func_ref(#(#args),*)?; Ok(res) } diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index bee22c367189..55e983e2ce6c 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -27,6 +27,8 @@ use crate::util::get_tvm_rt_crate; pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { let tvm_rt_crate = get_tvm_rt_crate(); + let result = quote! { #tvm_rt_crate::function::Result }; + let error = quote! { #tvm_rt_crate::errors::Error }; let derive_input = syn::parse_macro_input!(input as DeriveInput); let payload_id = derive_input.ident; @@ -77,9 +79,15 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { #[derive(Clone)] pub struct #ref_id(Option<#tvm_rt_crate::object::ObjectPtr<#payload_id>>); - impl #tvm_rt_crate::object::ToObjectRef for #ref_id { - fn to_object_ref(&self) -> ObjectRef { - ObjectRef(self.0.as_ref().map(|o| o.upcast())) + impl #tvm_rt_crate::object::IsObjectRef for #ref_id { + type Object = #payload_id; + + fn as_object_ptr(&self) -> Option<&ObjectPtr> { + self.0.as_ref() + } + + fn from_object_ptr(object_ptr: Option>) -> Self { + #ref_id(object_ptr) } } @@ -92,9 +100,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl std::convert::TryFrom<#tvm_rt_crate::RetValue> for #ref_id { - type Error = #tvm_rt_crate::Error; + type Error = #error; - fn try_from(ret_val: #tvm_rt_crate::RetValue) -> Result<#ref_id, Self::Error> { + fn try_from(ret_val: #tvm_rt_crate::RetValue) -> #result<#ref_id> { use std::convert::TryInto; let oref: ObjectRef = ret_val.try_into()?; let ptr = oref.0.ok_or(#tvm_rt_crate::Error::Null)?; @@ -125,9 +133,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl<'a> std::convert::TryFrom<#tvm_rt_crate::ArgValue<'a>> for #ref_id { - type Error = #tvm_rt_crate::Error; + type Error = #error; - fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + fn try_from(arg_value: #tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> { use std::convert::TryInto; let optr = arg_value.try_into()?; Ok(#ref_id(Some(optr))) @@ -135,9 +143,9 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id { - type Error = #tvm_rt_crate::Error; + type Error = #error; - fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> Result<#ref_id, Self::Error> { + fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> { use std::convert::TryInto; let optr = arg_value.try_into()?; Ok(#ref_id(Some(optr))) diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index cb8777a6227b..32423d39d43b 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -32,12 +32,13 @@ use std::{ ptr, str, }; -pub use tvm_sys::{ffi, ArgValue, RetValue}; use crate::errors::Error; use super::to_boxed_fn::ToBoxedFn; -use super::to_function::{ToFunction, Typed}; + +pub use tvm_sys::{ffi, ArgValue, RetValue}; +pub use super::to_function::{ToFunction, Typed}; pub type Result = std::result::Result; diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index c49f84e2d916..dc5668128759 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -39,13 +39,34 @@ impl ObjectRef { } } -pub trait ToObjectRef { - fn to_object_ref(&self) -> ObjectRef; -} +pub trait IsObjectRef: Sized { + type Object: IsObject; + fn as_object_ptr(&self) -> Option<&ObjectPtr>; + fn from_object_ptr(object_ptr: Option>) -> Self; -impl ToObjectRef for ObjectRef { fn to_object_ref(&self) -> ObjectRef { - self.clone() + let object_ptr = self.as_object_ptr().cloned(); + ObjectRef(object_ptr.map(|ptr| ptr.upcast())) + } + + fn downcast(&self) -> Result { + let ptr = + self.as_object_ptr() + .map(|ptr| ptr.downcast::()); + let ptr = ptr.transpose()?; + Ok(U::from_object_ptr(ptr)) + } +} + +impl IsObjectRef for ObjectRef { + type Object = Object; + + fn as_object_ptr(&self) -> Option<&ObjectPtr> { + self.0.as_ref() + } + + fn from_object_ptr(object_ptr: Option>) -> Self { + ObjectRef(object_ptr) } } diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 26758b1170e7..2805d4c300f5 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -36,7 +36,7 @@ pub struct StringObj { } impl String { - pub fn new(string: std::string::String) -> Result { + pub fn new(string: std::string::String) -> Result { let cstring = CString::new(string)?; // The string is being corrupted. @@ -73,7 +73,7 @@ impl String { // mod tests { // use super::String; // use crate::object::debug_print; -// use crate::ToObjectRef; +// use crate::IsObjectRef; // use anyhow::{ensure, Result}; // #[test] diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs index 2b5a23b63867..954c2fe6a808 100644 --- a/rust/tvm/src/ir/array.rs +++ b/rust/tvm/src/ir/array.rs @@ -1,55 +1,45 @@ -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; -use crate::runtime::object::{ObjectRef, ToObjectRef}; +use crate::runtime::object::{ObjectRef, IsObjectRef}; -use tvm_rt::external; -use tvm_rt::RetValue; - -use anyhow::Result; +use tvm_rt::{external, RetValue, function::{Function, Result}}; +use tvm_rt::errors::Error; #[derive(Clone)] -pub struct Array { +pub struct Array { object: ObjectRef, _data: PhantomData, } +// TODO(@jroesch): convert to use generics instead of casting inside +// the implementation. external! { #[name("node.ArrayGetItem")] fn array_get_item(array: ObjectRef, index: isize) -> ObjectRef; } -impl Array { +impl Array { pub fn from_vec(data: Vec) -> Result> { - unimplemented!() - // let iter = data.iter().map(|element| element.to_object_ref()); + let iter = data.iter().map(|element| element.to_object_ref().into()).collect(); + + let func = Function::get("node.Array") + .expect("node.Array function is not registered, this is most likely a build or linking error"); - // let array_data = Builder::default() - // .get_function("node.Array") - // .args(iter) - // .invoke()? - // .try_into()?; + let array_data = func.invoke(iter)?.try_into()?; - // Ok(Array { - // object: array_data, - // _data: PhantomData, - // }) + Ok(Array { + object: array_data, + _data: PhantomData, + }) } pub fn get(&self, index: isize) -> Result where - T: TryFrom, + T: TryFrom, { - unimplemented!() - // // TODO(@jroesch): why do we used a signed index here? - // let element: T = Builder::default() - // .get_function("node.ArrayGetItem") - // .arg(self.object.clone()) - // .arg(index) - // .invoke()? - // .try_into()?; - - // Ok(element) + let oref: ObjectRef = array_get_item(self.object.clone(), index)?; + oref.downcast() } } diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index ac7b707bdcd9..72fe2f1ab765 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -1,5 +1,5 @@ use super::array::Array; -use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, ToObjectRef}; +use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, IsObjectRef}; use crate::DataType; use tvm_macros::Object; diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index 0f10ca3bc522..c655c0db3826 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -1,9 +1,13 @@ use crate::ir::array::Array; -use crate::runtime::{external, Function, String as TString}; +use crate::runtime::{external, function::{self, Result, ToFunction, Typed}, String as TString}; use crate::runtime::{Object, ObjectPtr, ObjectRef}; +use crate::ir::relay::Function; + use tvm_macros::Object; -type Pass = ObjectRef; +pub type Pass = ObjectRef; +pub type IRModule = ObjectRef; +pub type PassContext = ObjectRef; #[repr(C)] #[derive(Object)] @@ -17,8 +21,8 @@ pub struct PassInfoNode { } impl PassInfo { - pub fn new(opt_level: i32, name: String, required: Vec) -> anyhow::Result { - let required: Result<_, _> = required + pub fn new(opt_level: i32, name: String, required: Vec) -> Result { + let required: Result<_> = required .into_iter() .map(|name| TString::new(name)) .collect(); @@ -38,5 +42,10 @@ impl PassInfo { external! { #[name("relay._transform.MakeFunctionPass")] - fn create_func_pass(func: Function, pass_info: PassInfo) -> Pass; + fn create_func_pass(func: function::Function, pass_info: PassInfo) -> Pass; +} + +pub fn function_pass Function + 'static>(pass_fn: F, pass_info: PassInfo) -> Result { + let func = pass_fn.to_function(); + create_func_pass(func, pass_info) } From dc099df7a9f6bff6dcc0a076d13e370af150b291 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 11 Jun 2020 00:21:53 -0700 Subject: [PATCH 04/13] Super slick API working --- rust/tvm-rt/src/errors.rs | 2 ++ rust/tvm-rt/src/to_function.rs | 56 +++++++++++++++++++--------------- rust/tvm-sys/src/lib.rs | 10 ++++++ 3 files changed, 44 insertions(+), 24 deletions(-) diff --git a/rust/tvm-rt/src/errors.rs b/rust/tvm-rt/src/errors.rs index 0b45ebf445bf..779f04e6daa9 100644 --- a/rust/tvm-rt/src/errors.rs +++ b/rust/tvm-rt/src/errors.rs @@ -66,6 +66,8 @@ pub enum Error { NDArray(#[from] NDArrayError), #[error("{0}")] CallFailed(String), + #[error("this case will never occur")] + Infallible(#[from] std::convert::Infallible), } impl Error { diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 4814d098238a..7a9bbeaf3a48 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -46,28 +46,32 @@ pub use tvm_sys::{ffi, ArgValue, RetValue}; /// And the implementation of it to `ToFunction`. pub trait Typed { fn args(i: &[ArgValue<'static>]) -> Result; - fn ret(o: O) -> RetValue; + fn ret(o: O) -> Result; } -impl> Typed<(), O> for F +impl Typed<(), O> for F where F: Fn() -> O, + Error: From, + O: TryInto { fn args(_args: &[ArgValue<'static>]) -> Result<()> { debug_assert!(_args.len() == 0); Ok(()) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } -impl, E> Typed<(A,), O> for F +impl Typed<(A,), O> for F where F: Fn(A) -> O, - Error: From, - A: TryFrom, Error = E>, + Error: From, + Error: From, + A: TryFrom, Error = E1>, + O: TryInto, { fn args(args: &[ArgValue<'static>]) -> Result<(A,)> { debug_assert!(args.len() == 1); @@ -75,17 +79,19 @@ where Ok((a,)) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } -impl, E> Typed<(A, B), O> for F +impl Typed<(A, B), O> for F where F: Fn(A, B) -> O, - Error: From, - A: TryFrom, Error = E>, - B: TryFrom, Error = E>, + Error: From, + Error: From, + A: TryFrom, Error = E1>, + B: TryFrom, Error = E1>, + O: TryInto, { fn args(args: &[ArgValue<'static>]) -> Result<(A, B)> { debug_assert!(args.len() == 2); @@ -94,18 +100,20 @@ where Ok((a, b)) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } -impl, E> Typed<(A, B, C), O> for F +impl Typed<(A, B, C), O> for F where F: Fn(A, B, C) -> O, - Error: From, - A: TryFrom, Error = E>, - B: TryFrom, Error = E>, - C: TryFrom, Error = E>, + Error: From, + Error: From, + A: TryFrom, Error = E1>, + B: TryFrom, Error = E1>, + C: TryFrom, Error = E1>, + O: TryInto { fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> { debug_assert!(args.len() == 3); @@ -115,8 +123,8 @@ where Ok((a, b, c)) } - fn ret(o: O) -> RetValue { - o.into() + fn ret(o: O) -> Result { + o.try_into().map_err(|e| e.into()) } } @@ -230,7 +238,7 @@ where { // Ideally we shouldn't need to clone, probably doesn't really matter. let out = unsafe { (*handle)() }; - Ok(F::ret(out)) + F::ret(out) } fn drop(_: *mut Self::Handle) {} @@ -253,7 +261,7 @@ macro_rules! to_function_instance { let out = unsafe { (*handle)($(args.$index),+) }; - Ok(F::ret(out)) + F::ret(out) } fn drop(_: *mut Self::Handle) {} diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs index 0f455e726d26..2aa6122af674 100644 --- a/rust/tvm-sys/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -57,3 +57,13 @@ pub use context::{Context, DeviceType}; pub use datatype::DataType; pub use errors::*; pub use packed_func::{ArgValue, RetValue}; + +impl std::convert::TryFrom> for RetValue +where RetValue: std::convert::TryFrom, + E: From<>::Error> { + type Error = E; + + fn try_from(val: Result) -> Result { + val.and_then(|t| RetValue::try_from(t).map_err(|e| e.into())) + } +} From ea64054299226e171f4ab19007ff3fd56fd5b8cc Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 11 Jun 2020 11:27:39 -0700 Subject: [PATCH 05/13] Add examples --- rust/egg/Cargo.toml | 16 ++++++++++++++++ rust/egg/import_pass.py | 27 +++++++++++++++++++++++++++ rust/egg/src/lib.rs | 33 +++++++++++++++++++++++++++++++++ rust/out-of-tree/Cargo.toml | 16 ++++++++++++++++ rust/out-of-tree/import_pass.py | 27 +++++++++++++++++++++++++++ rust/out-of-tree/src/lib.rs | 33 +++++++++++++++++++++++++++++++++ rust/tvm/src/ir/array.rs | 1 + rust/tvm/src/ir/relay/mod.rs | 22 ++++++++++++++++++++-- rust/tvm/src/transform.rs | 16 ++++++++++++++++ 9 files changed, 189 insertions(+), 2 deletions(-) create mode 100644 rust/egg/Cargo.toml create mode 100644 rust/egg/import_pass.py create mode 100644 rust/egg/src/lib.rs create mode 100644 rust/out-of-tree/Cargo.toml create mode 100644 rust/out-of-tree/import_pass.py create mode 100644 rust/out-of-tree/src/lib.rs diff --git a/rust/egg/Cargo.toml b/rust/egg/Cargo.toml new file mode 100644 index 000000000000..996a0191244b --- /dev/null +++ b/rust/egg/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "out-of-tree" +version = "0.1.0" +authors = ["Jared Roesch "] +edition = "2018" + + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "my_pass" +crate-type = ["cdylib"] + +[dependencies] +tvm = { version = "0.1", path = "../tvm" } +tvm-sys = { version = "0.1", path = "../tvm-sys" } +anyhow = "*" diff --git a/rust/egg/import_pass.py b/rust/egg/import_pass.py new file mode 100644 index 000000000000..100478db822a --- /dev/null +++ b/rust/egg/import_pass.py @@ -0,0 +1,27 @@ +import tvm +import tvm.relay +from tvm.ir.transform import PassContext + +x = tvm.relay.var("x", shape=(10,)) +test_func = tvm.relay.Function([x], x) +test_mod = tvm.IRModule.from_expr(test_func) + +pass_dylib = "/Users/jroesch/Git/tvm/rust/target/debug/libmy_pass.dylib" + +def load_rust_extension(ext_dylib): + load_so = tvm.get_global_func("runtime.module.loadfile_so") + mod = load_so(ext_dylib) + mod.get_function("initialize")() + + +def load_pass(pass_name, dylib): + load_rust_extension(dylib) + return tvm.get_global_func(pass_name) + +MyPass = load_pass("out_of_tree.Pass", pass_dylib) +ctx = PassContext() +import pdb; pdb.set_trace() +f = MyPass(test_func, test_mod, ctx) +mod = MyPass()(test_mod) + +print(mod) diff --git a/rust/egg/src/lib.rs b/rust/egg/src/lib.rs new file mode 100644 index 000000000000..2526d7963785 --- /dev/null +++ b/rust/egg/src/lib.rs @@ -0,0 +1,33 @@ +use std::ffi::c_void; +use std::os::raw::c_int; +use tvm::ir::relay::{self, Function}; +use tvm::runtime::ObjectRef; +use tvm::transform::{function_pass, PassInfo, Pass, PassContext, IRModule}; +use tvm::runtime::function::{register, Result}; +use tvm::export_pass; + +fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Function { + let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); + relay::Function::new( + func.params.clone(), + var.to_expr(), + func.ret_type.clone(), + func.type_params.clone()) +} + +fn my_pass_fn2(func: relay::Function) -> Function { + let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); + relay::Function::new( + func.params.clone(), + var.to_expr(), + func.ret_type.clone(), + func.type_params.clone()) +} + + +// fn the_pass() -> Result { +// let pass_info = PassInfo::new(15, "RustPass".into(), vec![])?; +// function_pass(my_pass_fn, pass_info) +// } + +export_pass!("out_of_tree.Pass", my_pass_fn2); diff --git a/rust/out-of-tree/Cargo.toml b/rust/out-of-tree/Cargo.toml new file mode 100644 index 000000000000..996a0191244b --- /dev/null +++ b/rust/out-of-tree/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "out-of-tree" +version = "0.1.0" +authors = ["Jared Roesch "] +edition = "2018" + + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "my_pass" +crate-type = ["cdylib"] + +[dependencies] +tvm = { version = "0.1", path = "../tvm" } +tvm-sys = { version = "0.1", path = "../tvm-sys" } +anyhow = "*" diff --git a/rust/out-of-tree/import_pass.py b/rust/out-of-tree/import_pass.py new file mode 100644 index 000000000000..100478db822a --- /dev/null +++ b/rust/out-of-tree/import_pass.py @@ -0,0 +1,27 @@ +import tvm +import tvm.relay +from tvm.ir.transform import PassContext + +x = tvm.relay.var("x", shape=(10,)) +test_func = tvm.relay.Function([x], x) +test_mod = tvm.IRModule.from_expr(test_func) + +pass_dylib = "/Users/jroesch/Git/tvm/rust/target/debug/libmy_pass.dylib" + +def load_rust_extension(ext_dylib): + load_so = tvm.get_global_func("runtime.module.loadfile_so") + mod = load_so(ext_dylib) + mod.get_function("initialize")() + + +def load_pass(pass_name, dylib): + load_rust_extension(dylib) + return tvm.get_global_func(pass_name) + +MyPass = load_pass("out_of_tree.Pass", pass_dylib) +ctx = PassContext() +import pdb; pdb.set_trace() +f = MyPass(test_func, test_mod, ctx) +mod = MyPass()(test_mod) + +print(mod) diff --git a/rust/out-of-tree/src/lib.rs b/rust/out-of-tree/src/lib.rs new file mode 100644 index 000000000000..2526d7963785 --- /dev/null +++ b/rust/out-of-tree/src/lib.rs @@ -0,0 +1,33 @@ +use std::ffi::c_void; +use std::os::raw::c_int; +use tvm::ir::relay::{self, Function}; +use tvm::runtime::ObjectRef; +use tvm::transform::{function_pass, PassInfo, Pass, PassContext, IRModule}; +use tvm::runtime::function::{register, Result}; +use tvm::export_pass; + +fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Function { + let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); + relay::Function::new( + func.params.clone(), + var.to_expr(), + func.ret_type.clone(), + func.type_params.clone()) +} + +fn my_pass_fn2(func: relay::Function) -> Function { + let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); + relay::Function::new( + func.params.clone(), + var.to_expr(), + func.ret_type.clone(), + func.type_params.clone()) +} + + +// fn the_pass() -> Result { +// let pass_info = PassInfo::new(15, "RustPass".into(), vec![])?; +// function_pass(my_pass_fn, pass_info) +// } + +export_pass!("out_of_tree.Pass", my_pass_fn2); diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs index 954c2fe6a808..a2a472cd3b2d 100644 --- a/rust/tvm/src/ir/array.rs +++ b/rust/tvm/src/ir/array.rs @@ -6,6 +6,7 @@ use crate::runtime::object::{ObjectRef, IsObjectRef}; use tvm_rt::{external, RetValue, function::{Function, Result}}; use tvm_rt::errors::Error; +#[repr(C)] #[derive(Clone)] pub struct Array { object: ObjectRef, diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 72fe2f1ab765..595ac9a67d59 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -169,12 +169,30 @@ impl Call { } } +#[repr(C)] +#[derive(Object)] +#[ref_name = "BaseFunc"] +#[type_key = "BaseFunc"] +pub struct BaseFuncNode { + pub base: RelayExpr, + pub attrs: ObjectRef, +} + +impl BaseFuncNode { + fn base() -> BaseFuncNode { + BaseFuncNode { + base: RelayExpr::base::(), + attrs: ObjectRef::null(), + } + } +} + #[repr(C)] #[derive(Object)] #[ref_name = "Function"] #[type_key = "relay.Function"] pub struct FunctionNode { - pub base: RelayExpr, + pub base: BaseFuncNode, pub params: Array, pub body: Expr, pub ret_type: Type, @@ -189,7 +207,7 @@ impl Function { type_params: Array, ) -> Function { let node = FunctionNode { - base: RelayExpr::base::(), + base: BaseFuncNode::base::(), params: params, body: body, ret_type: ret_type, diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index c655c0db3826..ad139f557f0b 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -49,3 +49,19 @@ pub fn function_pass Function + 'stati let func = pass_fn.to_function(); create_func_pass(func, pass_info) } + +#[macro_export] +macro_rules! export_pass { + ($name:literal,$func:expr) => { + #[no_mangle] + pub unsafe extern "C" fn initialize( + args: *mut tvm_sys::ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: tvm_sys::ffi::TVMRetValueHandle, + ) -> c_int { + register($func, $name).unwrap(); + return 0; + } +}; +} From 11cab681d083b5633f2c50d3da583789536ee0f6 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 11 Jun 2020 13:00:00 -0700 Subject: [PATCH 06/13] Delay egg example and add ASF headers --- rust/egg/Cargo.toml | 16 -------- rust/egg/import_pass.py | 27 -------------- rust/egg/src/lib.rs | 33 ----------------- rust/out-of-tree/Cargo.toml | 17 +++++++++ rust/out-of-tree/import_pass.py | 17 +++++++++ rust/out-of-tree/src/lib.rs | 31 ++++++++++------ rust/tvm-rt/src/object/object_ptr.rs | 55 +++++++++++++++++++++++++--- rust/tvm/src/ir/array.rs | 19 ++++++++++ rust/tvm/src/ir/mod.rs | 19 ++++++++++ rust/tvm/src/ir/relay/mod.rs | 19 ++++++++++ rust/tvm/src/runtime/mod.rs | 19 ++++++++++ rust/tvm/src/transform.rs | 19 ++++++++++ rust/tvm/tests/test_ir.rs | 19 ++++++++++ 13 files changed, 217 insertions(+), 93 deletions(-) delete mode 100644 rust/egg/Cargo.toml delete mode 100644 rust/egg/import_pass.py delete mode 100644 rust/egg/src/lib.rs diff --git a/rust/egg/Cargo.toml b/rust/egg/Cargo.toml deleted file mode 100644 index 996a0191244b..000000000000 --- a/rust/egg/Cargo.toml +++ /dev/null @@ -1,16 +0,0 @@ -[package] -name = "out-of-tree" -version = "0.1.0" -authors = ["Jared Roesch "] -edition = "2018" - - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[lib] -name = "my_pass" -crate-type = ["cdylib"] - -[dependencies] -tvm = { version = "0.1", path = "../tvm" } -tvm-sys = { version = "0.1", path = "../tvm-sys" } -anyhow = "*" diff --git a/rust/egg/import_pass.py b/rust/egg/import_pass.py deleted file mode 100644 index 100478db822a..000000000000 --- a/rust/egg/import_pass.py +++ /dev/null @@ -1,27 +0,0 @@ -import tvm -import tvm.relay -from tvm.ir.transform import PassContext - -x = tvm.relay.var("x", shape=(10,)) -test_func = tvm.relay.Function([x], x) -test_mod = tvm.IRModule.from_expr(test_func) - -pass_dylib = "/Users/jroesch/Git/tvm/rust/target/debug/libmy_pass.dylib" - -def load_rust_extension(ext_dylib): - load_so = tvm.get_global_func("runtime.module.loadfile_so") - mod = load_so(ext_dylib) - mod.get_function("initialize")() - - -def load_pass(pass_name, dylib): - load_rust_extension(dylib) - return tvm.get_global_func(pass_name) - -MyPass = load_pass("out_of_tree.Pass", pass_dylib) -ctx = PassContext() -import pdb; pdb.set_trace() -f = MyPass(test_func, test_mod, ctx) -mod = MyPass()(test_mod) - -print(mod) diff --git a/rust/egg/src/lib.rs b/rust/egg/src/lib.rs deleted file mode 100644 index 2526d7963785..000000000000 --- a/rust/egg/src/lib.rs +++ /dev/null @@ -1,33 +0,0 @@ -use std::ffi::c_void; -use std::os::raw::c_int; -use tvm::ir::relay::{self, Function}; -use tvm::runtime::ObjectRef; -use tvm::transform::{function_pass, PassInfo, Pass, PassContext, IRModule}; -use tvm::runtime::function::{register, Result}; -use tvm::export_pass; - -fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Function { - let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); - relay::Function::new( - func.params.clone(), - var.to_expr(), - func.ret_type.clone(), - func.type_params.clone()) -} - -fn my_pass_fn2(func: relay::Function) -> Function { - let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); - relay::Function::new( - func.params.clone(), - var.to_expr(), - func.ret_type.clone(), - func.type_params.clone()) -} - - -// fn the_pass() -> Result { -// let pass_info = PassInfo::new(15, "RustPass".into(), vec![])?; -// function_pass(my_pass_fn, pass_info) -// } - -export_pass!("out_of_tree.Pass", my_pass_fn2); diff --git a/rust/out-of-tree/Cargo.toml b/rust/out-of-tree/Cargo.toml index 996a0191244b..67fb72386a22 100644 --- a/rust/out-of-tree/Cargo.toml +++ b/rust/out-of-tree/Cargo.toml @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + [package] name = "out-of-tree" version = "0.1.0" diff --git a/rust/out-of-tree/import_pass.py b/rust/out-of-tree/import_pass.py index 100478db822a..57e3c7ff2f69 100644 --- a/rust/out-of-tree/import_pass.py +++ b/rust/out-of-tree/import_pass.py @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + import tvm import tvm.relay from tvm.ir.transform import PassContext diff --git a/rust/out-of-tree/src/lib.rs b/rust/out-of-tree/src/lib.rs index 2526d7963785..12cbc2e94208 100644 --- a/rust/out-of-tree/src/lib.rs +++ b/rust/out-of-tree/src/lib.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use std::ffi::c_void; use std::os::raw::c_int; use tvm::ir::relay::{self, Function}; @@ -15,19 +34,9 @@ fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Func func.type_params.clone()) } -fn my_pass_fn2(func: relay::Function) -> Function { - let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); - relay::Function::new( - func.params.clone(), - var.to_expr(), - func.ret_type.clone(), - func.type_params.clone()) -} - - // fn the_pass() -> Result { // let pass_info = PassInfo::new(15, "RustPass".into(), vec![])?; // function_pass(my_pass_fn, pass_info) // } -export_pass!("out_of_tree.Pass", my_pass_fn2); +export_pass!("out_of_tree.Pass", my_pass_fn); diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 40e218454f6a..026a908768c2 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -29,16 +29,36 @@ use crate::errors::Error; type Deleter = unsafe extern "C" fn(object: *mut Object) -> (); +/// A TVM intrusive smart pointer header, in TVM all FFI compatible types +/// start with an Object as their first field. The base object tracks +/// a type_index which is an index into the runtime type information +/// table, an atomic reference count, and a customized deleter which +/// will be invoked when the reference count is zero. +/// #[derive(Debug)] #[repr(C)] pub struct Object { - pub type_index: u32, + /// The index into into TVM's runtime type information table. + pub(self) type_index: u32, // TODO(@jroesch): pretty sure Rust and C++ atomics are the same, but not sure. // NB: in general we should not touch this in Rust. + /// The reference count of the smart pointer. pub(self) ref_count: AtomicI32, - pub fdeleter: Deleter, + /// The deleter function which is used to deallocate the underlying data + /// when the reference count is zero. This field must always be set for + /// all objects. + /// + /// The common use case is ensuring that the allocator which allocated the + /// data is also the one that deletes it. + pub(self) fdeleter: Deleter, } +/// The default deleter for objects allocated in Rust, we use a bit of +/// trait magic here to get a monomorphized deleter for each object +/// "subtype". +/// +/// This function just transmutes the pointer to the correct type +/// and invokes the underlying typed delete function. unsafe extern "C" fn delete(object: *mut Object) { let typed_object: *mut T = std::mem::transmute(object); T::typed_delete(typed_object); @@ -63,10 +83,12 @@ impl Object { fn new(type_index: u32, deleter: Deleter) -> Object { Object { type_index, - // Note: do not touch this field directly again, this is - // a critical section, we write a 1 to the atomic which will now - // be managed by the C++ atomics. - // In the future we should probably use C-atomcis. + // NB(@jroesch): I believe it is sound to use Rust atomics + // in conjunction with C++ atomics given the memory model + // is nearly identical. + // + // Of course these are famous last words which I may later + // regret. ref_count: AtomicI32::new(0), fdeleter: deleter, } @@ -75,6 +97,7 @@ impl Object { fn get_type_index() -> u32 { let type_key = T::TYPE_KEY; let cstring = CString::new(type_key).expect("type key must not contain null characters"); + if type_key == "Object" { return 0; } else { @@ -89,11 +112,16 @@ impl Object { } } + /// Allocates a base object value for an object subtype of type T. + /// By using associated constants and generics we can provide a + /// type indexed abstraction over allocating objects with the + /// correct index and deleter. pub fn base_object() -> Object { let index = Object::get_type_index::(); Object::new(index, delete::) } + /// Increases the object's reference count by one. pub(self) fn inc_ref(&self) { unsafe { let raw_ptr = std::mem::transmute(self); @@ -101,6 +129,7 @@ impl Object { } } + /// Decreases the object's reference count by one. pub(self) fn dec_ref(&self) { unsafe { let raw_ptr = std::mem::transmute(self); @@ -109,6 +138,13 @@ impl Object { } } +/// An unsafe trait which should be implemented for an object +/// subtype. +/// +/// The trait contains the type key needed to compute the type +/// index, a method for accessing the base object given the +/// subtype, and a typed delete method which is specialized +/// to the subtype. pub unsafe trait IsObject { const TYPE_KEY: &'static str; @@ -128,6 +164,10 @@ unsafe impl IsObject for Object { } } +/// A smart pointer for types which implement IsObject. +/// This type directly corresponds to TVM's C++ type ObjectPtr. +/// +/// See object.h for more details. #[repr(C)] pub struct ObjectPtr { pub ptr: NonNull, @@ -240,6 +280,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { RetValue::ObjectHandle(handle) => { let handle: *mut Object = unsafe { std::mem::transmute(handle) }; let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.inc_ref(); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), @@ -263,6 +304,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { ArgValue::ObjectHandle(handle) => { let handle = unsafe { std::mem::transmute(handle) }; let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.inc_ref(); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), @@ -278,6 +320,7 @@ impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr { ArgValue::ObjectHandle(handle) => { let handle = unsafe { std::mem::transmute(handle) }; let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; + optr.inc_ref(); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm/src/ir/array.rs index a2a472cd3b2d..4dd6b116eb73 100644 --- a/rust/tvm/src/ir/array.rs +++ b/rust/tvm/src/ir/array.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index bc667fdb19b8..d0b71bd7c07c 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::runtime::Object; use crate::DataType; diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 595ac9a67d59..22d4eddaa7a6 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use super::array::Array; use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, IsObjectRef}; use crate::DataType; diff --git a/rust/tvm/src/runtime/mod.rs b/rust/tvm/src/runtime/mod.rs index 57d43eea81c9..69fbb371824a 100644 --- a/rust/tvm/src/runtime/mod.rs +++ b/rust/tvm/src/runtime/mod.rs @@ -1 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + pub use tvm_rt::*; diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index ad139f557f0b..88c83b2dc494 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use crate::ir::array::Array; use crate::runtime::{external, function::{self, Result, ToFunction, Typed}, String as TString}; use crate::runtime::{Object, ObjectPtr, ObjectRef}; diff --git a/rust/tvm/tests/test_ir.rs b/rust/tvm/tests/test_ir.rs index a43f27e82eea..90e71854cc29 100644 --- a/rust/tvm/tests/test_ir.rs +++ b/rust/tvm/tests/test_ir.rs @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + use std::convert::TryInto; use std::str::FromStr; use tvm::ir::IntImmNode; From 8ef64c168c94c1c0c0465bb7459e9825bbe49322 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 16 Jun 2020 01:35:50 -0700 Subject: [PATCH 07/13] Move array.rs around --- rust/{tvm/src/ir => tvm-rt/src}/array.rs | 27 +++--------------------- rust/tvm-rt/src/lib.rs | 4 ++-- rust/tvm/src/ir/mod.rs | 1 - rust/tvm/src/ir/relay/mod.rs | 23 +++++++++++++++----- 4 files changed, 23 insertions(+), 32 deletions(-) rename rust/{tvm/src/ir => tvm-rt/src}/array.rs (73%) diff --git a/rust/tvm/src/ir/array.rs b/rust/tvm-rt/src/array.rs similarity index 73% rename from rust/tvm/src/ir/array.rs rename to rust/tvm-rt/src/array.rs index 4dd6b116eb73..b75c169a7d0f 100644 --- a/rust/tvm/src/ir/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -20,10 +20,9 @@ use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; -use crate::runtime::object::{ObjectRef, IsObjectRef}; - -use tvm_rt::{external, RetValue, function::{Function, Result}}; -use tvm_rt::errors::Error; +use crate::object::{ObjectRef, IsObjectRef}; +use crate::{external, RetValue, function::{Function, Result}}; +use crate::errors::Error; #[repr(C)] #[derive(Clone)] @@ -62,23 +61,3 @@ impl Array { oref.downcast() } } - -#[cfg(test)] -mod tests { - use super::Array; - use crate::ir::relay::Var; - use crate::runtime::object::ObjectRef; - use anyhow::Result; - - #[test] - fn create_array_and_get() -> Result<()> { - let vec = vec![ - Var::new("foo".into(), ObjectRef::null()), - Var::new("bar".into(), ObjectRef::null()), - ]; - let array = Array::from_vec(vec)?; - assert_eq!(array.get(0)?.name_hint().to_string()?, "foo"); - assert_eq!(array.get(1)?.name_hint().to_string()?, "bar"); - Ok(()) - } -} diff --git a/rust/tvm-rt/src/lib.rs b/rust/tvm-rt/src/lib.rs index 10f8317bf7bd..a56a25be82fb 100644 --- a/rust/tvm-rt/src/lib.rs +++ b/rust/tvm-rt/src/lib.rs @@ -91,10 +91,10 @@ pub(crate) fn set_last_error(err: &E) { } } -#[macro_use] -pub mod function; +pub mod array; pub mod context; pub mod errors; +pub mod function; pub mod module; pub mod ndarray; pub mod to_boxed_fn; diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index d0b71bd7c07c..9c0caddf85fc 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -20,7 +20,6 @@ use crate::runtime::Object; use crate::DataType; -pub mod array; pub mod relay; #[repr(C)] diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 22d4eddaa7a6..b24aa62ac707 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -17,8 +17,8 @@ * under the License. */ -use super::array::Array; use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, IsObjectRef}; +use crate::runtime::array::Array; use crate::DataType; use tvm_macros::Object; @@ -41,8 +41,6 @@ impl Id { } } -// define_ref!(Id, IdNode); - #[repr(C)] #[derive(Object)] #[ref_name = "BaseExpr"] @@ -98,8 +96,6 @@ impl GlobalVar { pub fn new(name_hint: String, _span: ObjectRef) -> GlobalVar { let node = GlobalVarNode { base: RelayExpr::base::(), - // span: span, - // checked_type: ObjectRef(None),, name_hint: TString::new(name_hint).unwrap(), }; GlobalVar(Some(ObjectPtr::new(node))) @@ -266,4 +262,21 @@ mod tests { assert!(cstr.into_string()?.contains("%local")); Ok(()) } + + + use super::Array; + use crate::ir::relay::Var; + use crate::runtime::object::ObjectRef; + + #[test] + fn create_array_and_get() -> Result<()> { + let vec = vec![ + Var::new("foo".into(), ObjectRef::null()), + Var::new("bar".into(), ObjectRef::null()), + ]; + let array = Array::from_vec(vec)?; + assert_eq!(array.get(0)?.name_hint().to_string()?, "foo"); + assert_eq!(array.get(1)?.name_hint().to_string()?, "bar"); + Ok(()) + } } From 4d38c09ac8aaf021c622c3d59ba02e2fd45926ab Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 16 Jun 2020 14:11:52 -0700 Subject: [PATCH 08/13] Remove outdated tests will restore in CI PR --- rust/tvm/examples/resnet/Cargo.toml | 29 ---- rust/tvm/examples/resnet/README.md | 45 ------ rust/tvm/examples/resnet/build.rs | 42 ----- rust/tvm/examples/resnet/src/build_resnet.py | 134 ---------------- rust/tvm/examples/resnet/src/main.rs | 160 ------------------- rust/tvm/src/ir/mod.rs | 16 +- rust/tvm/src/ir/relay/mod.rs | 17 +- rust/tvm/src/transform.rs | 4 +- rust/tvm/tests/basics/.gitignore | 7 - rust/tvm/tests/basics/Cargo.toml | 32 ---- rust/tvm/tests/basics/build.rs | 46 ------ rust/tvm/tests/basics/src/main.rs | 55 ------- rust/tvm/tests/basics/src/tvm_add.py | 50 ------ rust/tvm/tests/callback/Cargo.toml | 26 --- rust/tvm/tests/callback/src/bin/array.rs | 72 --------- rust/tvm/tests/callback/src/bin/error.rs | 56 ------- rust/tvm/tests/callback/src/bin/float.rs | 50 ------ rust/tvm/tests/callback/src/bin/int.rs | 49 ------ rust/tvm/tests/callback/src/bin/string.rs | 54 ------- rust/tvm/tests/test_ir.rs | 56 ------- src/printer/relay_text_printer.cc | 2 - 21 files changed, 26 insertions(+), 976 deletions(-) delete mode 100644 rust/tvm/examples/resnet/Cargo.toml delete mode 100644 rust/tvm/examples/resnet/README.md delete mode 100644 rust/tvm/examples/resnet/build.rs delete mode 100644 rust/tvm/examples/resnet/src/build_resnet.py delete mode 100644 rust/tvm/examples/resnet/src/main.rs delete mode 100644 rust/tvm/tests/basics/.gitignore delete mode 100644 rust/tvm/tests/basics/Cargo.toml delete mode 100644 rust/tvm/tests/basics/build.rs delete mode 100644 rust/tvm/tests/basics/src/main.rs delete mode 100755 rust/tvm/tests/basics/src/tvm_add.py delete mode 100644 rust/tvm/tests/callback/Cargo.toml delete mode 100644 rust/tvm/tests/callback/src/bin/array.rs delete mode 100644 rust/tvm/tests/callback/src/bin/error.rs delete mode 100644 rust/tvm/tests/callback/src/bin/float.rs delete mode 100644 rust/tvm/tests/callback/src/bin/int.rs delete mode 100644 rust/tvm/tests/callback/src/bin/string.rs delete mode 100644 rust/tvm/tests/test_ir.rs diff --git a/rust/tvm/examples/resnet/Cargo.toml b/rust/tvm/examples/resnet/Cargo.toml deleted file mode 100644 index e1a474eb5479..000000000000 --- a/rust/tvm/examples/resnet/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "resnet" -version = "0.0.0" -authors = ["TVM Contributors"] -license = "Apache-2.0" -build = "build.rs" - -[dependencies] -ndarray = "0.12" -tvm = { path = "../../" } -image = "0.20" -csv = "1.1" diff --git a/rust/tvm/examples/resnet/README.md b/rust/tvm/examples/resnet/README.md deleted file mode 100644 index d6e32f7fa768..000000000000 --- a/rust/tvm/examples/resnet/README.md +++ /dev/null @@ -1,45 +0,0 @@ - - - - - - - - - - - - - - - - - -## Resnet example - -This end-to-end example shows how to: -* build `Resnet 18` with `tvm` from Python -* use the provided Rust frontend API to test for an input image - -To run the example with pretrained resnet weights, first `tvm` and `mxnet` must be installed for the python build. To install mxnet for cpu, run `pip install mxnet` -and to install `tvm` with `llvm` follow the [TVM installation guide](https://tvm.apache.org/docs/install/index.html). - -* **Build the example**: `cargo build - -To have a successful build, note that it is required to instruct Rust compiler to link to the compiled shared library, for example with -`println!("cargo:rustc-link-search=native={}", build_path)`. See the `build.rs` for more details. - -* **Run the example**: `cargo run` - -Note: To use pretrained weights, one can enable `--pretrained` in `build.rs` with - -``` -let output = Command::new("python") - .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) - .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) - .arg(&format!("--pretrained")) - .output() - .expect("Failed to execute command"); -``` - -Otherwise, *random weights* are used, therefore, the prediction will be `limpkin, Aramus pictus`! diff --git a/rust/tvm/examples/resnet/build.rs b/rust/tvm/examples/resnet/build.rs deleted file mode 100644 index b9a3c4ccdf12..000000000000 --- a/rust/tvm/examples/resnet/build.rs +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::{path::Path, process::Command}; - -fn main() { - let output = Command::new("python3") - .arg(concat!(env!("CARGO_MANIFEST_DIR"), "/src/build_resnet.py")) - .arg(&format!("--build-dir={}", env!("CARGO_MANIFEST_DIR"))) - .output() - .expect("Failed to execute command"); - assert!( - Path::new(&format!("{}/deploy_lib.o", env!("CARGO_MANIFEST_DIR"))).exists(), - "Could not prepare demo: {}", - String::from_utf8(output.stderr) - .unwrap() - .trim() - .split("\n") - .last() - .unwrap_or("") - ); - println!( - "cargo:rustc-link-search=native={}", - env!("CARGO_MANIFEST_DIR") - ); -} diff --git a/rust/tvm/examples/resnet/src/build_resnet.py b/rust/tvm/examples/resnet/src/build_resnet.py deleted file mode 100644 index 49c67bf1c4f3..000000000000 --- a/rust/tvm/examples/resnet/src/build_resnet.py +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import argparse -import csv -import logging -from os import path as osp -import sys - -import numpy as np - -import tvm -from tvm import te -from tvm import relay -from tvm.relay import testing -from tvm.contrib import graph_runtime, cc - -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger(__name__) - -parser = argparse.ArgumentParser(description='Resnet build example') -aa = parser.add_argument -aa('--build-dir', type=str, required=True, help='directory to put the build artifacts') -aa('--pretrained', action='store_true', help='use a pretrained resnet') -aa('--batch-size', type=int, default=1, help='input image batch size') -aa('--opt-level', type=int, default=3, - help='level of optimization. 0 is unoptimized and 3 is the highest level') -aa('--target', type=str, default='llvm', help='target context for compilation') -aa('--image-shape', type=str, default='3,224,224', help='input image dimensions') -aa('--image-name', type=str, default='cat.png', help='name of input image to download') -args = parser.parse_args() - -build_dir = args.build_dir -batch_size = args.batch_size -opt_level = args.opt_level -target = tvm.target.create(args.target) -image_shape = tuple(map(int, args.image_shape.split(","))) -data_shape = (batch_size,) + image_shape - -def build(target_dir): - """ Compiles resnet18 with TVM""" - deploy_lib = osp.join(target_dir, 'deploy_lib.o') - if osp.exists(deploy_lib): - return - - if args.pretrained: - # needs mxnet installed - from mxnet.gluon.model_zoo.vision import get_model - - # if `--pretrained` is enabled, it downloads a pretrained - # resnet18 trained on imagenet1k dataset for image classification task - block = get_model('resnet18_v1', pretrained=True) - net, params = relay.frontend.from_mxnet(block, {"data": data_shape}) - # we want a probability so add a softmax operator - net = relay.Function(net.params, relay.nn.softmax(net.body), - None, net.type_params, net.attrs) - else: - # use random weights from relay.testing - net, params = relay.testing.resnet.get_workload( - num_layers=18, batch_size=batch_size, image_shape=image_shape) - - # compile the model - with relay.build_config(opt_level=opt_level): - graph, lib, params = relay.build_module.build(net, target, params=params) - - # save the model artifacts - lib.save(deploy_lib) - cc.create_shared(osp.join(target_dir, "deploy_lib.so"), - [osp.join(target_dir, "deploy_lib.o")]) - - with open(osp.join(target_dir, "deploy_graph.json"), "w") as fo: - fo.write(graph) - - with open(osp.join(target_dir,"deploy_param.params"), "wb") as fo: - fo.write(relay.save_param_dict(params)) - -def download_img_labels(): - """ Download an image and imagenet1k class labels for test""" - from mxnet.gluon.utils import download - - img_name = 'cat.png' - synset_url = ''.join(['https://gist.githubusercontent.com/zhreshold/', - '4d0b62f3d01426887599d4f7ede23ee5/raw/', - '596b27d23537e5a1b5751d2b0481ef172f58b539/', - 'imagenet1000_clsid_to_human.txt']) - synset_name = 'synset.txt' - download('https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', img_name) - download(synset_url, synset_name) - - with open(synset_name) as fin: - synset = eval(fin.read()) - - with open("synset.csv", "w") as fout: - w = csv.writer(fout) - w.writerows(synset.items()) - -def test_build(build_dir): - """ Sanity check with random input""" - graph = open(osp.join(build_dir, "deploy_graph.json")).read() - lib = tvm.runtime.load(osp.join(build_dir, "deploy_lib.so")) - params = bytearray(open(osp.join(build_dir,"deploy_param.params"), "rb").read()) - input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32")) - ctx = tvm.cpu() - module = graph_runtime.create(graph, lib, ctx) - module.load_params(params) - module.run(data=input_data) - out = module.get_output(0).asnumpy() - - -if __name__ == '__main__': - logger.info("building the model") - build(build_dir) - logger.info("build was successful") - logger.info("test the build artifacts") - test_build(build_dir) - logger.info("test was successful") - if args.pretrained: - download_img_labels() - logger.info("image and synset downloads are successful") diff --git a/rust/tvm/examples/resnet/src/main.rs b/rust/tvm/examples/resnet/src/main.rs deleted file mode 100644 index 0aed72b1eb52..000000000000 --- a/rust/tvm/examples/resnet/src/main.rs +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern crate csv; -extern crate image; -extern crate ndarray; -extern crate tvm_frontend as tvm; - -use std::{ - collections::HashMap, - convert::TryInto, - fs::{self, File}, - path::Path, - str::FromStr, -}; - -use image::{FilterType, GenericImageView}; -use ndarray::{Array, ArrayD, Axis}; - -use tvm::*; - -fn main() { - let ctx = TVMContext::cpu(0); - let img = image::open(concat!(env!("CARGO_MANIFEST_DIR"), "/cat.png")).unwrap(); - println!("original image dimensions: {:?}", img.dimensions()); - // for bigger size images, one needs to first resize to 256x256 - // with `img.resize_exact` method and then `image.crop` to 224x224 - let img = img.resize(224, 224, FilterType::Nearest).to_rgb(); - println!("resized image dimensions: {:?}", img.dimensions()); - let mut pixels: Vec = vec![]; - for pixel in img.pixels() { - let tmp = pixel.data; - // normalize the RGB channels using mean, std of imagenet1k - let tmp = [ - (tmp[0] as f32 - 123.0) / 58.395, // R - (tmp[1] as f32 - 117.0) / 57.12, // G - (tmp[2] as f32 - 104.0) / 57.375, // B - ]; - for e in &tmp { - pixels.push(*e); - } - } - - let arr = Array::from_shape_vec((224, 224, 3), pixels).unwrap(); - let arr: ArrayD = arr.permuted_axes([2, 0, 1]).into_dyn(); - // make arr shape as [1, 3, 224, 224] acceptable to resnet - let arr = arr.insert_axis(Axis(0)); - // create input tensor from rust's ndarray - let input = NDArray::from_rust_ndarray( - &arr, - TVMContext::cpu(0), - DLDataType::from_str("float32").unwrap(), - ) - .unwrap(); - println!( - "input size is {:?}", - input.shape().expect("cannot get the input shape") - ); - let graph = - fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_graph.json")).unwrap(); - // load the built module - let lib = Module::load(&Path::new(concat!( - env!("CARGO_MANIFEST_DIR"), - "/deploy_lib.so" - ))) - .unwrap(); - // get the global TVM graph runtime function - let runtime_create_fn = Function::get("tvm.graph_runtime.create").unwrap(); - let runtime_create_fn_ret = call_packed!( - runtime_create_fn, - graph, - &lib, - &ctx.device_type, - &ctx.device_id - ) - .unwrap(); - // get graph runtime module - let graph_runtime_module: Module = runtime_create_fn_ret.try_into().unwrap(); - // get the registered `load_params` from runtime module - let ref load_param_fn = graph_runtime_module - .get_function("load_params", false) - .unwrap(); - // parse parameters and convert to TVMByteArray - let params: Vec = - fs::read(concat!(env!("CARGO_MANIFEST_DIR"), "/deploy_param.params")).unwrap(); - let barr = TVMByteArray::from(¶ms); - // load the parameters - call_packed!(load_param_fn, &barr).unwrap(); - // get the set_input function - let ref set_input_fn = graph_runtime_module - .get_function("set_input", false) - .unwrap(); - - call_packed!(set_input_fn, "data".to_string(), &input).unwrap(); - // get `run` function from runtime module - let ref run_fn = graph_runtime_module.get_function("run", false).unwrap(); - // execute the run function. Note that it has no argument - call_packed!(run_fn,).unwrap(); - // prepare to get the output - let output_shape = &mut [1, 1000]; - let output = NDArray::empty( - output_shape, - TVMContext::cpu(0), - DLDataType::from_str("float32").unwrap(), - ); - // get the `get_output` function from runtime module - let ref get_output_fn = graph_runtime_module - .get_function("get_output", false) - .unwrap(); - // execute the get output function - call_packed!(get_output_fn, &0, &output).unwrap(); - // flatten the output as Vec - let output = output.to_vec::().unwrap(); - // find the maximum entry in the output and its index - let mut argmax = -1; - let mut max_prob = 0.; - for i in 0..output.len() { - if output[i] > max_prob { - max_prob = output[i]; - argmax = i as i32; - } - } - // create a hash map of (class id, class name) - let mut synset: HashMap = HashMap::new(); - let file = File::open("synset.csv").unwrap(); - let mut rdr = csv::ReaderBuilder::new() - .has_headers(true) - .from_reader(file); - - for result in rdr.records() { - let record = result.unwrap(); - let id: i32 = record[0].parse().unwrap(); - let cls = record[1].to_string(); - synset.insert(id, cls); - } - - println!( - "input image belongs to the class `{}` with probability {}", - synset - .get(&argmax) - .expect("cannot find the class id for argmax"), - max_prob - ); -} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 9c0caddf85fc..7b1a1134a695 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -16,12 +16,26 @@ * specific language governing permissions and limitations * under the License. */ +use std::ffi::CString; -use crate::runtime::Object; use crate::DataType; +use crate::runtime::{Object, IsObjectRef, ObjectRef, external}; + pub mod relay; + +// TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) +// fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> CString; +external! { + #[name("ir.AsText")] + fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: ObjectRef) -> CString; +} + +pub fn as_text(object: T) -> String { + _as_text(object.to_object_ref(), 0, ObjectRef::null()).unwrap().into_string().unwrap() +} + #[repr(C)] pub struct PrimExprNode { pub base: Object, diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index b24aa62ac707..5b6b9e525df5 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -17,7 +17,7 @@ * under the License. */ -use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString, IsObjectRef}; +use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString}; use crate::runtime::array::Array; use crate::DataType; use tvm_macros::Object; @@ -235,31 +235,32 @@ impl Function { #[cfg(test)] mod tests { use super::*; - use crate::runtime::{as_text, String as TString}; + use crate::runtime::{String as TString}; + use crate::ir::as_text; use anyhow::Result; #[test] fn test_id() -> Result<()> { let string = TString::new("foo".to_string()).expect("bar"); let id = Id::new(string); - let cstr = as_text(&id.upcast())?; - assert!(cstr.into_string()?.contains("relay.Id")); + let text = as_text(id.clone()); + assert!(text.contains("relay.Id")); Ok(()) } #[test] fn test_global() -> Result<()> { let gv = GlobalVar::new("main".to_string(), ObjectRef::null()); - let cstr = as_text(&gv.upcast())?; - assert!(cstr.into_string()?.contains("@main")); + let text = as_text(gv.clone()); + assert!(text.contains("@main")); Ok(()) } #[test] fn test_var() -> Result<()> { let var = Var::new("local".to_string(), ObjectRef::null()); - let cstr = as_text(&var.upcast())?; - assert!(cstr.into_string()?.contains("%local")); + let text = as_text(var.clone()); + assert!(text.contains("%local")); Ok(()) } diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index 88c83b2dc494..906f8a129c96 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -17,8 +17,8 @@ * under the License. */ -use crate::ir::array::Array; -use crate::runtime::{external, function::{self, Result, ToFunction, Typed}, String as TString}; +use crate::runtime::array::Array; +use crate::runtime::{external, function::{self, Result, ToFunction}, String as TString}; use crate::runtime::{Object, ObjectPtr, ObjectRef}; use crate::ir::relay::Function; diff --git a/rust/tvm/tests/basics/.gitignore b/rust/tvm/tests/basics/.gitignore deleted file mode 100644 index 10a4b225a705..000000000000 --- a/rust/tvm/tests/basics/.gitignore +++ /dev/null @@ -1,7 +0,0 @@ -/target -**/*.rs.bk -Cargo.lock -*.o -*.so -*.ptx -*.json diff --git a/rust/tvm/tests/basics/Cargo.toml b/rust/tvm/tests/basics/Cargo.toml deleted file mode 100644 index 0b059da7727b..000000000000 --- a/rust/tvm/tests/basics/Cargo.toml +++ /dev/null @@ -1,32 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "basics" -version = "0.0.0" -authors = ["TVM Contributors"] -license = "Apache-2.0" -build = "build.rs" - -[dependencies] -ndarray = "0.12" -tvm = { path = "../../" } - -[features] -default = ["cpu"] -cpu = [] -gpu = [] diff --git a/rust/tvm/tests/basics/build.rs b/rust/tvm/tests/basics/build.rs deleted file mode 100644 index 77a3bae3627d..000000000000 --- a/rust/tvm/tests/basics/build.rs +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -fn main() { - let out_dir = std::env::var("OUT_DIR").unwrap(); - - let output = std::process::Command::new(concat!(env!("CARGO_MANIFEST_DIR"), "/src/tvm_add.py")) - .args(&[ - if cfg!(feature = "cpu") { - "llvm" - } else { - "cuda" - }, - &std::env::var("OUT_DIR").unwrap(), - ]) - .output() - .expect("Failed to execute command"); - assert!( - std::path::Path::new(&format!("{}/test_add.so", out_dir)).exists(), - "Could not build tvm lib: {}", - String::from_utf8(output.stderr) - .unwrap() - .trim() - .split("\n") - .last() - .unwrap_or("") - ); - - println!("cargo:rustc-link-search=native={}", out_dir); -} diff --git a/rust/tvm/tests/basics/src/main.rs b/rust/tvm/tests/basics/src/main.rs deleted file mode 100644 index ca53dcf999dc..000000000000 --- a/rust/tvm/tests/basics/src/main.rs +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -extern crate ndarray as rust_ndarray; -extern crate tvm_frontend as tvm; - -use std::str::FromStr; - -use tvm::*; - -fn main() { - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; - - let (ctx, ctx_name) = if cfg!(feature = "cpu") { - (TVMContext::cpu(0), "cpu") - } else { - (TVMContext::gpu(0), "gpu") - }; - let dtype = DLDataType::from_str("float32").unwrap(); - let mut arr = NDArray::empty(shape, ctx, dtype); - arr.copy_from_buffer(data.as_mut_slice()); - let mut ret = NDArray::empty(shape, ctx, dtype); - let mut fadd = Module::load(&concat!(env!("OUT_DIR"), "/test_add.so")).unwrap(); - if !fadd.enabled(ctx_name) { - return; - } - if cfg!(feature = "gpu") { - fadd.import_module(Module::load(&concat!(env!("OUT_DIR"), "/test_add.ptx")).unwrap()); - } - function::Builder::from(&mut fadd) - .arg(&arr) - .arg(&arr) - .arg(&mut ret) - .invoke() - .unwrap(); - - assert_eq!(ret.to_vec::().unwrap(), vec![6f32, 8.0]); -} diff --git a/rust/tvm/tests/basics/src/tvm_add.py b/rust/tvm/tests/basics/src/tvm_add.py deleted file mode 100755 index 3911d4074e45..000000000000 --- a/rust/tvm/tests/basics/src/tvm_add.py +++ /dev/null @@ -1,50 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import os.path as osp -import sys - -import tvm -from tvm import te -from tvm.contrib import cc - - -def main(target, out_dir): - n = te.var('n') - A = te.placeholder((n,), name='A') - B = te.placeholder((n,), name='B') - C = te.compute(A.shape, lambda i: A[i] + B[i], name='C') - s = te.create_schedule(C.op) - - if target == 'cuda': - bx, tx = s[C].split(C.op.axis[0], factor=64) - s[C].bind(bx, te.thread_axis('blockIdx.x')) - s[C].bind(tx, te.thread_axis('threadIdx.x')) - - fadd = tvm.build(s, [A, B, C], target, target_host='llvm', name='myadd') - - fadd.save(osp.join(out_dir, 'test_add.o')) - if target == 'cuda': - fadd.imported_modules[0].save(osp.join(out_dir, 'test_add.ptx')) - cc.create_shared( - osp.join(out_dir, 'test_add.so'), [osp.join(out_dir, 'test_add.o')]) - - -if __name__ == '__main__': - main(sys.argv[1], sys.argv[2]) - diff --git a/rust/tvm/tests/callback/Cargo.toml b/rust/tvm/tests/callback/Cargo.toml deleted file mode 100644 index 5c89d2ac6375..000000000000 --- a/rust/tvm/tests/callback/Cargo.toml +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "callback" -version = "0.0.0" -authors = ["TVM Contributors"] -edition = "2018" - -[dependencies] -ndarray = "0.12" -tvm = { path = "../../" } diff --git a/rust/tvm/tests/callback/src/bin/array.rs b/rust/tvm/tests/callback/src/bin/array.rs deleted file mode 100644 index cb4a8229c401..000000000000 --- a/rust/tvm/tests/callback/src/bin/array.rs +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#![allow(unused_imports)] - -extern crate ndarray as rust_ndarray; -#[macro_use] -extern crate tvm_frontend as tvm; - -use rust_ndarray::ArrayD; -use std::{ - convert::{TryFrom, TryInto}, - str::FromStr, -}; - -use tvm::{errors::Error, *}; - -fn main() { - register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { - let mut ret = 0f32; - let shape = &mut [2]; - for arg in args.iter() { - let e = NDArray::empty( - shape, TVMContext::cpu(0), - DLDataType::from_str("float32").unwrap() - ); - let arg: NDArray = arg.try_into()?; - let arr = arg.copy_to_ndarray(e)?; - let rnd: ArrayD = ArrayD::try_from(&arr)?; - ret += rnd.scalar_sum(); - } - Ok(TVMRetValue::from(ret)) - } - } - - let shape = &mut [2]; - let mut data = vec![3f32, 4.0]; - let mut arr = NDArray::empty( - shape, - TVMContext::cpu(0), - DLDataType::from_str("float32").unwrap(), - ); - arr.copy_from_buffer(data.as_mut_slice()); - - let mut registered = function::Builder::default(); - let ret: f32 = registered - .get_function("sum") - .arg(&arr) - .arg(&arr) - .invoke() - .unwrap() - .try_into() - .unwrap(); - assert_eq!(ret, 7f32); -} diff --git a/rust/tvm/tests/callback/src/bin/error.rs b/rust/tvm/tests/callback/src/bin/error.rs deleted file mode 100644 index c9f9a6f771cf..000000000000 --- a/rust/tvm/tests/callback/src/bin/error.rs +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::panic; - -use tvm_frontend::{errors::Error, *}; - -fn main() { - register_global_func! { - fn error(_args: &[TVMArgValue]) -> Result { - Err(errors::TypeMismatchError{ - expected: "i64".to_string(), - actual: "f64".to_string(), - }.into()) - } - } - - let mut registered = function::Builder::default(); - registered.get_function("error"); - assert!(registered.func.is_some()); - registered.args(&[10, 20]); - - println!("expected error message is:"); - panic::set_hook(Box::new(|panic_info| { - // if let Some(msg) = panic_info.message() { - // println!("{:?}", msg); - // } - if let Some(location) = panic_info.location() { - println!( - "panic occurred in file '{}' at line {}", - location.file(), - location.line() - ); - } else { - println!("panic occurred but can't get location information"); - } - })); - - let _result = registered.invoke(); -} diff --git a/rust/tvm/tests/callback/src/bin/float.rs b/rust/tvm/tests/callback/src/bin/float.rs deleted file mode 100644 index 7111e287187f..000000000000 --- a/rust/tvm/tests/callback/src/bin/float.rs +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#![allow(unused_imports)] - -#[macro_use] -extern crate tvm_frontend as tvm; - -use std::convert::TryInto; -use tvm::{errors::Error, *}; - -fn main() { - register_global_func! { - fn sum(args: &[TVMArgValue]) -> Result { - let mut ret = 0.0; - for arg in args.into_iter() { - let val: f64 = arg.try_into()?; - ret += val; - } - Ok(TVMRetValue::from(ret)) - } - } - - let mut registered = function::Builder::default(); - registered.get_function("sum"); - assert!(registered.func.is_some()); - let ret: f64 = registered - .args(&[10.0f64, 20.0, 30.0]) - .invoke() - .unwrap() - .try_into() - .unwrap(); - assert_eq!(ret, 60f64); -} diff --git a/rust/tvm/tests/callback/src/bin/int.rs b/rust/tvm/tests/callback/src/bin/int.rs deleted file mode 100644 index 23910a3244f7..000000000000 --- a/rust/tvm/tests/callback/src/bin/int.rs +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#![allow(unused_imports)] - -extern crate tvm_frontend as tvm; - -use std::convert::TryInto; -use tvm::{errors::Error, *}; - -fn main() { - fn sum(args: &[TVMArgValue]) -> Result { - let mut ret = 0i64; - for arg in args.iter() { - let val: i64 = arg.try_into()?; - ret += val; - } - Ok(TVMRetValue::from(ret)) - } - - tvm::function::register(sum, "mysum".to_owned(), false).unwrap(); - - let mut registered = function::Builder::default(); - registered.get_function("mysum"); - assert!(registered.func.is_some()); - let ret: i64 = registered - .args(&[10, 20, 30]) - .invoke() - .unwrap() - .try_into() - .unwrap(); - assert_eq!(ret, 60); -} diff --git a/rust/tvm/tests/callback/src/bin/string.rs b/rust/tvm/tests/callback/src/bin/string.rs deleted file mode 100644 index 9ead58733bbb..000000000000 --- a/rust/tvm/tests/callback/src/bin/string.rs +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#![allow(unused_imports)] - -#[macro_use] -extern crate tvm_frontend as tvm; -use std::convert::TryInto; -use tvm::{errors::Error, *}; - -// FIXME -fn main() { - register_global_func! { - fn concate_str(args: &[TVMArgValue]) -> Result { - let mut ret = "".to_string(); - for arg in args.iter() { - let val: &str = arg.try_into()?; - ret += val; - } - Ok(TVMRetValue::from(ret)) - } - } - let a = std::ffi::CString::new("a").unwrap(); - let b = std::ffi::CString::new("b").unwrap(); - let c = std::ffi::CString::new("c").unwrap(); - let mut registered = function::Builder::default(); - registered.get_function("concate_str"); - assert!(registered.func.is_some()); - let ret: String = registered - .arg(a.as_c_str()) - .arg(b.as_c_str()) - .arg(c.as_c_str()) - .invoke() - .unwrap() - .try_into() - .unwrap(); - assert_eq!(ret, "abc".to_owned()); -} diff --git a/rust/tvm/tests/test_ir.rs b/rust/tvm/tests/test_ir.rs deleted file mode 100644 index 90e71854cc29..000000000000 --- a/rust/tvm/tests/test_ir.rs +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::convert::TryInto; -use std::str::FromStr; -use tvm::ir::IntImmNode; -use tvm::runtime::String as TString; -use tvm::runtime::{debug_print, Object, ObjectPtr, ObjectRef}; -use tvm_rt::{call_packed, DLDataType, Function}; -use tvm_sys::TVMRetValue; - -#[test] -fn test_new_object() -> anyhow::Result<()> { - let object = Object::base_object::(); - let ptr = ObjectPtr::new(object); - assert_eq!(ptr.count(), 1); - Ok(()) -} - -#[test] -fn test_new_string() -> anyhow::Result<()> { - let string = TString::new("hello world!".to_string())?; - Ok(()) -} - -#[test] -fn test_obj_build() -> anyhow::Result<()> { - let int_imm = Function::get("ir.IntImm").expect("Stable TVM API not found."); - - let dt = DLDataType::from_str("int32").expect("Known datatype doesn't convert."); - - let ret_val: ObjectRef = call_packed!(int_imm, dt, 1337) - .expect("foo") - .try_into() - .unwrap(); - - debug_print(&ret_val); - - Ok(()) -} diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 981d0c357e24..1f7a48ef40b0 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -824,9 +824,7 @@ std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { } TVM_REGISTER_GLOBAL("ir.TextPrinter").set_body_typed([](ObjectRef node) { - std::cout << "The program: " << node << std::endl; auto text = AsText(node, false, nullptr); - std::cout << "The text " << text; return text; }); From 328808b5cdb130f8782779232d67f9eef76664e5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 16 Jun 2020 15:13:55 -0700 Subject: [PATCH 09/13] Fix some memory issues --- rust/tvm-macros/src/object.rs | 9 ------- rust/tvm-rt/src/function.rs | 10 ++++++- rust/tvm-rt/src/ndarray.rs | 22 +++++++-------- rust/tvm-rt/src/object/object_ptr.rs | 30 +++++++++------------ rust/tvm-rt/src/string.rs | 40 ++++++++++++++-------------- rust/tvm/src/ir/mod.rs | 12 ++++----- rust/tvm/src/ir/relay/mod.rs | 2 +- 7 files changed, 58 insertions(+), 67 deletions(-) diff --git a/rust/tvm-macros/src/object.rs b/rust/tvm-macros/src/object.rs index 55e983e2ce6c..0170e1d71d41 100644 --- a/rust/tvm-macros/src/object.rs +++ b/rust/tvm-macros/src/object.rs @@ -142,15 +142,6 @@ pub fn macro_impl(input: proc_macro::TokenStream) -> TokenStream { } } - impl<'a> std::convert::TryFrom<&#tvm_rt_crate::ArgValue<'a>> for #ref_id { - type Error = #error; - - fn try_from(arg_value: &#tvm_rt_crate::ArgValue<'a>) -> #result<#ref_id> { - use std::convert::TryInto; - let optr = arg_value.try_into()?; - Ok(#ref_id(Some(optr))) - } - } impl From<#ref_id> for #tvm_rt_crate::RetValue { fn from(object_ref: #ref_id) -> #tvm_rt_crate::RetValue { diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index 32423d39d43b..f582973d61cf 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -66,6 +66,10 @@ impl Function { } } + pub unsafe fn null() -> Self { + Function { handle: std::ptr::null_mut(), is_global: false, from_rust: false } + } + /// For a given function, it returns a function by name. pub fn get>(name: S) -> Option { let name = CString::new(name.as_ref()).unwrap(); @@ -172,7 +176,11 @@ impl TryFrom for Function { impl<'a> From for ArgValue<'a> { fn from(func: Function) -> ArgValue<'a> { - ArgValue::FuncHandle(func.handle) + if func.handle.is_null() { + ArgValue::Null + } else { + ArgValue::FuncHandle(func.handle) + } } } diff --git a/rust/tvm-rt/src/ndarray.rs b/rust/tvm-rt/src/ndarray.rs index b7ae4622849d..24fa5e0dfcbc 100644 --- a/rust/tvm-rt/src/ndarray.rs +++ b/rust/tvm-rt/src/ndarray.rs @@ -411,17 +411,17 @@ mod tests { assert_eq!(nd.unwrap().to_vec::().unwrap(), data); } - #[test] - #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] - fn copy_wrong_dtype() { - let shape = vec![4]; - let mut data = vec![1f32, 2., 3., 4.]; - let ctx = Context::cpu(0); - let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); - nd_float.copy_from_buffer(&mut data); - let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); - nd_float.copy_to_ndarray(empty_int).unwrap(); - } + // #[test] + // #[should_panic(expected = "called `Result::unwrap()` on an `Err`")] + // fn copy_wrong_dtype() { + // let shape = vec![4]; + // let mut data = vec![1f32, 2., 3., 4.]; + // let ctx = Context::cpu(0); + // let mut nd_float = NDArray::empty(&shape, ctx, DataType::from_str("float32").unwrap()); + // nd_float.copy_from_buffer(&mut data); + // let empty_int = NDArray::empty(&shape, ctx, DataType::from_str("int32").unwrap()); + // nd_float.copy_to_ndarray(empty_int).unwrap(); + // } #[test] fn rust_ndarray() { diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 026a908768c2..a935ae524c2f 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -280,7 +280,7 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { RetValue::ObjectHandle(handle) => { let handle: *mut Object = unsafe { std::mem::transmute(handle) }; let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; - optr.inc_ref(); + debug_assert!(optr.count() >= 1); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", ret_value), "ObjectHandle")), @@ -290,7 +290,9 @@ impl<'a, T: IsObject> TryFrom for ObjectPtr { impl<'a, T: IsObject> From> for ArgValue<'a> { fn from(object_ptr: ObjectPtr) -> ArgValue<'a> { + debug_assert!(object_ptr.count() >= 1); let raw_object_ptr = ObjectPtr::leak(object_ptr); + let void_ptr = unsafe { std::mem::transmute(raw_object_ptr) }; ArgValue::ObjectHandle(void_ptr) } @@ -304,7 +306,7 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { ArgValue::ObjectHandle(handle) => { let handle = unsafe { std::mem::transmute(handle) }; let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; - optr.inc_ref(); + debug_assert!(optr.count() >= 1); optr.downcast() } _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), @@ -312,21 +314,6 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { } } -impl<'a, T: IsObject> TryFrom<&ArgValue<'a>> for ObjectPtr { - type Error = Error; - - fn try_from(arg_value: &ArgValue<'a>) -> Result, Self::Error> { - match arg_value { - ArgValue::ObjectHandle(handle) => { - let handle = unsafe { std::mem::transmute(handle) }; - let optr = ObjectPtr::from_raw(handle).ok_or(Error::Null)?; - optr.inc_ref(); - optr.downcast() - } - _ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")), - } - } -} #[cfg(test)] mod tests { @@ -376,6 +363,7 @@ mod tests { } fn test_fn(o: ObjectPtr) -> ObjectPtr { + // The call machinery adds at least 1 extra count while inside the call. assert_eq!(o.count(), 2); return o; } @@ -384,13 +372,19 @@ mod tests { fn test_ref_count_boundary() { use super::*; use crate::function::{register, Function, Result}; + // 1 let ptr = ObjectPtr::new(Object::base_object::()); + assert_eq!(ptr.count(), 1); + // 2 let stay = ptr.clone(); assert_eq!(ptr.count(), 2); register(test_fn, "my_func").unwrap(); let func = Function::get("my_func").unwrap(); let func = func.to_boxed_fn::) -> Result>>(); - func(ptr).unwrap(); + let same = func(ptr).unwrap(); + assert_eq!(stay.count(), 2); + assert_eq!(same.count(), 2); + drop(same); assert_eq!(stay.count(), 1); } } diff --git a/rust/tvm-rt/src/string.rs b/rust/tvm-rt/src/string.rs index 2805d4c300f5..7727e4be2409 100644 --- a/rust/tvm-rt/src/string.rs +++ b/rust/tvm-rt/src/string.rs @@ -69,24 +69,24 @@ impl String { } } -// #[cfg(test)] -// mod tests { -// use super::String; -// use crate::object::debug_print; -// use crate::IsObjectRef; -// use anyhow::{ensure, Result}; +#[cfg(test)] +mod tests { + use super::String; + use crate::object::debug_print; + use crate::IsObjectRef; + use anyhow::{ensure, Result}; -// #[test] -// fn test_string_debug() -> Result<()> { -// let s = String::new("foo".to_string()).unwrap(); -// let object_ref = s.to_object_ref(); -// println!("about to call"); -// let string = debug_print(object_ref)?; -// println!("after call"); -// ensure!( -// string.into_string().expect("is cstring").contains("foo"), -// "string content is invalid" -// ); -// Ok(()) -// } -// } + #[test] + fn test_string_debug() -> Result<()> { + let s = String::new("foo".to_string()).unwrap(); + let object_ref = s.to_object_ref(); + println!("about to call"); + let string = debug_print(object_ref)?; + println!("after call"); + ensure!( + string.into_string().expect("is cstring").contains("foo"), + "string content is invalid" + ); + Ok(()) + } +} diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index 7b1a1134a695..f42c2da55cb5 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -16,24 +16,22 @@ * specific language governing permissions and limitations * under the License. */ -use std::ffi::CString; use crate::DataType; -use crate::runtime::{Object, IsObjectRef, ObjectRef, external}; - +use crate::runtime::{self, Object, IsObjectRef, ObjectRef, external}; +use crate::runtime::{String as TString}; pub mod relay; - // TODO: figure out how to type the last argument runtime::TypedPackedFunc annotate) -// fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> CString; external! { #[name("ir.AsText")] - fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: ObjectRef) -> CString; + fn _as_text(object: ObjectRef, show_meta_data: i32, annotate: runtime::Function) -> TString; } pub fn as_text(object: T) -> String { - _as_text(object.to_object_ref(), 0, ObjectRef::null()).unwrap().into_string().unwrap() + let no_func = unsafe { runtime::Function::null() }; + _as_text(object.to_object_ref(), 0, no_func).unwrap().to_string().unwrap() } #[repr(C)] diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index 5b6b9e525df5..a85aa1965c96 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -86,7 +86,7 @@ impl RelayExpr { #[repr(C)] #[derive(Object)] #[ref_name = "GlobalVar"] -#[type_key = "relay.GlobalVar"] +#[type_key = "GlobalVar"] pub struct GlobalVarNode { pub base: RelayExpr, pub name_hint: TString, From c0a140cc17d1946b147ff50cc6d19ea8f26fe96c Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 16 Jun 2020 15:43:30 -0700 Subject: [PATCH 10/13] Fix ref counting issue --- rust/tvm-rt/src/array.rs | 10 +++++++--- rust/tvm-rt/src/object/mod.rs | 24 ++++-------------------- rust/tvm-rt/src/object/object_ptr.rs | 22 ++++++++++++++++++---- 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index b75c169a7d0f..2816e760be06 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -20,7 +20,7 @@ use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; -use crate::object::{ObjectRef, IsObjectRef}; +use crate::object::{ObjectRef, IsObjectRef, ObjectPtr, Object}; use crate::{external, RetValue, function::{Function, Result}}; use crate::errors::Error; @@ -45,10 +45,14 @@ impl Array { let func = Function::get("node.Array") .expect("node.Array function is not registered, this is most likely a build or linking error"); - let array_data = func.invoke(iter)?.try_into()?; + // let array_data = func.invoke(iter)?; + // let array_data: ObjectRef = func.invoke(iter)?.try_into()?; + let array_data: ObjectPtr = func.invoke(iter)?.try_into()?; + + debug_assert!(array_data.count() >= 1, "array reference count is {}", array_data.count()); Ok(Array { - object: array_data, + object: ObjectRef(Some(array_data)), _data: PhantomData, }) } diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index dc5668128759..b71174ee8326 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -94,39 +94,23 @@ impl<'a> std::convert::TryFrom> for ObjectRef { type Error = Error; fn try_from(arg_value: ArgValue<'a>) -> Result { - let optr = arg_value.try_into()?; + let optr: ObjectPtr = arg_value.try_into()?; + debug_assert!(optr.count() >= 1); Ok(ObjectRef(Some(optr))) } } -impl<'a> std::convert::TryFrom<&ArgValue<'a>> for ObjectRef { - type Error = Error; - - fn try_from(arg_value: &ArgValue<'a>) -> Result { - // TODO(@jroesch): remove the clone - let value: ArgValue<'a> = arg_value.clone(); - ObjectRef::try_from(value) - } -} - impl<'a> From for ArgValue<'a> { fn from(object_ref: ObjectRef) -> ArgValue<'a> { use std::ffi::c_void; - let object_ptr = &object_ref.0; + let object_ptr = object_ref.0; match object_ptr { None => ArgValue::ObjectHandle(std::ptr::null::() as *mut c_void), - Some(value) => value.clone().into(), + Some(value) => value.into(), } } } -impl<'a> From<&ObjectRef> for ArgValue<'a> { - fn from(object_ref: &ObjectRef) -> ArgValue<'a> { - let oref: ObjectRef = object_ref.clone(); - ArgValue::<'a>::from(oref) - } -} - external! { #[name("ir.DebugPrint")] fn debug_print(object: ObjectRef) -> CString; diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index a935ae524c2f..5f587ca3b00a 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -112,6 +112,13 @@ impl Object { } } + pub fn count(&self) -> i32 { + // need to do atomic read in C++ + // ABI compatible atomics is funky/hard. + self.ref_count + .load(std::sync::atomic::Ordering::SeqCst) + } + /// Allocates a base object value for an object subtype of type T. /// By using associated constants and generics we can provide a /// type indexed abstraction over allocating objects with the @@ -184,7 +191,10 @@ fn dec_ref(ptr: NonNull) { impl ObjectPtr { fn from_raw(object_ptr: *mut Object) -> Option> { let non_null = NonNull::new(object_ptr); - non_null.map(|ptr| ObjectPtr { ptr }) + non_null.map(|ptr| { + debug_assert!(unsafe { ptr.as_ref().count() } >= 0); + ObjectPtr { ptr } + }) } } @@ -247,9 +257,9 @@ impl ObjectPtr { }; if is_derived { - Ok(ObjectPtr { - ptr: self.ptr.cast(), - }) + let ptr = self.ptr.cast(); + inc_ref(ptr); + Ok(ObjectPtr { ptr }) } else { Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY)) } @@ -335,6 +345,8 @@ mod tests { let ptr = ObjectPtr::new(Object::base_object::()); let ret_value: RetValue = ptr.clone().into(); let ptr2: ObjectPtr = ret_value.try_into()?; + assert_eq!(ptr.count(), ptr2.count()); + assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, "type indices do not match" @@ -351,6 +363,8 @@ mod tests { let ptr = ObjectPtr::new(Object::base_object::()); let arg_value: ArgValue = ptr.clone().into(); let ptr2: ObjectPtr = arg_value.try_into()?; + assert_eq!(ptr.count(), ptr2.count()); + assert_eq!(ptr.count(), 2); ensure!( ptr.type_index == ptr2.type_index, "type indices do not match" From 1eb971b21d52dd7d79cc86bc822622da5f9a65bf Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 16 Jun 2020 15:54:23 -0700 Subject: [PATCH 11/13] Formatting and cleanup --- rust/out-of-tree/src/lib.rs | 9 +++--- rust/runtime/tests/test_wasm32/Cargo.toml | 5 +++- rust/runtime/tests/test_wasm32/build.rs | 14 ++++++--- rust/tvm-rt/src/array.rs | 28 ++++++++++++------ rust/tvm-rt/src/function.rs | 9 ++++-- rust/tvm-rt/src/object/mod.rs | 4 +-- rust/tvm-rt/src/object/object_ptr.rs | 4 +-- rust/tvm-rt/src/to_function.rs | 4 +-- rust/tvm-sys/src/lib.rs | 6 ++-- rust/tvm/src/ir/mod.rs | 9 ++++-- rust/tvm/src/ir/relay/mod.rs | 5 ++-- rust/tvm/src/transform.rs | 35 ++++++++++++++--------- 12 files changed, 82 insertions(+), 50 deletions(-) diff --git a/rust/out-of-tree/src/lib.rs b/rust/out-of-tree/src/lib.rs index 12cbc2e94208..85a0fba42608 100644 --- a/rust/out-of-tree/src/lib.rs +++ b/rust/out-of-tree/src/lib.rs @@ -19,11 +19,11 @@ use std::ffi::c_void; use std::os::raw::c_int; +use tvm::export_pass; use tvm::ir::relay::{self, Function}; -use tvm::runtime::ObjectRef; -use tvm::transform::{function_pass, PassInfo, Pass, PassContext, IRModule}; use tvm::runtime::function::{register, Result}; -use tvm::export_pass; +use tvm::runtime::ObjectRef; +use tvm::transform::{function_pass, IRModule, Pass, PassContext, PassInfo}; fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Function { let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); @@ -31,7 +31,8 @@ fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Func func.params.clone(), var.to_expr(), func.ret_type.clone(), - func.type_params.clone()) + func.type_params.clone(), + ) } // fn the_pass() -> Result { diff --git a/rust/runtime/tests/test_wasm32/Cargo.toml b/rust/runtime/tests/test_wasm32/Cargo.toml index 51f15ff08b67..eeead4587de0 100644 --- a/rust/runtime/tests/test_wasm32/Cargo.toml +++ b/rust/runtime/tests/test_wasm32/Cargo.toml @@ -20,8 +20,11 @@ name = "test-wasm32" version = "0.0.0" license = "Apache-2.0" authors = ["TVM Contributors"] +edition = "2018" [dependencies] -anyhow = "*" ndarray="0.12" tvm-runtime = { path = "../../" } + +[build-dependencies] +anyhow = "^1.0" diff --git a/rust/runtime/tests/test_wasm32/build.rs b/rust/runtime/tests/test_wasm32/build.rs index 8b72be290267..5c816c336825 100644 --- a/rust/runtime/tests/test_wasm32/build.rs +++ b/rust/runtime/tests/test_wasm32/build.rs @@ -19,12 +19,14 @@ use std::{path::PathBuf, process::Command}; -fn main() { +use anyhow::{Context, Result}; + +fn main() -> Result<()> { let mut out_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")); out_dir.push("lib"); if !out_dir.is_dir() { - std::fs::create_dir(&out_dir).unwrap(); + std::fs::create_dir(&out_dir).context("failed to create directory for WASM outputs")?; } let obj_file = out_dir.join("test.o"); @@ -36,7 +38,8 @@ fn main() { )) .arg(&out_dir) .output() - .expect("Failed to execute command"); + .context("failed to execute Python script for generating TVM library")?; + assert!( obj_file.exists(), "Could not build tvm lib: {}", @@ -49,12 +52,14 @@ fn main() { ); let ar = option_env!("LLVM_AR").unwrap_or("llvm-ar-8"); + let output = Command::new(ar) .arg("rcs") .arg(&lib_file) .arg(&obj_file) .output() - .expect("Failed to execute command"); + .context("failed to run LLVM_AR command")?; + assert!( lib_file.exists(), "Could not create archive: {}", @@ -68,4 +73,5 @@ fn main() { println!("cargo:rustc-link-lib=static=test_wasm32"); println!("cargo:rustc-link-search=native={}", out_dir.display()); + Ok(()) } diff --git a/rust/tvm-rt/src/array.rs b/rust/tvm-rt/src/array.rs index 2816e760be06..128bb879843b 100644 --- a/rust/tvm-rt/src/array.rs +++ b/rust/tvm-rt/src/array.rs @@ -20,9 +20,13 @@ use std::convert::{TryFrom, TryInto}; use std::marker::PhantomData; -use crate::object::{ObjectRef, IsObjectRef, ObjectPtr, Object}; -use crate::{external, RetValue, function::{Function, Result}}; use crate::errors::Error; +use crate::object::{IsObjectRef, Object, ObjectPtr, ObjectRef}; +use crate::{ + external, + function::{Function, Result}, + RetValue, +}; #[repr(C)] #[derive(Clone)] @@ -40,16 +44,24 @@ external! { impl Array { pub fn from_vec(data: Vec) -> Result> { - let iter = data.iter().map(|element| element.to_object_ref().into()).collect(); + let iter = data + .iter() + .map(|element| element.to_object_ref().into()) + .collect(); - let func = Function::get("node.Array") - .expect("node.Array function is not registered, this is most likely a build or linking error"); + let func = Function::get("node.Array").expect( + "node.Array function is not registered, this is most likely a build or linking error", + ); // let array_data = func.invoke(iter)?; // let array_data: ObjectRef = func.invoke(iter)?.try_into()?; let array_data: ObjectPtr = func.invoke(iter)?.try_into()?; - debug_assert!(array_data.count() >= 1, "array reference count is {}", array_data.count()); + debug_assert!( + array_data.count() >= 1, + "array reference count is {}", + array_data.count() + ); Ok(Array { object: ObjectRef(Some(array_data)), @@ -61,7 +73,7 @@ impl Array { where T: TryFrom, { - let oref: ObjectRef = array_get_item(self.object.clone(), index)?; - oref.downcast() + let oref: ObjectRef = array_get_item(self.object.clone(), index)?; + oref.downcast() } } diff --git a/rust/tvm-rt/src/function.rs b/rust/tvm-rt/src/function.rs index f582973d61cf..0772e96e4984 100644 --- a/rust/tvm-rt/src/function.rs +++ b/rust/tvm-rt/src/function.rs @@ -32,13 +32,12 @@ use std::{ ptr, str, }; - use crate::errors::Error; use super::to_boxed_fn::ToBoxedFn; -pub use tvm_sys::{ffi, ArgValue, RetValue}; pub use super::to_function::{ToFunction, Typed}; +pub use tvm_sys::{ffi, ArgValue, RetValue}; pub type Result = std::result::Result; @@ -67,7 +66,11 @@ impl Function { } pub unsafe fn null() -> Self { - Function { handle: std::ptr::null_mut(), is_global: false, from_rust: false } + Function { + handle: std::ptr::null_mut(), + is_global: false, + from_rust: false, + } } /// For a given function, it returns a function by name. diff --git a/rust/tvm-rt/src/object/mod.rs b/rust/tvm-rt/src/object/mod.rs index b71174ee8326..e6375bfa09dd 100644 --- a/rust/tvm-rt/src/object/mod.rs +++ b/rust/tvm-rt/src/object/mod.rs @@ -50,9 +50,7 @@ pub trait IsObjectRef: Sized { } fn downcast(&self) -> Result { - let ptr = - self.as_object_ptr() - .map(|ptr| ptr.downcast::()); + let ptr = self.as_object_ptr().map(|ptr| ptr.downcast::()); let ptr = ptr.transpose()?; Ok(U::from_object_ptr(ptr)) } diff --git a/rust/tvm-rt/src/object/object_ptr.rs b/rust/tvm-rt/src/object/object_ptr.rs index 5f587ca3b00a..ddcbff92c604 100644 --- a/rust/tvm-rt/src/object/object_ptr.rs +++ b/rust/tvm-rt/src/object/object_ptr.rs @@ -115,8 +115,7 @@ impl Object { pub fn count(&self) -> i32 { // need to do atomic read in C++ // ABI compatible atomics is funky/hard. - self.ref_count - .load(std::sync::atomic::Ordering::SeqCst) + self.ref_count.load(std::sync::atomic::Ordering::SeqCst) } /// Allocates a base object value for an object subtype of type T. @@ -324,7 +323,6 @@ impl<'a, T: IsObject> TryFrom> for ObjectPtr { } } - #[cfg(test)] mod tests { use super::{Object, ObjectPtr}; diff --git a/rust/tvm-rt/src/to_function.rs b/rust/tvm-rt/src/to_function.rs index 7a9bbeaf3a48..4fc021adb5ab 100644 --- a/rust/tvm-rt/src/to_function.rs +++ b/rust/tvm-rt/src/to_function.rs @@ -53,7 +53,7 @@ impl Typed<(), O> for F where F: Fn() -> O, Error: From, - O: TryInto + O: TryInto, { fn args(_args: &[ArgValue<'static>]) -> Result<()> { debug_assert!(_args.len() == 0); @@ -113,7 +113,7 @@ where A: TryFrom, Error = E1>, B: TryFrom, Error = E1>, C: TryFrom, Error = E1>, - O: TryInto + O: TryInto, { fn args(args: &[ArgValue<'static>]) -> Result<(A, B, C)> { debug_assert!(args.len() == 3); diff --git a/rust/tvm-sys/src/lib.rs b/rust/tvm-sys/src/lib.rs index 2aa6122af674..231569ba682e 100644 --- a/rust/tvm-sys/src/lib.rs +++ b/rust/tvm-sys/src/lib.rs @@ -59,8 +59,10 @@ pub use errors::*; pub use packed_func::{ArgValue, RetValue}; impl std::convert::TryFrom> for RetValue -where RetValue: std::convert::TryFrom, - E: From<>::Error> { +where + RetValue: std::convert::TryFrom, + E: From<>::Error>, +{ type Error = E; fn try_from(val: Result) -> Result { diff --git a/rust/tvm/src/ir/mod.rs b/rust/tvm/src/ir/mod.rs index f42c2da55cb5..4fe13a32ea35 100644 --- a/rust/tvm/src/ir/mod.rs +++ b/rust/tvm/src/ir/mod.rs @@ -17,9 +17,9 @@ * under the License. */ +use crate::runtime::String as TString; +use crate::runtime::{self, external, IsObjectRef, Object, ObjectRef}; use crate::DataType; -use crate::runtime::{self, Object, IsObjectRef, ObjectRef, external}; -use crate::runtime::{String as TString}; pub mod relay; @@ -31,7 +31,10 @@ external! { pub fn as_text(object: T) -> String { let no_func = unsafe { runtime::Function::null() }; - _as_text(object.to_object_ref(), 0, no_func).unwrap().to_string().unwrap() + _as_text(object.to_object_ref(), 0, no_func) + .unwrap() + .to_string() + .unwrap() } #[repr(C)] diff --git a/rust/tvm/src/ir/relay/mod.rs b/rust/tvm/src/ir/relay/mod.rs index a85aa1965c96..cad41acfc307 100644 --- a/rust/tvm/src/ir/relay/mod.rs +++ b/rust/tvm/src/ir/relay/mod.rs @@ -17,8 +17,8 @@ * under the License. */ -use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString}; use crate::runtime::array::Array; +use crate::runtime::{IsObject, Object, ObjectPtr, ObjectRef, String as TString}; use crate::DataType; use tvm_macros::Object; @@ -235,8 +235,8 @@ impl Function { #[cfg(test)] mod tests { use super::*; - use crate::runtime::{String as TString}; use crate::ir::as_text; + use crate::runtime::String as TString; use anyhow::Result; #[test] @@ -264,7 +264,6 @@ mod tests { Ok(()) } - use super::Array; use crate::ir::relay::Var; use crate::runtime::object::ObjectRef; diff --git a/rust/tvm/src/transform.rs b/rust/tvm/src/transform.rs index 906f8a129c96..ab84202af4fa 100644 --- a/rust/tvm/src/transform.rs +++ b/rust/tvm/src/transform.rs @@ -17,10 +17,14 @@ * under the License. */ +use crate::ir::relay::Function; use crate::runtime::array::Array; -use crate::runtime::{external, function::{self, Result, ToFunction}, String as TString}; +use crate::runtime::{ + external, + function::{self, Result, ToFunction}, + String as TString, +}; use crate::runtime::{Object, ObjectPtr, ObjectRef}; -use crate::ir::relay::Function; use tvm_macros::Object; @@ -64,7 +68,10 @@ external! { fn create_func_pass(func: function::Function, pass_info: PassInfo) -> Pass; } -pub fn function_pass Function + 'static>(pass_fn: F, pass_info: PassInfo) -> Result { +pub fn function_pass Function + 'static>( + pass_fn: F, + pass_info: PassInfo, +) -> Result { let func = pass_fn.to_function(); create_func_pass(func, pass_info) } @@ -72,15 +79,15 @@ pub fn function_pass Function + 'stati #[macro_export] macro_rules! export_pass { ($name:literal,$func:expr) => { - #[no_mangle] - pub unsafe extern "C" fn initialize( - args: *mut tvm_sys::ffi::TVMValue, - type_codes: *mut c_int, - num_args: c_int, - ret: tvm_sys::ffi::TVMRetValueHandle, - ) -> c_int { - register($func, $name).unwrap(); - return 0; - } -}; + #[no_mangle] + pub unsafe extern "C" fn initialize( + args: *mut tvm_sys::ffi::TVMValue, + type_codes: *mut c_int, + num_args: c_int, + ret: tvm_sys::ffi::TVMRetValueHandle, + ) -> c_int { + register($func, $name).unwrap(); + return 0; + } + }; } From 205d2b64b9f5cc74e3b50d8a7c0d3ff3e2909dde Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Tue, 16 Jun 2020 16:01:38 -0700 Subject: [PATCH 12/13] Remove out-of-tree for now --- rust/out-of-tree/Cargo.toml | 33 ------------------------- rust/out-of-tree/import_pass.py | 44 --------------------------------- rust/out-of-tree/src/lib.rs | 43 -------------------------------- 3 files changed, 120 deletions(-) delete mode 100644 rust/out-of-tree/Cargo.toml delete mode 100644 rust/out-of-tree/import_pass.py delete mode 100644 rust/out-of-tree/src/lib.rs diff --git a/rust/out-of-tree/Cargo.toml b/rust/out-of-tree/Cargo.toml deleted file mode 100644 index 67fb72386a22..000000000000 --- a/rust/out-of-tree/Cargo.toml +++ /dev/null @@ -1,33 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -[package] -name = "out-of-tree" -version = "0.1.0" -authors = ["Jared Roesch "] -edition = "2018" - - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html -[lib] -name = "my_pass" -crate-type = ["cdylib"] - -[dependencies] -tvm = { version = "0.1", path = "../tvm" } -tvm-sys = { version = "0.1", path = "../tvm-sys" } -anyhow = "*" diff --git a/rust/out-of-tree/import_pass.py b/rust/out-of-tree/import_pass.py deleted file mode 100644 index 57e3c7ff2f69..000000000000 --- a/rust/out-of-tree/import_pass.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import tvm -import tvm.relay -from tvm.ir.transform import PassContext - -x = tvm.relay.var("x", shape=(10,)) -test_func = tvm.relay.Function([x], x) -test_mod = tvm.IRModule.from_expr(test_func) - -pass_dylib = "/Users/jroesch/Git/tvm/rust/target/debug/libmy_pass.dylib" - -def load_rust_extension(ext_dylib): - load_so = tvm.get_global_func("runtime.module.loadfile_so") - mod = load_so(ext_dylib) - mod.get_function("initialize")() - - -def load_pass(pass_name, dylib): - load_rust_extension(dylib) - return tvm.get_global_func(pass_name) - -MyPass = load_pass("out_of_tree.Pass", pass_dylib) -ctx = PassContext() -import pdb; pdb.set_trace() -f = MyPass(test_func, test_mod, ctx) -mod = MyPass()(test_mod) - -print(mod) diff --git a/rust/out-of-tree/src/lib.rs b/rust/out-of-tree/src/lib.rs deleted file mode 100644 index 85a0fba42608..000000000000 --- a/rust/out-of-tree/src/lib.rs +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -use std::ffi::c_void; -use std::os::raw::c_int; -use tvm::export_pass; -use tvm::ir::relay::{self, Function}; -use tvm::runtime::function::{register, Result}; -use tvm::runtime::ObjectRef; -use tvm::transform::{function_pass, IRModule, Pass, PassContext, PassInfo}; - -fn my_pass_fn(func: relay::Function, module: IRModule, ctx: PassContext) -> Function { - let var = relay::Var::new("Hi from Rust!".into(), ObjectRef::null()); - relay::Function::new( - func.params.clone(), - var.to_expr(), - func.ret_type.clone(), - func.type_params.clone(), - ) -} - -// fn the_pass() -> Result { -// let pass_info = PassInfo::new(15, "RustPass".into(), vec![])?; -// function_pass(my_pass_fn, pass_info) -// } - -export_pass!("out_of_tree.Pass", my_pass_fn); From e8f7f34e50445a69daec8c036dcd5a1aa929a144 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 17 Jun 2020 10:08:44 -0700 Subject: [PATCH 13/13] Remove out-of-tree --- rust/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/rust/Cargo.toml b/rust/Cargo.toml index afe62071116a..d9bb3ab065fd 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -32,5 +32,4 @@ members = [ "tvm-macros", "tvm-rt", "tvm", - "out-of-tree" ]