diff --git a/.gitignore b/.gitignore index 833eee1a0774..d24fccb6f513 100644 --- a/.gitignore +++ b/.gitignore @@ -91,10 +91,8 @@ ENV/ *~ *.pyc *~ -build config.mk config.cmake -build_* Win32 *.dir perf @@ -187,7 +185,6 @@ tvm_u.* tvm_t.* # Mac OS X .DS_Store -build* # Jetbrain .idea @@ -201,3 +198,11 @@ build* # tmp file .nfs* + +# keys +*.pem +*.p12 +*.pfx +*.cer +*.crt +*.der diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 4f0564ec7694..946a54012d0c 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 4f0564ec769477c66d480dd966088f172050c874 +Subproject commit 946a54012d0c390675ab5b46cd990838d4183d6f diff --git a/CMakeLists.txt b/CMakeLists.txt index a47fe1f8b889..7bd76bbd7906 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -213,7 +213,7 @@ target_include_directories( # Tests set(TEST_EXECS "") file(GLOB TEST_SRCS tests/cpp/*.cc) -find_library(GTEST_LIB gtest) +find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}") if(GTEST_LIB) foreach(__srcpath ${TEST_SRCS}) diff --git a/apps/sgx/Makefile b/apps/sgx/Makefile index 1038f57c3ba1..422d3e4f03ab 100644 --- a/apps/sgx/Makefile +++ b/apps/sgx/Makefile @@ -1,13 +1,12 @@ -# Makefile for example to deploy TVM modules in SGX. - -TVM_ROOT := $(shell cd ../..; pwd) -NNVM_PATH := nnvm -DMLC_CORE := ${TVM_ROOT}/dmlc-core - SGX_SDK ?= /opt/sgxsdk +RUST_SGX_SDK ?= /opt/rust-sgx-sdk SGX_MODE ?= SIM -SGX_ARCH ?= x64 -SGX_DEBUG ?= 1 +DEBUG ?= true +NUM_THREADS ?= 4 + +TVM_DIR ?= $(shell git rev-parse --show-toplevel) + +export sgx_edger8r := $(SGX_SDK)/bin/x64/sgx_edger8r sgx_enclave_signer := $(SGX_SDK)/bin/x64/sgx_sign @@ -20,69 +19,71 @@ trts_library_name := sgx_trts$(sgx_sim) tservice_library_name := sgx_tservice$(sgx_sim) uservice_library_name := sgx_uae_service$(sgx_sim) -pkg_cflags := -std=c++11 -O2 -fPIC\ - -I${TVM_ROOT}/include\ - -I${DMLC_CORE}/include\ - -I${TVM_ROOT}/3rdparty/dlpack/include\ - -I.\ - -DDMLC_LOG_STACK_TRACE=0\ - -fmax-errors=4 - -pkg_ldflags := -L${TVM_ROOT}/lib - -enclave_include_paths := -I$(SGX_SDK)/include\ - -I$(SGX_SDK)/include/tlibc\ - -I$(SGX_SDK)/include/libcxx\ - -I$(SGX_SDK)/include/stdc++\ +pkg_cflags := -std=c++11 -fPIC \ + -I$(SGX_SDK)/include \ + -I$(TVM_DIR)/include \ + -I$(TVM_DIR)/dlpack/include \ + -I$(TVM_DIR)/dmlc-core/include + +pkg_ldflags := -L$(TVM_DIR)/build -ltvm_runtime + +ifneq ($(DEBUG), false) + debug := debug + enclave_cflags += -Og -g + pkg_cflags += -Og -g +else + debug := release + enclave_cflags += -O2 + pkg_cflags += -O2 +endif -enclave_cflags := -static -nostdinc\ - -fvisibility=hidden -fpie -fstack-protector-strong\ - -ffunction-sections -fdata-sections\ - -DDMLC_CXX11_THREAD_LOCAL=0\ - -include "lib/tvm_t.h"\ - $(enclave_include_paths)\ +build_dir := build -enclave_cxxflags := -nostdinc++ $(enclave_cflags) -DTVM_SGX_MAX_CONCURRENCY=4 +enclave_cflags := \ + -I$(SGX_SDK)/include \ + -I$(SGX_SDK)/include/tlibc \ + -I$(SGX_SDK)/include/stdport \ + -I$(SGX_SDK)/include/epid \ + -I$(TVM_DIR)/include \ + -I$(TVM_DIR)/dlpack/include \ + -I$(TVM_DIR)/dmlc-core/include enclave_ldflags :=\ + -L$(build_dir) -L$(TVM_DIR)/build \ -Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\ -Wl,--whole-archive -l$(trts_library_name) -Wl,--no-whole-archive\ -Wl,--start-group\ -lsgx_tstdc -lsgx_tstdcxx -lsgx_tcxx -lsgx_tcrypto -lsgx_tkey_exchange -l$(tservice_library_name)\ + -lenclave -ltvm_t\ -Wl,--end-group\ -Wl,-Bstatic -Wl,-Bsymbolic -Wl,--no-undefined\ -Wl,-pie,-eenclave_entry -Wl,--export-dynamic\ - -Wl,--defsym,__ImageBase=0 -Wl,--gc-sections - -.PHONY: clean all + -Wl,--defsym,__ImageBase=0 -Wl,--gc-sections\ + -Wl,--version-script=enclave/enclave.lds -all: lib/test_addone.signed.so +.PHONY: enclave clean -# The code library built by TVM -lib/test_addone_sys.o: prepare_test_libs.py - python prepare_test_libs.py +enclave: $(build_dir)/enclave.signed.so -lib/tvm_t.h: ../../src/runtime/sgx/tvm.edl - $(sgx_edger8r) --trusted $< --trusted-dir lib --search-path $(SGX_SDK)/include - mv $@ $@.in - awk 'NR==4{print "#include "}1' $@.in > $@ +$(build_dir)/enclave.signed.so: $(build_dir)/enclave.so build/enclave_config.xml enclave/enclave.pem + $(sgx_enclave_signer) sign -key enclave/enclave.pem -enclave $< -out $@ -config build/enclave_config.xml -lib/tvm_t.c: lib/tvm_t.h +enclave/enclave.pem: + curl -sSo $@ 'https://gist.githubusercontent.com/nhynes/8a2d80068a92e672f8b0b7d710ceb404/raw/2d5ae5fbe83198ede49465fdc6535065e093543b/tvm_sgx_demo.pem' -lib/tvm_t.o: lib/tvm_t.c - $(CC) $(enclave_cflags) $(pkg_cflags) -c $< -o $@ -include $(TVM_ROOT)/include/tvm/runtime/c_runtime_api.h +build/enclave_config.xml: enclave/enclave_config.xml.in + cpp $^ -P -o $@ -DNUM_THREADS=$$(( $(NUM_THREADS) + 1 )) -# The enclave library -lib/test_addone.so: $(TVM_ROOT)/src/runtime/sgx/trusted/runtime.cc lib/tvm_t.o lib/test_addone_sys.o - $(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) $(enclave_ldflags) -g +$(build_dir)/enclave.so: $(build_dir)/libenclave.a $(TVM_DIR)/build/libtvm_t.a + $(CXX) $< -o $@ $(enclave_ldflags) $(enclave_cflags) -ltvm_t -# The demo enclave signing key -lib/enclave.pem: - curl -Lso $@ https://gist.githubusercontent.com/nhynes/8a2d80068a92e672f8b0b7d710ceb404/raw/2d5ae5fbe83198ede49465fdc6535065e093543b/tvm_sgx_demo.pem +$(build_dir)/libenclave.a: enclave/target/x86_64-unknown-linux-sgx/$(debug)/libmodel_enclave.a + @mkdir -p $(@D) + @cp $< $@ -# The signed enclave -lib/test_addone.signed.so: lib/test_addone.so enclave_config.xml lib/enclave.pem - $(sgx_enclave_signer) sign -key lib/enclave.pem -enclave $< -out $@ -config enclave_config.xml +enclave/target/x86_64-unknown-linux-sgx/$(debug)/libmodel_enclave.a: enclave/**/* + $(MAKE) -C enclave clean: - rm -rf lib + $(MAKE) -s -C enclave clean + rm -rf build diff --git a/apps/sgx/README.md b/apps/sgx/README.md index 565519d457ce..10989ba4b90d 100644 --- a/apps/sgx/README.md +++ b/apps/sgx/README.md @@ -4,13 +4,41 @@ This application demonstrates the use of a simple TVM model in the [Intel SGX](h ## Prerequisites +1. The TVM premade Docker image + +or + 1. A GNU/Linux environment 2. TVM compiled with LLVM and SGX; and the `tvm` Python module 3. The [Linux SGX SDK](https://github.com/intel/linux-sgx) [link to pre-built libraries](https://01.org/intel-software-guard-extensions/downloads) +4. [Rust](https://rustup.sh) +5. The [rust-sgx-sdk](https://github.com/baidu/rust-sgx-sdk) +6. [xargo](https://github.com/japaric/xargo) + +Check out the `/tvm/install/ubuntu_install_sgx.sh` for the commands to get these dependencies. ## Running the example -`SGX_SDK=/path/to/sgxsdk bash run_example.sh` +If using Docker, start by running + +``` +git clone --recursive https://github.com/dmlc/tvm.git +docker run --rm -it -v $(pwd)/tvm:/mnt tvmai/ci-cpu /bin/bash +``` +then, in the container +``` +cd /mnt +mkdir build && cd build +cmake .. -DUSE_LLVM=ON -DUSE_SGX=/opt/sgxsdk -DRUST_SGX_SDK=/opt/rust-sgx-sdk +make -j4 +cd .. +pip install -e python -e topi/python -e nnvm/python +cd apps/sgx +``` + +Once TVM is build and installed, just + +`./run_example.sh` If everything goes well, you should see a lot of build messages and below them the text `It works!`. @@ -24,10 +52,9 @@ In this library, one can use other libraries like TVM. Building this example performs the following steps: 1. Creates a simple TVM module that computes `x + 1` and save it as a system library. -2. Builds a minimal TVM runtime pack that can load the module. -3. Links the TVM module into an SGX enclave along with some code that runs the module. -4. Compiles and runs an executable that loads the enclave and calls a function - which invokes the TVM module. +2. Builds a TVM runtime that links the module and allows running it using the TVM Python runtime. +3. Packages the bundle into an SGX enclave +4. Runs the enclave using the usual TVM Python `module` API For more information on building, please refer to the `Makefile`. For more information on the TVM module, please refer to `../howto_deploy`. diff --git a/apps/sgx/enclave/.rustfmt.toml b/apps/sgx/enclave/.rustfmt.toml new file mode 120000 index 000000000000..ec1baa2f89be --- /dev/null +++ b/apps/sgx/enclave/.rustfmt.toml @@ -0,0 +1 @@ +../../../rust/.rustfmt.toml \ No newline at end of file diff --git a/apps/sgx/enclave/Cargo.toml b/apps/sgx/enclave/Cargo.toml new file mode 100644 index 000000000000..cb128f3fbf94 --- /dev/null +++ b/apps/sgx/enclave/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "model-enclave" +version = "0.1.0" +authors = ["Nick Hynes "] + +[lib] +crate-type = ["staticlib"] + +[dependencies] +lazy_static = "1.1.0" +tvm = { path = "../../../rust", default-features = false, features = ["sgx"] } + +[profile.release] +lto = true +opt-level = 3 diff --git a/apps/sgx/enclave/Makefile b/apps/sgx/enclave/Makefile new file mode 100644 index 000000000000..a28e05e03b13 --- /dev/null +++ b/apps/sgx/enclave/Makefile @@ -0,0 +1,42 @@ +MODEL ?= resnet +NUM_THREADS ?= 4 +BATCH_SIZE ?= 64 +TRAINING ?= true +DEBUG ?= false + +build_dir := ../build + +ifeq ($(DEBUG), false) + debug := release + xargo_args := --release +else + debug := debug +endif + +target=target/x86_64-unknown-linux-sgx/$(debug)/libmodel-enclave.a + +$(target): $(build_dir)/libmodel.a **/* $(TVM_DIR)/rust/patched.txt + RUST_TARGET_PATH=$(shell pwd) \ + RUST_TARGET_DIR=$(shell pwd)/target \ + RUSTFLAGS="-Z force-unstable-if-unmarked" \ + TVM_NUM_THREADS=$(NUM_THREADS) \ + BUILD_DIR=../build \ + xargo build --target x86_64-unknown-linux-sgx $(xargo_args) -q + +$(TVM_DIR)/rust/patched.txt: $(shell pwd)/sgx-deps.diff + echo $(TVM_DIR) + cd $(TVM_DIR) && git apply $< + touch $@ + +$(build_dir)/libmodel.a: $(build_dir)/model.o + $(AR) cr $@ $^ + +$(build_dir)/model.o: $(build_dir)/model.bc + $(CC) -c $< -o $@ -fPIC -O3 + objcopy --globalize-symbol __tvm_module_startup $@ + +$(build_dir)/model.bc: src/build_model.py + python3 $< -o $(build_dir) + +clean: + xargo clean diff --git a/apps/sgx/enclave/Xargo.toml b/apps/sgx/enclave/Xargo.toml new file mode 100644 index 000000000000..57acf092b4d6 --- /dev/null +++ b/apps/sgx/enclave/Xargo.toml @@ -0,0 +1,13 @@ +[dependencies] +alloc = {} +panic_unwind = {} +panic_abort = {} + +[dependencies.std] +path = "/opt/rust-sgx-sdk/xargo/sgx_tstd" +features = ["backtrace", "stdio", "untrusted_time"] +stage = 2 + +[dependencies.xargo_sgx_rand] +path = "/opt/rust-sgx-sdk/xargo/sgx_rand" +stage = 3 diff --git a/apps/sgx/enclave/build.rs b/apps/sgx/enclave/build.rs new file mode 100644 index 000000000000..a3beedaacda6 --- /dev/null +++ b/apps/sgx/enclave/build.rs @@ -0,0 +1,9 @@ +use std::env; + +fn main() { + println!( + "cargo:rustc-link-search=native={}", + env::var("BUILD_DIR").unwrap() + ); + println!("cargo:rustc-link-lib=static=model"); +} diff --git a/apps/sgx/enclave/enclave.lds b/apps/sgx/enclave/enclave.lds new file mode 100644 index 000000000000..e3d9d0ee0d90 --- /dev/null +++ b/apps/sgx/enclave/enclave.lds @@ -0,0 +1,9 @@ +enclave.so +{ + global: + g_global_data_sim; + g_global_data; + enclave_entry; + local: + *; +}; diff --git a/apps/sgx/enclave_config.xml b/apps/sgx/enclave/enclave_config.xml.in similarity index 50% rename from apps/sgx/enclave_config.xml rename to apps/sgx/enclave/enclave_config.xml.in index 07be0d7a7ad2..2423f93086b8 100644 --- a/apps/sgx/enclave_config.xml +++ b/apps/sgx/enclave/enclave_config.xml.in @@ -1,10 +1,10 @@ 0 0 - 0x2000 - 0x2000 - 5 - 1 + 0x20000 + 0x5000000 + NUM_THREADS + 0 0 0 0xFFFFFFFF diff --git a/apps/sgx/enclave/sgx-deps.diff b/apps/sgx/enclave/sgx-deps.diff new file mode 100644 index 000000000000..1c67e7957f38 --- /dev/null +++ b/apps/sgx/enclave/sgx-deps.diff @@ -0,0 +1,13 @@ +diff --git a/rust/Cargo.toml b/rust/Cargo.toml +index 0819e0c7..e56f4ef2 100644 +--- a/rust/Cargo.toml ++++ b/rust/Cargo.toml +@@ -14,7 +14,7 @@ default = ["nom/std"] + sgx = ["nom/alloc"] + + [dependencies] +-bounded-spsc-queue = "0.4.0" ++bounded-spsc-queue = { git = "https://github.com/nhynes/bounded-spsc-queue", branch = "sgx" } + error-chain = { version = "0.12.0", default-features = false } + itertools = "0.7.8" + lazy_static = "1.1.0" diff --git a/apps/sgx/enclave/src/build_model.py b/apps/sgx/enclave/src/build_model.py new file mode 100644 index 000000000000..d1b45cc4a4df --- /dev/null +++ b/apps/sgx/enclave/src/build_model.py @@ -0,0 +1,38 @@ +"""Creates a simple TVM modules.""" + +import argparse +import os +from os import path as osp + +import nnvm.compiler +import nnvm.testing +import tvm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-o', '--out-dir', default='.') + opts = parser.parse_args() + + # from tutorials/nnvm_quick_start.py + dshape = (1, 3, 224, 224) + net, params = nnvm.testing.resnet.get_workload( + layers=18, batch_size=dshape[0], image_shape=dshape[1:]) + + with nnvm.compiler.build_config(opt_level=3): + graph, lib, params = nnvm.compiler.build( + net, 'llvm --system-lib', shape={'data': dshape}, params=params) + + build_dir = osp.abspath(opts.out_dir) + if not osp.isdir(build_dir): + os.makedirs(build_dir, exist_ok=True) + + lib.save(osp.join(build_dir, 'model.bc')) + with open(osp.join(build_dir, 'graph.json'), 'w') as f_graph_json: + f_graph_json.write(graph.json()) + with open(osp.join(build_dir, 'params.bin'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + + +if __name__ == '__main__': + main() diff --git a/apps/sgx/enclave/src/lib.rs b/apps/sgx/enclave/src/lib.rs new file mode 100644 index 000000000000..d74015a92510 --- /dev/null +++ b/apps/sgx/enclave/src/lib.rs @@ -0,0 +1,119 @@ +#![feature(try_from)] + +#[macro_use] +extern crate lazy_static; +extern crate tvm; + +use std::{convert::TryFrom, sync::Mutex}; + +use tvm::runtime::{sgx, Graph, GraphExecutor, SystemLibModule, TVMArgValue, TVMRetValue}; + +lazy_static! { + static ref SYSLIB: SystemLibModule = { SystemLibModule::default() }; + static ref MODEL: Mutex> = { + let _params = include_bytes!(concat!("../", env!("BUILD_DIR"), "/params.bin")); + let graph_json = include_str!(concat!("../", env!("BUILD_DIR"), "/graph.json")); + + let graph = Graph::try_from(graph_json).unwrap(); + Mutex::new(GraphExecutor::new(graph, &*SYSLIB).unwrap()) + }; +} + +fn ecall_init(_args: &[TVMArgValue]) -> TVMRetValue { + lazy_static::initialize(&MODEL); + TVMRetValue::from(0) +} + +fn ecall_main(_args: &[TVMArgValue]) -> TVMRetValue { + let model = MODEL.lock().unwrap(); + // model.set_input("data", args[0]); + model.run(); + sgx::shutdown(); + // model.get_output(0).into() + TVMRetValue::from(42) +} + +pub mod ecalls { + //! todo: generate this using proc_macros + + use super::*; + + use std::{ + ffi::CString, + os::raw::{c_char, c_int}, + slice, + }; + + use tvm::{ + ffi::runtime::{TVMRetValueHandle, TVMValue}, + runtime::{ + sgx::{run_worker, SgxStatus}, + PackedFunc, + }, + }; + + macro_rules! tvm_ocall { + ($func: expr) => { + match $func { + 0 => Ok(()), + err => Err(err), + } + }; + } + + const ECALLS: &'static [&'static str] = &["__tvm_run_worker__", "__tvm_main__", "init"]; + + lazy_static! { + static ref ECALL_FUNCS: Vec = { + vec![ + Box::new(run_worker), + Box::new(ecall_main), + Box::new(ecall_init), + ] + }; + } + + extern "C" { + fn __tvm_module_startup() -> (); + fn tvm_ocall_register_export(name: *const c_char, func_id: c_int) -> SgxStatus; + } + + #[no_mangle] + pub extern "C" fn tvm_ecall_init(_ret: TVMRetValueHandle) { + unsafe { + __tvm_module_startup(); + + ECALLS.into_iter().enumerate().for_each(|(i, ecall)| { + tvm_ocall!(tvm_ocall_register_export( + CString::new(*ecall).unwrap().as_ptr(), + i as i32 + )).expect(&format!("Error registering `{}`", ecall)); + }); + } + } + + #[no_mangle] + pub extern "C" fn tvm_ecall_packed_func( + func_id: c_int, + arg_values: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut i64, + ) { + let args = unsafe { + let values = slice::from_raw_parts(arg_values, num_args as usize); + let type_codes = slice::from_raw_parts(type_codes, num_args as usize); + values + .into_iter() + .zip(type_codes.into_iter()) + .map(|(v, t)| TVMArgValue::new(*v, *t as i64)) + .collect::>() + }; + let (rv, tc) = ECALL_FUNCS[func_id as usize](&args).into_tvm_value(); + unsafe { + *ret_val = rv; + *ret_type_code = tc; + } + } +} diff --git a/apps/sgx/enclave/x86_64-unknown-linux-sgx.json b/apps/sgx/enclave/x86_64-unknown-linux-sgx.json new file mode 100644 index 000000000000..6cbb524f4439 --- /dev/null +++ b/apps/sgx/enclave/x86_64-unknown-linux-sgx.json @@ -0,0 +1,31 @@ +{ + "arch": "x86_64", + "cpu": "x86-64", + "data-layout": "e-m:e-i64:64-f80:128-n8:16:32:64-S128", + "dynamic-linking": true, + "env": "sgx", + "exe-allocation-crate": "alloc_system", + "executables": true, + "has-elf-tls": true, + "has-rpath": true, + "linker-flavor": "gcc", + "linker-is-gnu": true, + "llvm-target": "x86_64-unknown-linux-gnu", + "max-atomic-width": 64, + "os": "linux", + "position-independent-executables": true, + "pre-link-args": { + "gcc": [ + "-Wl,--as-needed", + "-Wl,-z,noexecstack", + "-m64" + ] + }, + "relro-level": "full", + "stack-probes": true, + "target-c-int-width": "32", + "target-endian": "little", + "target-family": "unix", + "target-pointer-width": "64", + "vendor": "unknown" +} diff --git a/apps/sgx/prepare_test_libs.py b/apps/sgx/prepare_test_libs.py deleted file mode 100644 index f676f46b7ff0..000000000000 --- a/apps/sgx/prepare_test_libs.py +++ /dev/null @@ -1,26 +0,0 @@ -"""Script to prepare test_addone_sys.o""" - -from os import path as osp - -import tvm - -CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) - - -def main(): - out_dir = osp.join(CWD, 'lib') - - n = tvm.var('n') - A = tvm.placeholder((n,), name='A') - B = tvm.compute(A.shape, lambda *i: A(*i) + 1, name='B') - s = tvm.create_schedule(B.op) - s[B].parallel(s[B].op.axis[0]) - print(tvm.lower(s, [A, B], simple_mode=True)) - - # Compile library in system library mode - fadd_syslib = tvm.build(s, [A, B], 'llvm --system-lib') - fadd_syslib.save(osp.join(out_dir, 'test_addone_sys.o')) - - -if __name__ == '__main__': - main() diff --git a/apps/sgx/run_example.sh b/apps/sgx/run_example.sh index 9334b260cbf3..811da3938dd6 100755 --- a/apps/sgx/run_example.sh +++ b/apps/sgx/run_example.sh @@ -1,6 +1,10 @@ #!/bin/bash sgx_sdk=${SGX_SDK:=/opt/sgxsdk} -make -echo "=========================" -LD_LIBRARY_PATH="$sgx_sdk/lib64":${LD_LIBRARY_PATH} TVM_CACHE_DIR=/tmp python test_addone.py + +export LD_LIBRARY_PATH="$sgx_sdk/lib64":${LD_LIBRARY_PATH} +export CC=clang-6.0 +export AR=llvm-ar-6.0 +export TVM_CACHE_DIR=/tmp + +make && printf "\n" && python3 run_model.py diff --git a/apps/sgx/run_model.py b/apps/sgx/run_model.py new file mode 100644 index 000000000000..491a5ccbda3c --- /dev/null +++ b/apps/sgx/run_model.py @@ -0,0 +1,20 @@ +import os.path as osp +import numpy as np +import tvm + +CWD = osp.abspath(osp.dirname(__file__)) + + +def main(): + ctx = tvm.context('cpu', 0) + model = tvm.module.load(osp.join(CWD, 'build', 'enclave.signed.so')) + out = model() + if out == 42: + print('It works!') + else: + print('It doesn\'t work!') + exit(1) + + +if __name__ == '__main__': + main() diff --git a/apps/sgx/test_addone.py b/apps/sgx/test_addone.py deleted file mode 100644 index 5ddccfa425cc..000000000000 --- a/apps/sgx/test_addone.py +++ /dev/null @@ -1,13 +0,0 @@ -import tvm -import numpy as np - -ctx = tvm.context('cpu', 0) -fadd1 = tvm.module.load('lib/test_addone.signed.so') - -n = 10 -x = tvm.nd.array(np.random.uniform(size=n).astype('float32'), ctx) -y = tvm.nd.array(np.zeros(n, dtype='float32'), ctx) -fadd1(x, y) - -np.testing.assert_allclose(y.asnumpy(), x.asnumpy() + 1) -print("It works!") diff --git a/cmake/modules/SGX.cmake b/cmake/modules/SGX.cmake index c9894de11f8b..608d6ff5a4bd 100644 --- a/cmake/modules/SGX.cmake +++ b/cmake/modules/SGX.cmake @@ -1,5 +1,4 @@ if(NOT USE_SGX STREQUAL "OFF") - message(STATUS "Build with SGX support") set(_sgx_src ${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/sgx) set(_tvm_u_h ${_sgx_src}/untrusted/tvm_u.h) @@ -9,8 +8,11 @@ if(NOT USE_SGX STREQUAL "OFF") set(_sgx_ustdc ${RUST_SGX_SDK}/sgx_ustdc) set(_urts_lib "sgx_urts") - if(SGX_MODE STREQUAL "SIM") + if(NOT SGX_MODE STREQUAL "HW") + message(STATUS "Build with SGX support (SIM)") set(_urts_lib "${_urts_lib}_sim") + else() + message(STATUS "Build with SGX support (HW)") endif() # build edge routines diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 60d811344b07..b2bebea0b892 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -15,6 +15,18 @@ RUN bash /install/ubuntu_install_python_package.sh COPY install/ubuntu_install_llvm.sh /install/ubuntu_install_llvm.sh RUN bash /install/ubuntu_install_llvm.sh +# SGX deps (build early; changes infrequently) +COPY install/ubuntu_install_sgx.sh /install/ubuntu_install_sgx.sh +RUN bash /install/ubuntu_install_sgx.sh +ENV LD_LIBRARY_PATH /opt/sgxsdk/lib64:${LD_LIBRARY_PATH} + +# Rust env (build early; takes a while) +COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh +RUN bash /install/ubuntu_install_rust.sh +ENV RUSTUP_HOME /opt/rust +ENV CARGO_HOME /opt/rust +ENV RUSTC_WRAPPER sccache + # AutoTVM deps COPY install/ubuntu_install_redis.sh /install/ubuntu_install_redis.sh RUN bash /install/ubuntu_install_redis.sh @@ -23,13 +35,4 @@ RUN bash /install/ubuntu_install_redis.sh COPY install/ubuntu_install_golang.sh /install/ubuntu_install_golang.sh RUN bash /install/ubuntu_install_golang.sh -# Rust env -COPY install/ubuntu_install_rust.sh /install/ubuntu_install_rust.sh -RUN bash /install/ubuntu_install_rust.sh - -# SGX deps -COPY install/ubuntu_install_sgx.sh /install/ubuntu_install_sgx.sh -RUN bash /install/ubuntu_install_sgx.sh - - -ENV PATH $PATH:/root/.cargo/bin +ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin diff --git a/docker/install/ubuntu_install_golang.sh b/docker/install/ubuntu_install_golang.sh index 9585824091a7..e15a456bc15a 100644 --- a/docker/install/ubuntu_install_golang.sh +++ b/docker/install/ubuntu_install_golang.sh @@ -1,4 +1,4 @@ #install the necessary dependancies for golang build -apt-get update && apt-get install -y golang-go +apt-get update && apt-get install -y golang-0.10-go apt-get update && apt-get install -y godoc apt-get update && apt-get install -y golint diff --git a/docker/install/ubuntu_install_rust.sh b/docker/install/ubuntu_install_rust.sh index 1d17b66164c9..9a51afeea79b 100644 --- a/docker/install/ubuntu_install_rust.sh +++ b/docker/install/ubuntu_install_rust.sh @@ -1,9 +1,15 @@ apt-get update && apt-get install -y --no-install-recommends --force-yes curl -curl -sSo rustup.sh 'https://sh.rustup.rs' -# rustc nightly-2018-08-25 is the version supported by the above version of rust-sgx-sdk -bash rustup.sh -y --no-modify-path --default-toolchain nightly-2018-08-25 -. $HOME/.cargo/env +export RUSTUP_HOME=/opt/rust +export CARGO_HOME=/opt/rust +# this rustc is one supported by the installed version of rust-sgx-sdk +curl https://sh.rustup.rs -sSf | sh -s -- -y --no-modify-path --default-toolchain nightly-2018-09-25 +. $CARGO_HOME/env +rustup toolchain add nightly rustup component add rust-src -cargo install rustfmt-nightly --force -cargo install xargo +cargo +nightly install sccache +cargo +nightly install rustfmt-nightly --version 0.99.5 --force +cargo +nightly install xargo + +# make rust usable by all users +chmod -R a+w /opt/rust diff --git a/docker/install/ubuntu_install_sgx.sh b/docker/install/ubuntu_install_sgx.sh index 917fd4b55954..a8201ac74a97 100644 --- a/docker/install/ubuntu_install_sgx.sh +++ b/docker/install/ubuntu_install_sgx.sh @@ -2,18 +2,20 @@ apt-get update && apt-get install -y --no-install-recommends --force-yes \ build-essential git cmake \ wget python pkg-config software-properties-common \ autoconf automake libtool ocaml \ + protobuf-compiler libprotobuf-dev \ libssl-dev libcurl4-openssl-dev curl git clone https://github.com/intel/linux-sgx.git cd linux-sgx git checkout sgx_2.2 -curl 'https://gist.github.com/nhynes/c770b0e91610f8c020a8d1a803a1e7cb' | git am +curl 'https://gist.githubusercontent.com/nhynes/c770b0e91610f8c020a8d1a803a1e7cb/raw/8f5372d9cb88929b3cc49a384943bb363bc06827/intel-sgx.patch' | git apply ./download_prebuilt.sh -make -j sdk && make -j sdk_install_pkg -./linux/installer/bin/sgx_linux_x64_sdk_2.2.100.45311.bin --prefix /opt +make -j4 sdk && make -j4 sdk_install_pkg +./linux/installer/bin/sgx_linux_x64_sdk*.bin --prefix /opt cd - git clone https://github.com/baidu/rust-sgx-sdk.git /opt/rust-sgx-sdk cd /opt/rust-sgx-sdk -git checkout bdd75ca05f66d1f5df637182ec335970f769b03a +git checkout v1.0.4 +curl 'https://gist.githubusercontent.com/nhynes/37164039c5d3f33aa4f123e4ba720036/raw/5b7fc24d4faa0bd6efce19f8324f79d5562991e0/rust-sgx-sdk.diff' | git apply cd - diff --git a/docs/api/python/nnvm/frontend.rst b/docs/api/python/nnvm/frontend.rst index f872a6b878e2..eb07a13e8340 100644 --- a/docs/api/python/nnvm/frontend.rst +++ b/docs/api/python/nnvm/frontend.rst @@ -10,3 +10,7 @@ nnvm.frontend .. autofunction:: nnvm.frontend.from_coreml .. autofunction:: nnvm.frontend.from_keras + +.. autofunction:: nnvm.frontend.from_tensorflow + +.. autofunction:: nnvm.frontend.from_darknet diff --git a/docs/contribute/pull_request.rst b/docs/contribute/pull_request.rst index 80a0448c08dd..c83edc6cf7d1 100644 --- a/docs/contribute/pull_request.rst +++ b/docs/contribute/pull_request.rst @@ -24,3 +24,50 @@ This is a quick guide to submit a pull request, please also refer to the detaile - The detailed guidelines and summarizes useful lessons. - The patch can be merged after the reviewers approve the pull request. + +Testing +------- +Even though we have hooks to run unit tests automatically for each pull request, It's always recommended to run unit tests +locally beforehand to reduce reviewers' burden and speedup review process. + +C++ +^^^ +.. code:: bash + + # assume you are in tvm source root + TVM_ROOT=`pwd` + + # you need to install google test first, gtest will be installed to $TVM_ROOT/lib + CACHE_PREFIX=. make -f 3rdparty/dmlc-core/scripts/packages.mk gtest + + mkdir build + cd build + GTEST_LIB=$TVM_ROOT/lib cmake .. + make cpptest -j + for test in *_test; do + ./$test || exit -1 + done + +Python +^^^^^^ +If you want to run all tests: + +.. code:: bash + + # build tvm + make + + ./tests/scripts/task_python_unittest.sh + +If you want to run a single test: + +.. code:: bash + + # build tvm + make + + # let python know where to find tvm related libraries + export PYTHONPATH=python:topi/python + rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc + + TVM_FFI=ctypes python -m nose -v tests/python/unittest/test_pass_storage_rewrite.py \ No newline at end of file diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index bcaece0bf0a1..fe5356557e55 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -25,8 +25,20 @@ This level enables fully connected multi-layer perceptron. tvm.relay.log tvm.relay.sqrt tvm.relay.exp + tvm.relay.sigmoid tvm.relay.add tvm.relay.expand_dims + tvm.relay.concatenate + tvm.relay.nn.softmax + tvm.relay.nn.log_softmax + tvm.relay.subtract + tvm.relay.multiply + tvm.relay.divide + tvm.relay.mod + tvm.relay.tanh + tvm.relay.sigmoid + tvm.relay.nn.relu + **Level 2: Convolutions** @@ -36,45 +48,132 @@ This level enables typical convnet models. :nosignatures: tvm.relay.nn.conv2d + tvm.relay.nn.conv2d_transpose + tvm.relay.nn.max_pool2d + tvm.relay.nn.avg_pool2d + tvm.relay.nn.global_max_pool2d + tvm.relay.nn.global_avg_pool2d + tvm.relay.nn.upsampling + tvm.relay.nn.batch_flatten + tvm.relay.nn.lrn + tvm.relay.nn.l2_normalize **Level 3: Additional Math And Transform Operators** +This level enables additional math and transform operators. + +.. autosummary:: + :nosignatures: + + tvm.relay.zeros_like + tvm.relay.ones_like + tvm.relay.reshape + tvm.relay.copy + tvm.relay.transpose + tvm.relay.floor + tvm.relay.ceil + tvm.relay.trunc + tvm.relay.round + tvm.relay.abs + tvm.relay.negative + tvm.relay.take + tvm.relay.full + tvm.relay.full_like + + **Level 4: Broadcast and Reductions** .. autosummary:: :nosignatures: tvm.relay.right_shift + tvm.relay.left_shift tvm.relay.equal tvm.relay.not_equal tvm.relay.greater tvm.relay.greater_equal tvm.relay.less tvm.relay.less_equal + tvm.relay.maximum + tvm.relay.minimum + tvm.relay.pow + **Level 5: Vision/Image Operators** +.. autosummary:: + :nosignatures: + + tvm.relay.image.resize + Level 1 Definitions ------------------- .. autofunction:: tvm.relay.log .. autofunction:: tvm.relay.sqrt .. autofunction:: tvm.relay.exp +.. autofunction:: tvm.relay.sigmoid .. autofunction:: tvm.relay.add +.. autofunction:: tvm.relay.subtract +.. autofunction:: tvm.relay.multiply +.. autofunction:: tvm.relay.divide +.. autofunction:: tvm.relay.mod +.. autofunction:: tvm.relay.tanh +.. autofunction:: tvm.relay.sigmoid +.. autofunction:: tvm.relay.concatenate +.. autofunction:: tvm.relay.nn.softmax +.. autofunction:: tvm.relay.nn.log_softmax +.. autofunction:: tvm.relay.nn.relu Level 2 Definitions ------------------- .. autofunction:: tvm.relay.nn.conv2d +.. autofunction:: tvm.relay.nn.conv2d_transpose +.. autofunction:: tvm.relay.nn.max_pool2d +.. autofunction:: tvm.relay.nn.avg_pool2d +.. autofunction:: tvm.relay.nn.global_max_pool2d +.. autofunction:: tvm.relay.nn.global_avg_pool2d +.. autofunction:: tvm.relay.nn.upsampling +.. autofunction:: tvm.relay.nn.batch_flatten +.. autofunction:: tvm.relay.nn.lrn +.. autofunction:: tvm.relay.nn.l2_normalize + + +Level 3 Definitions +------------------- +.. autofunction:: tvm.relay.floor +.. autofunction:: tvm.relay.ceil +.. autofunction:: tvm.relay.trunc +.. autofunction:: tvm.relay.round +.. autofunction:: tvm.relay.abs +.. autofunction:: tvm.relay.negative +.. autofunction:: tvm.relay.reshape +.. autofunction:: tvm.relay.copy +.. autofunction:: tvm.relay.transpose +.. autofunction:: tvm.relay.take + +Level 3 Definitions +------------------- +.. autofunction:: tvm.relay.zeros_like +.. autofunction:: tvm.relay.ones_like Level 4 Definitions ------------------- .. autofunction:: tvm.relay.right_shift +.. autofunction:: tvm.relay.left_shift .. autofunction:: tvm.relay.equal .. autofunction:: tvm.relay.not_equal .. autofunction:: tvm.relay.greater .. autofunction:: tvm.relay.greater_equal .. autofunction:: tvm.relay.less .. autofunction:: tvm.relay.less_equal +.. autofunction:: tvm.relay.maximum +.. autofunction:: tvm.relay.minimum +.. autofunction:: tvm.relay.pow + +Level 5 Definitions +------------------- +.. autofunction:: tvm.relay.image.resize diff --git a/include/tvm/expr.h b/include/tvm/expr.h index 050ab4c334e2..7fdca7f6af8e 100644 --- a/include/tvm/expr.h +++ b/include/tvm/expr.h @@ -56,6 +56,8 @@ inline TVMType Type2TVMType(Type t) { // Get number of bytes considering vector type. inline int GetVectorBytes(Type dtype) { int data_bits = dtype.bits() * dtype.lanes(); + // allow bool to exist + if (dtype == Bool()) return 1; CHECK_EQ(data_bits % 8, 0U) << "Need to load/store by multiple of bytes"; return data_bits / 8; @@ -108,6 +110,8 @@ class Range : public HalideIR::IR::Range { TVM_DLL static Range make_by_min_extent(Expr min, Expr extent); }; +using Region = Array; + /*! * \brief Type of iteration variable. * Each IterVar have a specific type. diff --git a/include/tvm/operation.h b/include/tvm/operation.h index c11242c0a55d..1a1d28ab71bb 100644 --- a/include/tvm/operation.h +++ b/include/tvm/operation.h @@ -49,7 +49,7 @@ class OperationNode : public FunctionBaseNode { } /*! * \return The list of iteration variable at root - * \note root_iter_vars dedides the shape of the outputs. + * \note root_iter_vars decides the shape of the outputs. */ virtual Array root_iter_vars() const = 0; /*! @@ -239,6 +239,74 @@ class TVM_DLL ComputeOpNode : public OperationNode { TVM_DECLARE_NODE_TYPE_INFO(ComputeOpNode, OperationNode); }; +/*! + * \brief A TenorCompute op that compute a tensor with an tensor intrinsic. + */ +class TensorComputeOpNode : public OperationNode { + public: + /*! \brief IterVar on each axis */ + Array axis; + /*! \brief IterVar on each reduction axis, if the intrin will use the reduce axis */ + Array reduce_axis; + /*! \brief number of axes that can be scheduled */ + int schedulable_ndim; + /*! \brief TensorIntrin used to compute */ + TensorIntrin intrin; + /*! \brief input tensors of intrin */ + Array inputs; + /*! \brief region of input tensors */ + Array input_regions; + /*! \brief constructor */ + TensorComputeOpNode() {} + // override functions + int num_outputs() const final; + Array root_iter_vars() const final; + Type output_dtype(size_t i) const final; + Array output_shape(size_t i) const final; + Array InputTensors() const final; + Operation ReplaceInputs( + const Operation& self, + const std::unordered_map& rmap) const final; + void PropBoundToInputs( + const Operation& self, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const final; + void GatherBound( + const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const final; + Stmt BuildRealize( + const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const final; + Stmt BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const final; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("name", &name); + v->Visit("tag", &tag); + v->Visit("axis", &axis); + v->Visit("reduce_axis", &reduce_axis); + v->Visit("schedulable_ndim", &schedulable_ndim); + v->Visit("intrin", &intrin); + v->Visit("inputs", &inputs); + v->Visit("input_regions", &input_regions); + } + static Operation make(std::string name, + std::string tag, + Array axis, + Array reduce_axis, + int schedulable_ndim, + TensorIntrin intrin, + Array tensors, + Array regions); + + static constexpr const char* _type_key = "TensorComputeOp"; + TVM_DECLARE_NODE_TYPE_INFO(TensorComputeOpNode, OperationNode); +}; + /*! * \brief Symbolic scan. */ @@ -326,7 +394,7 @@ class ExternOpNode : public OperationNode { public: /*! \brief The input tensors */ Array inputs; - /*! \brief Symbolic placeholder representationinputs */ + /*! \brief Symbolic placeholder representation of inputs */ Array input_placeholders; /*! \brief Symbolic placeholder representation of outputs */ Array output_placeholders; diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h new file mode 100644 index 000000000000..527bb647314f --- /dev/null +++ b/include/tvm/relay/attrs/image.h @@ -0,0 +1,41 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/attrs/image.h + * \brief Auxiliary attributes for image operators. + */ +#ifndef TVM_RELAY_ATTRS_IMAGE_H_ +#define TVM_RELAY_ATTRS_IMAGE_H_ + +#include +#include + +namespace tvm { +namespace relay { + +/*! \brief Attributes used in image resize operator */ +struct ResizeAttrs : public tvm::AttrsNode { + Array size; + std::string layout; + std::string method; + bool align_corners; + + TVM_DECLARE_ATTRS(ResizeAttrs, "relay.attrs.ResizeAttrs") { + TVM_ATTR_FIELD(size).set_default(NullValue >()) + .describe("Output Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).set_default("BILINEAR") + .describe("Specify the mode to use for scaling." + "NEAREST_NEIGHBOR - Nearest Neighbor" + "BILINEAR - Bilinear Interpolation"); + TVM_ATTR_FIELD(align_corners).set_default(false) + .describe("Should be true to preserve the values at the corner pixels"); + } +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_IMAGE_H_ diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index b364079f06fc..7eb7a83605ac 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -67,6 +67,201 @@ struct ConvAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in softmax operators */ +struct SoftmaxAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(SoftmaxAttrs, "relay.attrs.SoftmaxAttrs") { + TVM_ATTR_FIELD(axis).set_default(1) + .describe("The axis to sum over when computing softmax."); + } +}; + +/*! \brief Attributes used in transposed convolution operator */ +struct Conv2DTransposeAttrs : public tvm::AttrsNode { + IndexExpr channels; + Array kernel_size; + Array strides; + Array padding; + Array output_padding; + Array dilation; + int groups; + std::string data_layout; + std::string weight_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv2DTransposeAttrs, "relay.attrs.Conv2DTransposeAttrs") { + TVM_ATTR_FIELD(channels) + .set_default(NullValue()) + .describe("The dimensionality of the output space" + "i.e. the number of output channels in the convolution."); + TVM_ATTR_FIELD(kernel_size) + .describe("The dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + .describe("The strides of the convolution."); + TVM_ATTR_FIELD(output_padding).set_default(Array({0, 0})) + .describe("Zero-padding added to one side of the output."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation).set_default(Array({1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(data_layout).set_default("NCHW") + .describe("Dimension ordering of data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(weight_layout).set_default("OIHW") + .describe("Dimension ordering of data and weight. Can be 'OIHW', 'OIHW16o16i', etc." + "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width" + "dimensions respectively."); + TVM_ATTR_FIELD(out_dtype) + .set_default(Int(0)) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + +/*! \brief Attributes for max pool operator */ +struct MaxPool2DAttrs : public tvm::AttrsNode { + Array pool_size; + Array strides; + Array padding; + std::string layout; + bool ceil_mode; + + TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") { + TVM_ATTR_FIELD(pool_size) + .describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false) + .describe("When true, will use ceil instead of floor to compute the output shape."); + } +}; + +/*! \brief Attributes for avg pool operator */ +struct AvgPool2DAttrs : public tvm::AttrsNode { + Array pool_size; + Array strides; + Array padding; + std::string layout; + bool ceil_mode; + bool count_include_pad; + + TVM_DECLARE_ATTRS(AvgPool2DAttrs, "relay.attrs.AvgPool2DAttrs") { + TVM_ATTR_FIELD(pool_size) + .describe("Size of the pooling windows."); + TVM_ATTR_FIELD(strides).set_default(Array({1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "Padding support both symmetric and asymmetric as" + "one int : same padding used on all sides" + "two int : bottom, right will use same padding as top, left" + "four int : padding width in the order of (top, left, bottom, right)"); + TVM_ATTR_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(ceil_mode).set_default(false) + .describe("When true, will use ceil instead of floor to compute the output shape."); + TVM_ATTR_FIELD(count_include_pad).set_default(false) + .describe("When true, will include padding to compute the average"); + } +}; + +/*! \brief Attributes for global pool operator */ +struct GlobalPool2DAttrs : public tvm::AttrsNode { + std::string layout; + + TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") { + TVM_ATTR_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + } +}; + +/*! \brief Attributes for upsampling operator */ +struct UpSamplingAttrs : public tvm::AttrsNode { + int scale; + std::string layout; + std::string method; + + TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") { + TVM_ATTR_FIELD(scale) + .describe("Should be true to preserve the values at the corner pixels"); + TVM_ATTR_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Upsampling is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).set_default("NEAREST_NEIGHBOR") + .describe("Specify the mode to use for scaling." + "NEAREST_NEIGHBOR - Nearest Neighbor" + "BILINEAR - Bilinear Interpolation"); + } +}; + + + + +/*! \brief Attributes for LRN operator */ +struct LRNAttrs : public tvm::AttrsNode { + IndexExpr size; + IndexExpr axis; + double bias; + double alpha; + double beta; + + TVM_DECLARE_ATTRS(LRNAttrs, "relay.attrs.LRNAttrs") { + TVM_ATTR_FIELD(size).set_default(5) + .describe("The size of the local region to be considered for normalization."); + TVM_ATTR_FIELD(axis).set_default(1) + .describe("Axis of input data layout channel."); + TVM_ATTR_FIELD(bias).set_default(2) + .describe("The offset parameter to avoid division by 0."); + TVM_ATTR_FIELD(alpha).set_default(0.0001) + .describe("The scaling parameter."); + TVM_ATTR_FIELD(beta).set_default(0.75) + .describe("The exponent parameter."); + } +}; + + +/*! \brief Attributes for L2Normalize operator */ +struct L2NormalizeAttrs : public tvm::AttrsNode { + double eps; + Array axis; + + TVM_DECLARE_ATTRS(L2NormalizeAttrs, "relay.attrs.L2NormalizeAttrs") { + TVM_ATTR_FIELD(eps) + .describe("A lower bound value for the norm, to avoid division by 0."); + TVM_ATTR_FIELD(axis) + .describe("Axis over the normalization applied."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_NN_H_ diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index b14e8f22722e..080a375cf1e2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -30,6 +30,58 @@ struct ExpandDimsAttrs : public tvm::AttrsNode { } }; // struct ExpandDimsAttrs +/*! \brief Attributes used in concatenate operators */ +struct ConcatenateAttrs : public tvm::AttrsNode { + int axis; + TVM_DECLARE_ATTRS(ConcatenateAttrs, "relay.attrs.ConcatenateAttrs") { + TVM_ATTR_FIELD(axis) + .describe("The axis at which the input arrays are concatenated." + "Should lie in range `[-ndim, ndim)`.") + .set_default(0); + } +}; // struct ConcatenateAttrs + +/*! \brief Attributes used in transpose operators */ +struct TransposeAttrs : public tvm::AttrsNode { + Array axes; + TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") { + TVM_ATTR_FIELD(axes) + .describe("The target axes order, reverse order if not specified."); + } +}; // struct TransposeAttrs + +/*! \brief Attributes used in reshape operators */ +struct ReshapeAttrs : public tvm::AttrsNode { + Array newshape; + TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { + TVM_ATTR_FIELD(newshape) + .describe("The new shape. Should be compatible with the original shape."); + } +}; // struct ReshapeAttrs + +struct TakeAttrs : public tvm::AttrsNode { + IndexExpr axis; + + TVM_DECLARE_ATTRS(TakeAttrs, "relay.attrs.TakeAttrs") { + TVM_ATTR_FIELD(axis).set_default(NullValue()) + .describe("The axis over which to select values."); + } +}; + +/*! \brief Attributes used in full operator */ +struct FullAttrs : public tvm::AttrsNode { + Array shape; + DataType dtype; + + TVM_DECLARE_ATTRS(FullAttrs, "relay.attrs.FullAttrs") { + TVM_ATTR_FIELD(shape) + .describe("Target shape."); + TVM_ATTR_FIELD(dtype) + .describe("Target data type.") + .set_default(Int(0)); + } +}; // struct FullAttrs + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_TRANSFORM_H_ diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h new file mode 100644 index 000000000000..a2f7360f1f71 --- /dev/null +++ b/include/tvm/relay/attrs/vision.h @@ -0,0 +1,17 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file tvm/relay/attrs/vision.h + * \brief Auxiliary attributes for vision operators. + */ +#ifndef TVM_RELAY_ATTRS_VISION_H_ +#define TVM_RELAY_ATTRS_VISION_H_ + +#include +#include + +namespace tvm { +namespace relay { + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_ATTRS_VISION_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 909b702bc1a1..c6e5573d9413 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -360,8 +360,6 @@ class IfNode : public ExprNode { /*! \brief The expression evaluated when condition is false */ Expr false_branch; - IfNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); @@ -378,6 +376,28 @@ class IfNode : public ExprNode { RELAY_DEFINE_NODE_REF(If, IfNode, Expr); +/*! \brief Get a field out of a tuple. */ +class TupleGetItem; +class TupleGetItemNode : public ExprNode { + public: + /*! \brief The tuple */ + Expr tuple; + /*! \brief which value to get */ + int index; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("tuple", &tuple); + v->Visit("index", &index); + } + + TVM_DLL static TupleGetItem make(Expr tuple, int index); + + static constexpr const char * _type_key = "relay.GetItem"; + TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); + /*! \brief Print a debug representation of the expression to the stream. * \param env The environment. * \param e The expression diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1da66bc95f57..be174d33b4c8 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -89,6 +89,7 @@ class ExprFunctor { Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -108,6 +109,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); + RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); return vtable; } }; @@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { void VisitExpr_(const LetNode* op) override; void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; virtual void VisitType(const Type& t); }; @@ -153,6 +156,7 @@ class ExprMutator Expr VisitExpr_(const CallNode* call_node) override; Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; + Expr VisitExpr_(const TupleGetItemNode* op) override; /*! \brief Used to visit the types inside of expressions. * * Can be overloaded to transform the types in arbitrary diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 8b2a5fafd8f0..3678aee32850 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -80,7 +80,7 @@ bool AlphaEqual(const Expr& e1, const Expr& e2); */ bool AlphaEqual(const Type& t1, const Type& t2); -/*! brief Check that each Var is only bind once. +/*! \brief Check that each Var is only bound once. * * For example, the expression `let x = 1 in let x = 2 in 3` bound x twice. * @@ -88,9 +88,9 @@ bool AlphaEqual(const Type& t1, const Type& t2); * * \param e the expression to check. * - * \return true iff all Var in e is bind at most once. + * \return true iff all Var in e is bound at most once. */ -bool WellFormed(const Expr & e); +bool WellFormed(const Expr& e); /*! \brief Get free variables from expression e. * @@ -100,7 +100,7 @@ bool WellFormed(const Expr & e); * * \return the set of free variable. */ -tvm::Array FreeVariables(const Expr & e); +tvm::Array FreeVariables(const Expr& e); /*! \brief Get free type parameters from expression e. * @@ -110,7 +110,7 @@ tvm::Array FreeVariables(const Expr & e); * * \return the set of free type variables. */ -tvm::Array FreeTypeVariables(const Expr & e); +tvm::Array FreeTypeVariables(const Expr& e); /*! \brief Get free type parameters from type t. * @@ -120,7 +120,20 @@ tvm::Array FreeTypeVariables(const Expr & e); * * \return the set of free type variables. */ -tvm::Array FreeTypeVariables(const Type & t); +tvm::Array FreeTypeVariables(const Type& t); + +/*! \brief Remove expressions which does not effect the program result. + * + * It will remove let binding that are not referenced, and if branch that are not entered. + * + * For example, this pass should turn `let a = 1 in 2` into `2`, as the value of the expression does not depend on a. + * Another example is `if (true) then 1 else 2` will be optimized into 1. + * + * \param e the expression to optimize. + * + * \return the optimized expression. + */ +Expr DeadCodeElimination(const Expr& e); } // namespace relay } // namespace tvm diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 313e0a5c3da8..0fc8e42b8bcb 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -282,6 +282,21 @@ inline void NDArray::reset() { } } +/*! \brief return the size of data the DLTensor hold, in term of number of bytes + * + * \param arr the input DLTensor + * + * \return number of bytes of data in the DLTensor. + */ +inline size_t GetDataSize(const DLTensor& arr) { + size_t size = 1; + for (tvm_index_t i = 0; i < arr.ndim; ++i) { + size *= static_cast(arr.shape[i]); + } + size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; + return size; +} + inline void NDArray::CopyFrom(DLTensor* other) { CHECK(data_ != nullptr); CopyFromTo(other, &(data_->dl_tensor)); diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index d204f8624a64..a8fa096e51c4 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -873,6 +873,9 @@ inline const char* TypeCode2Str(int type_code) { #ifndef _LIBCPP_SGX_NO_IOSTREAMS inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) + if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { + os << "bool"; return os; + } os << TypeCode2Str(t.code); if (t.code == kHandle) return os; os << static_cast(t.bits); @@ -890,7 +893,9 @@ inline std::string TVMType2String(TVMType t) { os << t; return os.str(); #else - std::string repr = ""; + if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { + return "bool"; + } repr += TypeCode2Str(t.code); if (t.code == kHandle) return repr; repr += std::to_string(static_cast(t.bits)); @@ -920,6 +925,11 @@ inline TVMType String2TVMType(std::string s) { t.code = kHandle; t.bits = 64; // handle uses 64 bit by default. scan = s.c_str() + 6; + } else if (s == "bool") { + t.code = kDLUInt; + t.bits = 1; + t.lanes = 1; + return t; } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; diff --git a/include/tvm/tensor_intrin.h b/include/tvm/tensor_intrin.h index 944498d1e615..fbee4bccc0bf 100644 --- a/include/tvm/tensor_intrin.h +++ b/include/tvm/tensor_intrin.h @@ -89,5 +89,58 @@ class TensorIntrinNode : public Node { inline const TensorIntrinNode* TensorIntrin::operator->() const { return static_cast(node_.get()); } + + +// Internal node container of tensor intrinsic calling. +class TensorIntrinCallNode; + +/*! \brief Tensor intrinsic calling node. */ +class TensorIntrinCall : public NodeRef { + public: + TensorIntrinCall() {} + explicit TensorIntrinCall(NodePtr n) : NodeRef(n) {} + /*! + * \brief access the internal node container + * \return the pointer to the internal node container + */ + inline const TensorIntrinCallNode* operator->() const; + + /*! \brief specify container node */ + using ContainerType = TensorIntrinCallNode; +}; + +class TensorIntrinCallNode : public Node { + public: + /*! \brief the tensor intrinsic */ + TensorIntrin intrin; + /*! \brief input tensors of the intrinsic */ + Array tensors; + /*! \brief regions of input tensors */ + Array regions; + /*! + * \brief IterVar on each reduction axis, if the + * intrin will use the reduce axis + */ + Array reduce_axis; + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("intrin", &intrin); + v->Visit("tensors", &tensors); + v->Visit("regions", ®ions); + v->Visit("reduce_axis", &reduce_axis); + } + static TensorIntrinCall make(TensorIntrin intrin, + Array tensors, + Array regions, + Array reduce_axis); + + static constexpr const char* _type_key = "TensorIntrinCall"; + TVM_DECLARE_NODE_TYPE_INFO(TensorIntrinCallNode, Node); +}; + +inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const { + return static_cast(node_.get()); +} + } // namespace tvm #endif // TVM_TENSOR_INTRIN_H_ diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 195d49bfb9b4..f1acb972158d 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -218,6 +218,9 @@ std::shared_ptr JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) // rebuild attribute parser if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) { n.node->op()->attr_parser(&(n.node->attrs)); + } else if (!no_parse && n.node->is_variable()) { + n.node->attrs.parsed = + Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed; } for (const JSONGraph &subgraph : n.subgraphs) { // The "no_parse" option here, is to be compatible with diff --git a/nnvm/src/top/tensor/transform.cc b/nnvm/src/top/tensor/transform.cc index 40c8c930a029..270172856a75 100644 --- a/nnvm/src/top/tensor/transform.cc +++ b/nnvm/src/top/tensor/transform.cc @@ -1135,7 +1135,7 @@ Examples:: .set_attr("FCorrectLayout", TakeCorrectLayout) .set_num_inputs(2) .set_num_outputs(1) -.set_support_level(1) +.set_support_level(3) .set_attr( "FTVMCompute", [](const NodeAttrs& attrs, const Array& inputs, diff --git a/nnvm/tests/python/frontend/onnx/test_forward.py b/nnvm/tests/python/frontend/onnx/test_forward.py index 187e6c175cd4..7ca520a88b12 100644 --- a/nnvm/tests/python/frontend/onnx/test_forward.py +++ b/nnvm/tests/python/frontend/onnx/test_forward.py @@ -66,7 +66,7 @@ def get_caffe2_output(model, x, dtype='float32'): def verify_onnx_forward_impl(graph_file, data_shape, out_shape): dtype = 'float32' x = np.random.uniform(size=data_shape) - model = onnx.load(graph_file) + model = onnx.load_model(graph_file) c2_out = get_caffe2_output(model, x, dtype) for target, ctx in ctx_list(): tvm_out = get_tvm_output(model, x, target, ctx, out_shape, dtype) diff --git a/nnvm/tests/python/frontend/onnx/test_graph.py b/nnvm/tests/python/frontend/onnx/test_graph.py index 0aad9d22f1be..b3961c1a38fd 100755 --- a/nnvm/tests/python/frontend/onnx/test_graph.py +++ b/nnvm/tests/python/frontend/onnx/test_graph.py @@ -6,7 +6,7 @@ from model_zoo import squeezenet as squeezenet def compare_graph(onnx_file, nnvm_sym, ishape): - onnx_model = onnx.load(onnx_file) + onnx_model = onnx.load_model(onnx_file) onnx_sym, params = nnvm.frontend.from_onnx(onnx_model) g1 = nnvm.graph.create(onnx_sym) g2 = nnvm.graph.create(nnvm_sym) diff --git a/nnvm/tests/python/unittest/test_pass_saveload_json.py b/nnvm/tests/python/unittest/test_pass_saveload_json.py new file mode 100644 index 000000000000..7b5f5ea6867a --- /dev/null +++ b/nnvm/tests/python/unittest/test_pass_saveload_json.py @@ -0,0 +1,17 @@ +import nnvm +from tvm.contrib import util + + +def test_variable_node_parsed(): + sym = nnvm.sym.Variable('data') + tempdir = util.tempdir() + json_filename = 'test_nnvm_symbol.json' + with open(tempdir.relpath(json_filename), 'w') as fo: + fo.write(nnvm.graph.create(sym).json()) + sym_str = open(tempdir.relpath(json_filename), 'r').read() + sym = nnvm.graph.load_json(sym_str).symbol() + sym = nnvm.sym.relu(sym) + + +if __name__ == '__main__': + test_variable_node_parsed() diff --git a/python/tvm/_ffi/_ctypes/node.py b/python/tvm/_ffi/_ctypes/node.py index eb9e930b30eb..ccfaa6dd77a2 100644 --- a/python/tvm/_ffi/_ctypes/node.py +++ b/python/tvm/_ffi/_ctypes/node.py @@ -76,6 +76,8 @@ def __init_handle_by_constructor__(self, fconstructor, *args): So the return handle is directly set into the Node object instead of creating a new Node. """ + # assign handle first to avoid error raising + self.handle = None handle = __init_by_constructor__(fconstructor, args) if not isinstance(handle, NodeHandle): handle = NodeHandle(handle) diff --git a/python/tvm/_ffi/_cython/node.pxi b/python/tvm/_ffi/_cython/node.pxi index c62e4ab44cef..73ead2b4b447 100644 --- a/python/tvm/_ffi/_cython/node.pxi +++ b/python/tvm/_ffi/_cython/node.pxi @@ -82,6 +82,8 @@ cdef class NodeBase: So the return handle is directly set into the Node object instead of creating a new Node. """ + # avoid error raised during construction. + self.chandle = NULL cdef void* chandle ConstructorCall( (fconstructor).chandle, diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 2aced1aef7d2..b17487559e50 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -48,6 +48,13 @@ def __init__(self, type_str): super(TVMType, self).__init__() if isinstance(type_str, np.dtype): type_str = str(type_str) + + if type_str == "bool": + self.bits = 1 + self.type_code = 1 + self.lanes = 1 + return + arr = type_str.split("x") head = arr[0] self.lanes = int(arr[1]) if len(arr) > 1 else 1 @@ -73,6 +80,8 @@ def __init__(self, type_str): def __repr__(self): + if self.bits == 1 and self.lanes == 1: + return "bool" x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) if self.lanes != 1: x += "x%d" % self.lanes diff --git a/python/tvm/api.py b/python/tvm/api.py index 8cf507de6386..e275c1122c36 100644 --- a/python/tvm/api.py +++ b/python/tvm/api.py @@ -243,24 +243,43 @@ def compute(shape, fcompute, name="compute", tag="", attrs=None): raise ValueError("nested tag is not allowed for now") tag = _tag.TagScope.get_current().tag shape = (shape,) if isinstance(shape, _expr.Expr) else shape + # for python3 + shape = tuple([int(s) if isinstance(s, float) else s for s in shape]) ndim = len(shape) code = fcompute.__code__ - if fcompute.__code__.co_argcount == 0: + out_ndim = ndim + if code.co_argcount == 0: arg_names = ["i%d" % i for i in range(ndim)] else: arg_names = code.co_varnames[:code.co_argcount] + out_ndim = code.co_argcount - if ndim != len(arg_names): + if out_ndim != len(arg_names): raise ValueError("fcompute do not match dimension, ndim=%d" % ndim) - dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape)] + dim_var = [_IterVar((0, s), x, 0) for x, s in zip(arg_names, shape[:out_ndim])] body = fcompute(*[v.var for v in dim_var]) - if not isinstance(body, (list, tuple)): - body = [body] - body = convert(body) - op_node = _api_internal._ComputeOp( - name, tag, attrs, dim_var, body) + + if isinstance(body, _tensor.TensorIntrinCall): + for i, s in enumerate(shape[out_ndim:]): + var_name = "ax" + str(i) + dim_var.append(_IterVar((0, s), var_name, 4)) + op_node = _api_internal._TensorComputeOp(name, + tag, + dim_var, + body.reduce_axis, + out_ndim, + body.intrin, + body.tensors, + body.regions) + else: + if not isinstance(body, (list, tuple)): + body = [body] + body = convert(body) + op_node = _api_internal._ComputeOp( + name, tag, attrs, dim_var, body) + num = op_node.num_outputs outputs = tuple(op_node.output(i) for i in range(num)) return outputs[0] if num == 1 else outputs @@ -529,14 +548,14 @@ def decl_buffer(shape, dtype = float32 if dtype is None else dtype strides = () if strides is None else strides if offset_factor != 0 and elem_offset is None: - elem_offset = var('%s_elem_offset' % name, shape[0].dtype) + shape_dtype = shape[0].dtype if hasattr(shape[0], "dtype") else "int32" + elem_offset = var('%s_elem_offset' % name, shape_dtype) if data is None: data = var(name, "handle") return _api_internal._Buffer( data, dtype, shape, strides, elem_offset, name, scope, data_alignment, offset_factor) - def _IterVar(dom, name, iter_type, thread_tag=''): """Internal function to create IterVar diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index c1299636eed2..18c02a416d6b 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -1,4 +1,4 @@ -# pylint: disable=wildcard-import +# pylint: disable=wildcard-import, redefined-builtin """The Relay IR namespace containing the IR definition and compiler.""" from . import base from . import ty @@ -10,8 +10,10 @@ # Root operators from .op import Op from .op.tensor import * -from . import nn from .op.transform import * +from . import nn +from . import vision +from . import image # Span Span = base.Span @@ -25,6 +27,7 @@ TypeConstraint = ty.TypeConstraint FuncType = ty.FuncType TypeRelation = ty.TypeRelation +IncompleteType = ty.IncompleteType # Expr Constant = expr.Constant @@ -36,3 +39,4 @@ Call = expr.Call Let = expr.Let If = expr.If +TupleGetItem = expr.TupleGetItem diff --git a/python/tvm/relay/_ir_pass.pyi b/python/tvm/relay/_ir_pass.pyi index f321083aa443..f1432803e9e2 100644 --- a/python/tvm/relay/_ir_pass.pyi +++ b/python/tvm/relay/_ir_pass.pyi @@ -4,4 +4,5 @@ from . import ir def check_expr(env: Environment, expr: ir.Expr) -> ir.Type: ... def generalize(env: Environment, expr: ir.Expr) -> ir.Expr: ... def _get_checked_type(expr: ir.Expr) -> ir.Type: ... -def well_formed(expr: ir.Expr) -> bool: ... \ No newline at end of file +def well_formed(expr: ir.Expr) -> bool: ... +def dead_code_elimination(expr: ir.Expr) -> ir.Expr: ... \ No newline at end of file diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 3f90a3af64a5..6ed8df0d736b 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -9,7 +9,15 @@ class Expr(NodeBase): """The base type for all Relay expressions.""" + @property def checked_type(self): + """Get the checked type of relay. + + Returns + ------- + checked_type : relay.Type + The checked type. + """ ret = self._checked_type_ if ret is None: raise ValueError("The type checker has not populated" @@ -104,7 +112,7 @@ def __init__(self, op, args, attrs, ty_args=None): class Let(Expr): """A variable bindings in Relay, see tvm/relay/expr.h for more details.""" - def __init__(self, var, value, body, value_type): + def __init__(self, var, value, body, value_type=None): self.__init_handle_by_constructor__( _make.Let, var, value, body, value_type) @@ -117,4 +125,12 @@ def __init__(self, cond, true_value, false_value): self.__init_handle_by_constructor__( _make.If, cond, true_value, false_value) +@register_relay_node +class TupleGetItem(Expr): + """An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details.""" + + def __init__(self, tuple_, index): + self.__init_handle_by_constructor__( + _make.TupleGetItem, tuple_, index) + debug_print = _expr._debug_print diff --git a/python/tvm/relay/image.py b/python/tvm/relay/image.py new file mode 100644 index 000000000000..43cee89b3483 --- /dev/null +++ b/python/tvm/relay/image.py @@ -0,0 +1,4 @@ +# pylint: disable=wildcard-import, unused-import, unused-wildcard-import +"""Image nets related operators.""" +# Re-export in a specific file name so that autodoc can pick it up +from .op.image import * diff --git a/python/tvm/relay/ir_builder.py b/python/tvm/relay/ir_builder.py index 6e52f209d0c6..accb782659df 100644 --- a/python/tvm/relay/ir_builder.py +++ b/python/tvm/relay/ir_builder.py @@ -16,12 +16,12 @@ def _convert_to_value(arg, ctxt=tvm.cpu(0)): """Convert Python values into the appropriate types for the Relay evaluator. """ - if isinstance(arg, int): + if isinstance(arg, bool): # bool is subclass of int + return tvm.nd.array(np.array(arg, dtype='uint8'), ctxt) + elif isinstance(arg, int): return tvm.nd.array(np.array(arg, dtype='int32'), ctxt) elif isinstance(arg, float): return tvm.nd.array(arg, ctxt) - elif isinstance(arg, bool): - return tvm.nd.array(np.array(arg, dtype='float32'), ctxt) elif isinstance(arg, np.ndarray): return tvm.nd.array(arg, ctxt) elif isinstance(arg, tvm.ndarray.NDArray): diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 78cc5027c32c..6de6437b9eb9 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -6,15 +6,16 @@ them in Python. """ from . import _ir_pass +from . import _make # pylint: disable=invalid-name def infer_type(env, expr): - """Infer the type of expr under the context of env + """Infer the type of expr under the context of env. Parameters ---------- env : relay.Environment - The global environmemt. + The global environment. expr : relay.Expr The input expression. @@ -34,3 +35,37 @@ def infer_type(env, expr): free_vars = _ir_pass.free_vars free_type_vars = _ir_pass.free_type_vars + +def dead_code_elimination(e): + """ Remove expressions which does not effect the program result (dead code). + + Parameters + ---------- + e: relay.Expr + The input Expression + + Returns + ------- + result: relay.Expr + An expression which is semantically equal to the input expression, + but with dead code removed. + """ + return _ir_pass.dead_code_elimination(e) + +def alpha_equal(lhs, rhs): + """Compare two Relay expr for structural equivalence (alpha equivalence). + + Parameters + ---------- + lhs: relay.Expr + One of the input Expression. + rhs: relay.Expr + One of the input Expression. + + + Returns + ------- + result: bool + True iff lhs is alpha equal to rhs. + """ + return bool(_make._alpha_equal(lhs, rhs)) diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index 444dc74a31cb..bfd368356d89 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -1,13 +1,14 @@ -#pylint: disable=wildcard-import +#pylint: disable=wildcard-import, redefined-builtin """Relay core operators.""" # operator defs from .op import get, register, Op # Operators from .tensor import * -from . import nn from .transform import * - +from . import nn +from . import image +from . import vision # operator registry from . import _tensor diff --git a/python/tvm/relay/op/image/__init__.py b/python/tvm/relay/op/image/__init__.py new file mode 100644 index 000000000000..9d1415b1dca4 --- /dev/null +++ b/python/tvm/relay/op/image/__init__.py @@ -0,0 +1,4 @@ +# pylint: disable=wildcard-import +"""Image network related operators.""" +from __future__ import absolute_import as _abs +from .image import * diff --git a/python/tvm/relay/op/image/_make.py b/python/tvm/relay/op/image/_make.py new file mode 100644 index 000000000000..1198258553fe --- /dev/null +++ b/python/tvm/relay/op/image/_make.py @@ -0,0 +1,4 @@ +"""Constructor APIs""" +from ...._ffi.function import _init_api + +_init_api("relay.op.image._make", __name__) diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py new file mode 100644 index 000000000000..36c8dd5fa548 --- /dev/null +++ b/python/tvm/relay/op/image/image.py @@ -0,0 +1,42 @@ +"""Image operations.""" +from __future__ import absolute_import as _abs +from . import _make + +def resize(data, + size, + layout="NCHW", + method="BILINEAR", + align_corners=False): + """Image resize operator. + + This operator takes data as input and does 2D scaling to the given scale factor. + In the default case, where the data_layout is `NCHW` + with data of shape (n, c, h, w) + out will have a shape (n, c, size[0], size[1]) + + method indicates the algorithm to be used while calculating ghe out value + and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + size: Tuple of Expr + The out size to which the image will be resized. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [NEAREST_NEIGHBOR, BILINEAR]. + + align_corners : int, optional + Should be true to preserve the values at the corner pixels + + Returns + ------- + result: relay.Expr + The resized result. + """ + return _make.resize(data, size, layout, method, align_corners) diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index f2d60d48eaad..52414df8e444 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -86,3 +86,427 @@ def conv2d(data, return _make.conv2d(data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout, weight_layout, out_layout, out_dtype) + + +def conv2d_transpose(data, + weight, + strides=(1, 1), + padding=(0, 0), + dilation=(1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCHW", + weight_layout="OIHW", + output_padding=(0, 0), + out_dtype=""): + """Two dimensional trnasposed convolution operator. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + weight : relay.Expr + The weight expressions. + + strides : Tuple[int], optional + The strides of convoltution. + + padding : Tuple[int], optional + The padding of convolution on both sides of inputs. + + dilation : Tuple[int], optional + Specifies the dilation rate to be used for dilated convolution. + + groups : int, optional + Number of groups for grouped convolution. + + data_layout : str, optional + Layout of the input. + + weight_layout : str, optional + Layout of the weight. + + output_padding : Tuple[int], optional + Additional zero-padding to be added to one side of the output. + + out_dtype : str, optional + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.conv2d_transpose(data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + weight_layout, output_padding, out_dtype) + + +def softmax(data, axis): + r"""Computes softmax. + + .. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} + + .. note:: + This operator can be optimized away for inference. + + Parameters + ---------- + data: relay.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing softmax + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.softmax(data, axis) + + +def log_softmax(data, axis): + r"""Computes log softmax. + + .. math:: + + \text{log_softmax}(x)_i = \log \frac{exp(x_i)}{\sum_j exp(x_j)} + + .. note:: + This operator can be optimized away for inference. + + Parameters + ---------- + data: relay.Expr + The input data to the operator. + + axis: int + The axis to sum over when computing softmax + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.log_softmax(data, axis) + + +def max_pool2d(data, + pool_size=(1, 1), + strides=(1, 1), + padding=(0, 0), + layout="NCHW", + ceil_mode=False): + r"""2D maximum pooling operator. + + This operator takes data as input and does 2D max value calculation + with in pool_size sized window by striding defined by stride + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w) and pool_size (kh, kw) + + .. math:: + + \mbox{out}(b, c, y, x) = \max_{m=0, \ldots, kh-1} \max_{n=0, \ldots, kw-1} + \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n) + + Padding is applied to data before the computation. + ceil_mode is used to take ceil or floor while computing out shape. + This operator accepts data layout specification. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + strides : tuple of int, optional + The strides of pooling. + + padding : tuple of int, optional + The padding for pooling. + + layout : str, optional + Layout of the input. + + ceil_mode : bool, optional + To enable or disable ceil while pooling. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.max_pool2d(data, pool_size, strides, padding, + layout, ceil_mode) + +def avg_pool2d(data, + pool_size=(1, 1), + strides=(1, 1), + padding=(0, 0), + layout="NCHW", + ceil_mode=False, + count_include_pad=False): + r"""2D average pooling operator. + + This operator takes data as input and does 2D average value calculation + with in pool_size sized window by striding defined by stride + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w), pool_size (kh, kw) + + .. math:: + + \mbox{out}(b, c, y, x) = \frac{1}{kh * kw} \sum_{m=0}^{kh-1} \sum_{n=0}^{kw-1} + \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x + n) + + Padding is applied to data before the computation. + ceil_mode is used to take ceil or floor while computing out shape. + count_include_pad indicates including or excluding padded input values in computation. + This operator accepts data layout specification. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + strides : tuple of int, optional + The strides of pooling. + + padding : tuple of int, optional + The padding for pooling. + + layout : str, optional + Layout of the input. + + ceil_mode : bool, optional + To enable or disable ceil while pooling. + + count_include_pad : bool, optional + To include padding to compute the average. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.avg_pool2d(data, pool_size, strides, padding, + layout, ceil_mode, count_include_pad) + +def global_max_pool2d(data, + layout="NCHW"): + r"""2D global maximum pooling operator. + + This operator takes data as input and does 2D max value calculation + across each window represented by WxH. + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w) + + .. math:: + + \mbox{out}(b, c, 1, 1) = \max_{m=0, \ldots, h} \max_{n=0, \ldots, w} + \mbox{data}(b, c, m, n) + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.global_max_pool2d(data, layout) + +def global_avg_pool2d(data, + layout="NCHW"): + r"""2D global average pooling operator. + + This operator takes data as input and does 2D average value calculation + across each window represented by WxH. + + + In the default case, where the data_layout is `NCHW` + a data Tensor with shape `(batch_size, in_channels, height, width)`, + to produce an output Tensor with the following rule: + + with data of shape (b, c, h, w) + + .. math:: + + \mbox{out}(b, c, 1, 1) = \frac{1}{h * w} \sum_{m=0}^{h-1} \sum_{n=0}^{w-1} + \mbox{data}(b, c, m, n) + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + layout : str, optional + Layout of the input. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.global_avg_pool2d(data, layout) + + +def upsampling(data, + scale=1, + layout="NCHW", + method="NEAREST_NEIGHBOR"): + """Upsampling. + + This operator takes data as input and does 2D scaling to the given scale factor. + In the default case, where the data_layout is `NCHW` + with data of shape (n, c, h, w) + out will have a shape (n, c, h*scale, w*scale) + + method indicates the algorithm to be used while calculating ghe out value + and method can be one of ("BILINEAR", "NEAREST_NEIGHBOR") + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + scale : relay.Expr + The scale factor for upsampling. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method to used [NEAREST_NEIGHBOR, BILINEAR]. + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.upsampling(data, scale, layout, method) + +def batch_flatten(data): + """BatchFlatten. + + This operator flattens all the dimensions except for the batch dimension. + which results a 2D output. + + For data with shape ``(d1, d2, ..., dk)`` + batch_flatten(data) returns reshaped output of shape ``(d1, d2*...*dk)``. + + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + Returns + ------- + result: relay.Expr + The Flattened result. + """ + return _make.batch_flatten(data) + + +def relu(data): + """Rectified linear unit. + + .. math:: + out = max(x, 0) + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.relu(data) + + +def lrn(data, size=5, axis=1, bias=2, alpha=.00001, beta=0.75): + """This operator takes data as input and does local response normalization. + + Normalize the input in a local region across or within feature maps. + Each input value is divided by (data / (bias + (alpha * sum_data ^2 /size))^beta) + where n is the size of each local region, and the sum is taken over the region + centered at that value (zero padding is added where necessary). + + .. math:: + (data / (bias + (alpha * sum_data ^2 /size))^beta) + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + size : int, optional + The size of the local region to be considered for normalization. + + axis : int, optional + Input data layout channel axis. Default value is 1 for NCHW format + + bias : float, optional + The offset parameter to avoid dividing by 0. + + alpha : float, optional + The scaling parameter. + + beta : float, optional + The exponent parameter. + + Returns + ------- + result : relay.Expr + The computed result. + """ + + return _make.lrn(data, size, axis, alpha, beta, bias) + +def l2_normalize(data, eps, axis=None): + """Perform L2 normalization on the input data + + .. math:: + y(i, j) = x(i, j) / sqrt(max(sum(x^2), eps)) + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + eps : float + epsilon value + + axis : list of int, optional + axis over the normalization applied + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.l2_normalize(data, eps, axis) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index c8c42c1a6ca4..316514801fd6 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -1,4 +1,5 @@ """Basic tensor operations.""" +# pylint: disable=redefined-builtin from __future__ import absolute_import as _abs from . import _make from ..expr import Tuple @@ -59,6 +60,133 @@ def sqrt(data): """ return _make.sqrt(data) +def sigmoid(data): + """Compute elementwise sigmoid of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.sigmoid(data) + + +def floor(data): + """Compute element-wise floor of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.floor(data) + + +def ceil(data): + """Compute element-wise ceil of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.ceil(data) + + +def trunc(data): + """Compute element-wise trunc of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.trunc(data) + + +def round(data): + """Compute element-wise round of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.round(data) + + +def abs(data): + """Compute element-wise absolute of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.abs(data) + + +def tanh(data): + """Compute element-wise tanh of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.tanh(data) + + +def negative(data): + """Compute element-wise negative of data. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.negative(data) + def add(lhs, rhs): """Addition with numpy-style broadcasting. @@ -86,8 +214,80 @@ def add(lhs, rhs): return _make.add(lhs, rhs) +def multiply(lhs, rhs): + """Multiplication with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.multiply(lhs, rhs) + + +def divide(lhs, rhs): + """Division with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.divide(lhs, rhs) + + +def pow(lhs, rhs): + """Power with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.pow(lhs, rhs) + + +def mod(lhs, rhs): + """Mod with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.mod(lhs, rhs) + + def subtract(lhs, rhs): - """Elementwise subtraction with broadcasting. + """Subtraction with numpy-style broadcasting. Parameters ---------- @@ -104,7 +304,6 @@ def subtract(lhs, rhs): return _make.subtract(lhs, rhs) - def equal(lhs, rhs): """Broadcasted elementwise test for (lhs == rhs). @@ -213,6 +412,42 @@ def greater_equal(lhs, rhs): return _make.greater_equal(lhs, rhs) +def maximum(lhs, rhs): + """Maximum with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.maximum(lhs, rhs) + + +def minimum(lhs, rhs): + """Minimum with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.minimum(lhs, rhs) + + def right_shift(lhs, rhs): """Right shift with numpy-style broadcasting. @@ -231,16 +466,118 @@ def right_shift(lhs, rhs): return _make.right_shift(lhs, rhs) -def concat(*args): - """Concatenate the input tensors along the zero axis. +def left_shift(lhs, rhs): + """Left shift with numpy-style broadcasting. Parameters ---------- - args: list of Tensor + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.left_shift(lhs, rhs) + + +def zeros_like(data): + """Returns an array of zeros, with same type and shape as the input. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.zeros_like(data) + + +def ones_like(data): + """Returns an array of ones, with same type and shape as the input. + + Parameters + ---------- + data : relay.Expr + The input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.ones_like(data) + + +def clip(a, a_min, a_max): + """Clip the elements in `a` between `a_min` and `a_max`. + `a_min` and `a_max` are cast to `a`'s dtype. + + Parameters + ---------- + a : relay.Expr + The input tensor. + a_min : float + The clip minimum. + a_max : float + The clip maximum. + + Returns + ------- + result : relay.Expr + `a` with elements clipped between `a_min` and `a_max`. + + Examples + -------- + .. code:: python + x = relay.Constant(tvm.nd.array([0, 1, 5, 3, 4, 2])) + relay.clip(x, 1., 4.) + # [1, 1, 4, 3, 4, 2] + """ + return _make.clip(a, a_min, a_max) + + +def concatenate(data, axis): + """Concatenate the input tensors along the given axis. + + Parameters + ---------- + data : Union(List[relay.Expr], Tuple[relay.Expr]) + A list of tensors. + axis : int + The axis along which the tensors are concatenated. + + Returns + ------- + result: relay.Expr + The concatenated tensor. + """ + data = list(data) + if not data: + raise ValueError("relay.concatenate requires data to be non-empty.") + if not isinstance(axis, int): + raise ValueError("For now, we only support integer axis") + return _make.concatenate(Tuple(data), axis) + + +def copy(data): + """Copy a tensor. + + Parameters + ---------- + data : relay.Expr + The tensor to be copied. Returns ------- - tensor: The concatenated tensor. + result: relay.Expr + The copied result. """ - tup = Tuple(list(args)) - return _make.concat(tup) + return _make.copy(data) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 21f61735e58a..757297db9109 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -18,7 +18,7 @@ def expand_dims(data, axis, num_newaxis=1): If `axis >= 0`, it is the last axis inserted in Python's negative indexing. num_newaxis : int - Number of axises to be inserted. Should be >= 0. + Number of axes to be inserted. Should be >= 0. Returns ------- @@ -26,3 +26,157 @@ def expand_dims(data, axis, num_newaxis=1): The reshaped result. """ return _make.expand_dims(data, axis, num_newaxis) + + +def transpose(data, axes=None): + """Permutes the dimensions of an array. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axes : None or List[int] + The target axes order, reverse order if not specified. + + Returns + ------- + result : relay.Expr + The reshaped result. + """ + axes = axes or [] + return _make.transpose(data, list(axes)) + + +def reshape(data, newshape): + """Reshapes the input array. + + Example:: + + To give user more convenience in without doing manual shape inference, + some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. + The significance of each is explained below: + + - ``0`` copy this dimension from the input to the output shape. + + Example:: + + - data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2) + - data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4) + + - ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions + keeping the size of the new array same as that of the input array. + At most one dimension of shape can be -1. + + Example:: + + - data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4) + - data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8) + - data.shape = (2,3,4), newshape = (-1,), result.shape = (24,) + + - ``-2`` copy all/remainder of the input dimensions to the output shape. + + Example:: + + - data.shape = (2,3,4), newshape = (-2,), result.shape = (2,3,4) + - data.shape = (2,3,4), newshape = (2,-2), result.shape = (2,3,4) + - data.shape = (2,3,4), newshape = (-2,1,1), result.shape = (2,3,4,1,1) + + - ``-3`` use the product of two consecutive dimensions of the input shape + as the output dimension. + + Example:: + + - data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4) + - data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20) + - data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12) + - data.shape = (2,3,4), newshape = (-3,-2), result.shape = (6,4) + + - ``-4`` split one dimension of the input into two dimensions passed subsequent + to -4 in shape (can contain -1). + + Example:: + + - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) + - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + newshape : Union[int, Tuple[int], List[int]] + The new shape. Should be compatible with the original shape. + + Returns + ------- + result : relay.Expr + The reshaped result. + """ + if isinstance(newshape, int): + newshape = [newshape] + return _make.reshape(data, list(newshape)) + + +def take(data, indices, axis=None): + """Take elements from an array along an axis. + + Parameters + ---------- + a : relay.Expr + The source array. + + indices : rely.Expr + The indices of the values to extract. + + axis : int, optional + The axis over which to select values. By default, + the flattened input array is used. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.take(data, indices, axis) + + +def full(fill_value, shape=(), dtype=""): + """Fill array with scalar value. + + Parameters + ---------- + fill_value : relay.Expr + The value to fill. Must be a scalar. + + shape : tuple of int + The shape of the target. + + dtype : data type, optional (defaults to data type of the fill value) + The data type of the target. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ + return _make.full(fill_value, shape, dtype) + + +def full_like(data, fill_value): + """Return an scalar value array with the same shape and type as the input array. + + Parameters + ---------- + data : relay.Expr + The input tensor. + + fill_value : relay.Expr + The scalar value to fill. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ + return _make.full_like(data, fill_value) diff --git a/python/tvm/relay/op/vision/__init__.py b/python/tvm/relay/op/vision/__init__.py new file mode 100644 index 000000000000..3569093b95e6 --- /dev/null +++ b/python/tvm/relay/op/vision/__init__.py @@ -0,0 +1,3 @@ +# pylint: disable=wildcard-import +"""Vision network related operators.""" +from __future__ import absolute_import as _abs diff --git a/python/tvm/relay/op/vision/_make.py b/python/tvm/relay/op/vision/_make.py new file mode 100644 index 000000000000..614d42f47176 --- /dev/null +++ b/python/tvm/relay/op/vision/_make.py @@ -0,0 +1,4 @@ +"""Constructor APIs""" +from ...._ffi.function import _init_api + +_init_api("relay.op.vision._make", __name__) diff --git a/python/tvm/relay/ty.py b/python/tvm/relay/ty.py index c7cf9a346b68..a6ac1857bfa8 100644 --- a/python/tvm/relay/ty.py +++ b/python/tvm/relay/ty.py @@ -12,7 +12,7 @@ def __eq__(self, other): """Compare two Relay types for structural equivalence using alpha equivalence. """ - return bool(_make._type_alpha_eq(self, other)) + return bool(_make._type_alpha_equal(self, other)) def __ne__(self, other): return not self.__eq__(other) diff --git a/python/tvm/relay/vision.py b/python/tvm/relay/vision.py new file mode 100644 index 000000000000..d2c08bc0cc45 --- /dev/null +++ b/python/tvm/relay/vision.py @@ -0,0 +1,4 @@ +# pylint: disable=wildcard-import, unused-import, unused-wildcard-import +"""Vision network related operators.""" +# Re-export in a specific file name so that autodoc can pick it up +from .op.vision import * diff --git a/python/tvm/tensor.py b/python/tvm/tensor.py index f0d60f514a37..f32b70eb9a12 100644 --- a/python/tvm/tensor.py +++ b/python/tvm/tensor.py @@ -30,6 +30,11 @@ def dtype(self): """Data content of the tensor.""" return self.tensor.dtype +@register_node +class TensorIntrinCall(NodeBase): + """Intermediate structure for calling a tensor intrinsic.""" + pass + itervar_cls = None @@ -106,6 +111,7 @@ def name(self): return "%s.v%d" % (op.name, self.value_index) + class Operation(NodeBase): """Represent an operation that generate a tensor""" @@ -155,6 +161,12 @@ def reduce_axis(self): return self.__getattr__("reduce_axis") +@register_node +class TensorComputeOp(Operation): + """Tensor operation.""" + pass + + @register_node class ScanOp(Operation): """Scan operation.""" diff --git a/python/tvm/tensor_intrin.py b/python/tvm/tensor_intrin.py index 193124b2f946..f1f26655fe27 100644 --- a/python/tvm/tensor_intrin.py +++ b/python/tvm/tensor_intrin.py @@ -6,9 +6,25 @@ from . import stmt as _stmt from . import make as _make from . import tensor as _tensor +from . import schedule as _schedule from .build_module import current_build_config from ._ffi.node import NodeBase, register_node + +def _get_region(tslice): + region = [] + for idx in tslice.indices: + if isinstance(idx, slice): + assert idx.step is None + region.append(_api.Range(idx.start, idx.stop)) + else: + if isinstance(idx, _schedule.IterVar): + begin = idx.var + else: + begin = idx + region.append(_make.range_by_min_extent(begin, 1)) + return region + @register_node class TensorIntrin(NodeBase): """Tensor intrinsic functions for certain computation. @@ -17,8 +33,16 @@ class TensorIntrin(NodeBase): -------- decl_tensor_intrin: Construct a TensorIntrin """ - pass - + def __call__(self, *args, **kwargs): + tensors = [x.tensor for x in args] + regions = [_get_region(x) for x in args] + reduce_axis = [] + if "reduce_axis" in kwargs: + reduce_axis = kwargs["reduce_axis"] + if not isinstance(reduce_axis, (list, tuple)): + reduce_axis = [reduce_axis] + reduce_axis = _api.convert(reduce_axis) + return _api_internal._TensorIntrinCall(self, tensors, regions, reduce_axis) def decl_tensor_intrin(op, fcompute, diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 000000000000..230ab66104df --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,3 @@ +Cargo.lock +target/ +**/*.rs.bk diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml new file mode 100644 index 000000000000..dbf3347a32bd --- /dev/null +++ b/rust/.rustfmt.toml @@ -0,0 +1,59 @@ +max_width = 100 +hard_tabs = false +tab_spaces = 2 +newline_style = "Auto" +use_small_heuristics = "Default" +indent_style = "Block" +wrap_comments = false +comment_width = 80 +normalize_comments = false +format_strings = false +format_macro_matchers = false +format_macro_bodies = true +empty_item_single_line = true +struct_lit_single_line = true +fn_single_line = false +where_single_line = false +imports_indent = "Block" +imports_layout = "Mixed" +merge_imports = true +reorder_imports = true +reorder_modules = true +reorder_impl_items = false +type_punctuation_density = "Wide" +space_before_colon = false +space_after_colon = true +spaces_around_ranges = false +binop_separator = "Front" +remove_nested_parens = true +combine_control_expr = true +struct_field_align_threshold = 0 +match_arm_blocks = true +force_multiline_blocks = false +fn_args_density = "Tall" +brace_style = "SameLineWhere" +control_brace_style = "AlwaysSameLine" +trailing_semicolon = true +trailing_comma = "Vertical" +match_block_trailing_comma = false +blank_lines_upper_bound = 1 +blank_lines_lower_bound = 0 +edition = "2015" +merge_derives = true +use_try_shorthand = true +use_field_init_shorthand = false +force_explicit_abi = true +condense_wildcard_suffixes = false +color = "Auto" +required_version = "0.99.5" +unstable_features = false +disable_all_formatting = false +skip_children = false +hide_parse_errors = false +error_on_line_overflow = false +error_on_unformatted = false +report_todo = "Never" +report_fixme = "Never" +ignore = [] +emit_mode = "Files" +make_backup = false diff --git a/rust/.travis.yml b/rust/.travis.yml new file mode 100644 index 000000000000..63a3d0277c1b --- /dev/null +++ b/rust/.travis.yml @@ -0,0 +1,5 @@ +language: rust +rust: + - nightly +matrix: + fast_finish: true diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 000000000000..0819e0c70023 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "tvm" +version = "0.1.0" +license = "Apache-2.0" +description = "TVM Rust runtime" +repository = "https://github.com/dmlc/tvm" +readme = "README.md" +keywords = ["tvm", "nnvm"] +categories = ["api-bindings", "science"] +authors = ["Nick Hynes "] + +[features] +default = ["nom/std"] +sgx = ["nom/alloc"] + +[dependencies] +bounded-spsc-queue = "0.4.0" +error-chain = { version = "0.12.0", default-features = false } +itertools = "0.7.8" +lazy_static = "1.1.0" +ndarray = "0.11.2" +nom = {version = "4.0.0", default-features = false } +serde = "1.0.59" +serde_derive = "1.0.79" +serde_json = "1.0.17" + +[target.'cfg(not(target_env = "sgx"))'.dependencies] +num_cpus = "1.8.0" diff --git a/rust/src/errors.rs b/rust/src/errors.rs new file mode 100644 index 000000000000..f9da7180b8cc --- /dev/null +++ b/rust/src/errors.rs @@ -0,0 +1,39 @@ +#[cfg(target_env = "sgx")] +use alloc::alloc; +#[cfg(not(target_env = "sgx"))] +use std::alloc; +use std::num; + +use ndarray; +use serde_json; + +error_chain! { + errors { + TryFromTVMRetValueError(expected: String, actual: i64) { + description("mismatched types while downcasting TVMRetValue") + display("invalid downcast: expected `{}` but was `{}`", expected, actual) + } + + GraphFormatError(msg: String) { + description("unable to load graph") + display("could not load graph json: {}", msg) + } + + LoadGraphParamsError(msg: String) { + description("unable to load graph params") + display("could not load graph params: {}", msg) + } + } + foreign_links { + Alloc(alloc::AllocErr); + GraphDeserialize(serde_json::Error); + ParseInt(num::ParseIntError); + ShapeError(ndarray::ShapeError); + } +} + +impl From for Error { + fn from(_err: alloc::LayoutErr) -> Error { + Error::from_kind(ErrorKind::Msg("Layout error".to_string())) + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 000000000000..e17c66911b18 --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,67 @@ +//! This crate is an implementation of the TVM runtime for modules compiled with `--system-lib`. +//! It's mainly useful for compiling to WebAssembly and SGX, +//! but also native if you prefer Rust to C++. +//! +//! For TVM graphs, the entrypoint to this crate is `runtime::GraphExecutor`. +//! Single-function modules are used via the `packed_func!` macro after obtaining +//! the function from `runtime::SystemLibModule` +//! +//! The main entrypoints to this crate are `GraphExecutor` +//! For examples of use, please refer to the multi-file tests in the `tests` directory. + +#![feature( + alloc, + allocator_api, + box_syntax, + fn_traits, + try_from, + unboxed_closures, + vec_remove_item +)] + +#[cfg(target_env = "sgx")] +extern crate alloc; +extern crate bounded_spsc_queue; +#[cfg(target_env = "sgx")] +extern crate core; +#[macro_use] +extern crate error_chain; +#[macro_use] +extern crate itertools; +#[macro_use] +extern crate lazy_static; +extern crate ndarray; +#[macro_use] +extern crate nom; +#[cfg(not(target_env = "sgx"))] +extern crate num_cpus; +extern crate serde; +#[macro_use] +extern crate serde_derive; +extern crate serde_json; + +pub mod ffi { + #![allow( + non_camel_case_types, + non_snake_case, + non_upper_case_globals, + unused + )] + + pub mod runtime { + use std::os::raw::{c_char, c_int, c_void}; + + include!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/runtime/c_runtime_api.rs" + )); + + pub type BackendPackedCFunc = + extern "C" fn(args: *const TVMValue, type_codes: *const c_int, num_args: c_int) -> c_int; + } +} + +pub mod errors; +pub mod runtime; + +pub use errors::*; diff --git a/rust/src/runtime/allocator.rs b/rust/src/runtime/allocator.rs new file mode 100644 index 000000000000..d704336bff1f --- /dev/null +++ b/rust/src/runtime/allocator.rs @@ -0,0 +1,52 @@ +#[cfg(target_env = "sgx")] +use alloc::alloc::{self, Layout}; +#[cfg(not(target_env = "sgx"))] +use std::alloc::{self, Layout}; + +use errors::*; + +const DEFAULT_ALIGN_BYTES: usize = 4; + +#[derive(PartialEq, Eq)] +pub struct Allocation { + layout: Layout, + ptr: *mut u8, +} + +impl Allocation { + /// Allocates a chunk of memory of `size` bytes with optional alignment. + pub fn new(size: usize, align: Option) -> Result { + let alignment = align.unwrap_or(DEFAULT_ALIGN_BYTES); + let layout = Layout::from_size_align(size, alignment)?; + let ptr = unsafe { alloc::alloc(layout.clone()) }; + if ptr.is_null() { + alloc::handle_alloc_error(layout); + } + Ok(Self { + ptr: ptr, + layout: layout, + }) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + self.ptr + } + + /// Returns the size of the Allocation in bytes. + pub fn size(&self) -> usize { + self.layout.size() + } + + /// Returns the byte alignment of the Allocation. + pub fn align(&self) -> usize { + self.layout.align() + } +} + +impl Drop for Allocation { + fn drop(&mut self) { + unsafe { + alloc::dealloc(self.ptr, self.layout.clone()); + } + } +} diff --git a/rust/src/runtime/array.rs b/rust/src/runtime/array.rs new file mode 100644 index 000000000000..9d0941811758 --- /dev/null +++ b/rust/src/runtime/array.rs @@ -0,0 +1,462 @@ +use std::{ + any::TypeId, + convert::TryFrom, + mem, + os::raw::{c_int, c_void}, + ptr, slice, +}; + +use ndarray; + +use super::allocator::Allocation; +use errors::*; +use ffi::runtime::{ + DLContext, DLDataType, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, + DLDeviceType_kDLCPU, DLTensor, +}; + +/// A `Storage` is a container which holds `Tensor` data. +#[derive(PartialEq)] +pub enum Storage<'a> { + /// A `Storage` which owns its contained bytes. + Owned(Allocation), + + /// A view of an existing `Storage`. + View(&'a mut [u8], usize), // ptr, align +} + +impl<'a> Storage<'a> { + pub fn new(size: usize, align: Option) -> Result> { + Ok(Storage::Owned(Allocation::new(size, align)?)) + } + + pub fn as_mut_ptr(&self) -> *mut u8 { + match self { + Storage::Owned(alloc) => alloc.as_mut_ptr(), + Storage::View(slice, _) => slice.as_ptr() as *mut u8, + } + } + + pub fn size(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.size(), + Storage::View(slice, _) => slice.len(), + } + } + + pub fn align(&self) -> usize { + match self { + Storage::Owned(alloc) => alloc.align(), + Storage::View(_, align) => *align, + } + } + + pub fn as_ptr(&self) -> *const u8 { + self.as_mut_ptr() as *const _ + } + + /// Returns a `Storage::View` which points to an owned `Storage::Owned`. + pub fn view(&self) -> Storage<'a> { + match self { + Storage::Owned(alloc) => Storage::View( + unsafe { slice::from_raw_parts_mut(alloc.as_mut_ptr(), self.size()) }, + self.align(), + ), + Storage::View(slice, _) => Storage::View( + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), slice.len()) }, + self.align(), + ), + } + } + + pub fn is_owned(&self) -> bool { + match self { + Storage::Owned(_) => true, + _ => false, + } + } + + /// Returns an owned version of this storage via cloning. + pub fn to_owned(&self) -> Storage<'static> { + let s = Storage::new(self.size(), Some(self.align())).unwrap(); + unsafe { + s.as_mut_ptr() + .copy_from_nonoverlapping(self.as_ptr(), self.size()) + } + s + } +} + +impl<'a, T> From<&'a [T]> for Storage<'a> { + fn from(data: &'a [T]) -> Self { + let data = unsafe { + slice::from_raw_parts_mut( + data.as_ptr() as *const u8 as *mut u8, + data.len() * mem::size_of::() as usize, + ) + }; + Storage::View(data, mem::align_of::()) + } +} + +/// A n-dimensional array type which can be converted to/from `tvm::DLTensor` and `ndarray::Array`. +/// `Tensor` is primarily a holder of data which can be operated on via TVM (via `DLTensor`) or +/// converted to `ndarray::Array` for non-TVM processing. +/// +/// # Examples +/// +/// ``` +/// extern crate ndarray; +/// +/// let mut a_nd: ndarray::Array = ndarray::Array::from_vec(vec![1f32, 2., 3., 4.]); +/// let mut a: Tensor = a_nd.into(); +/// let mut a_dl: DLTensor = (&mut t).into(); +/// call_packed!(tvm_fn, &mut a_dl); +/// +/// // Array -> Tensor is mostly useful when post-processing TVM graph outputs. +/// let mut a_nd = ndarray::Array::try_from(&a).unwrap(); +/// ``` +#[derive(PartialEq)] +pub struct Tensor<'a> { + /// The bytes which contain the data this `Tensor` represents. + pub(super) data: Storage<'a>, + pub(super) ctx: TVMContext, + pub(super) dtype: DataType, + pub(super) shape: Vec, // not usize because `typedef int64_t tvm_index_t` in c_runtime_api.h + /// The `Tensor` strides. Can be `None` if the `Tensor` is contiguous. + pub(super) strides: Option>, + pub(super) byte_offset: isize, + pub(super) size: usize, +} + +unsafe impl<'a> Send for Tensor<'a> {} + +impl<'a> Tensor<'a> { + pub fn shape(&self) -> Vec { + self.shape.clone() + } + + /// Returns the data of this `Tensor` as a `Vec`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn to_vec(&self) -> Vec { + assert!(self.is_contiguous()); + assert!(self.dtype.is_type::()); + let mut vec: Vec = Vec::with_capacity(self.size * self.dtype.itemsize()); + unsafe { + vec.as_mut_ptr().copy_from_nonoverlapping( + self.data.as_ptr().offset(self.byte_offset) as *const T, + self.size, + ); + vec.set_len(self.size); + } + vec + } + + /// Returns `true` iff this `Tensor` is represented by a contiguous region of memory. + pub fn is_contiguous(&self) -> bool { + match self.strides { + None => true, + Some(ref strides) => { + // check that stride for each dimension is the product of all trailing dimensons' shapes + self + .shape + .iter() + .zip(strides) + .rfold( + (true, 1), + |(is_contig, expected_stride), (shape, stride)| { + ( + is_contig && *stride == expected_stride, + expected_stride * (*shape as usize), + ) + }, + ) + .0 + } + } + } + + /// Returns a clone of this `Tensor`. + /// + /// # Panics + /// + /// Panics if the `Tensor` is not contiguous or does not contain elements of type `T`. + pub fn copy(&mut self, other: &Tensor) { + assert!( + self.dtype == other.dtype && self.size == other.size, + "Tensor shape/dtype mismatch." + ); + assert!( + self.is_contiguous() && other.is_contiguous(), + "copy currently requires contiguous tensors\n`self.strides = {:?}` `other.strides = {:?}`", + self.strides, + other.strides + ); + unsafe { + self + .data + .as_mut_ptr() + .offset(self.byte_offset as isize) + .copy_from_nonoverlapping( + other.data.as_mut_ptr().offset(other.byte_offset), + other.size * other.dtype.itemsize(), + ); + } + } + + /// Returns an owned version of this `Tensor` via cloning. + pub fn to_owned(&self) -> Tensor<'static> { + let t = Tensor { + data: self.data.to_owned(), + ctx: self.ctx.clone(), + dtype: self.dtype.clone(), + size: self.size.clone(), + shape: self.shape.clone(), + strides: None, + byte_offset: 0, + }; + unsafe { mem::transmute::, Tensor<'static>>(t) } + } + + fn from_array_storage<'s, T, D: ndarray::Dimension>( + arr: &ndarray::Array, + storage: Storage<'s>, + type_code: usize, + ) -> Tensor<'s> { + let type_width = mem::size_of::() as usize; + Tensor { + data: storage, + ctx: TVMContext::default(), + dtype: DataType { + code: type_code, + bits: 8 * type_width, + lanes: 1, + }, + size: arr.len(), + shape: arr.shape().iter().map(|&v| v as i64).collect(), + strides: Some(arr.strides().into_iter().map(|&v| v as usize).collect()), + byte_offset: 0, + } + } +} + +/// Conversions to `ndarray::Array` from `Tensor`, if the types match. +macro_rules! impl_ndarray_try_from_tensor { + ($type:ty, $dtype:expr) => { + impl<'a, 't> TryFrom<&'a Tensor<'t>> for ndarray::ArrayD<$type> { + type Error = Error; + fn try_from(tensor: &'a Tensor) -> Result> { + ensure!( + tensor.dtype == $dtype, + "Cannot convert Tensor with dtype {:?} to ndarray", + tensor.dtype + ); + Ok(ndarray::Array::from_shape_vec( + tensor + .shape + .iter() + .map(|s| *s as usize) + .collect::>(), + tensor.to_vec::<$type>(), + )?) + } + } + }; +} + +impl_ndarray_try_from_tensor!(i32, DTYPE_INT32); +impl_ndarray_try_from_tensor!(u32, DTYPE_UINT32); +impl_ndarray_try_from_tensor!(f32, DTYPE_FLOAT32); +impl_ndarray_try_from_tensor!(f64, DTYPE_FLOAT64); + +impl DLTensor { + pub(super) fn from_tensor<'a>(tensor: &'a Tensor, flatten: bool) -> Self { + assert!(!flatten || tensor.is_contiguous()); + Self { + data: unsafe { tensor.data.as_mut_ptr().offset(tensor.byte_offset) } as *mut c_void, + ctx: DLContext::from(&tensor.ctx), + ndim: if flatten { 1 } else { tensor.shape.len() } as i32, + dtype: DLDataType::from(&tensor.dtype), + shape: if flatten { + &tensor.size as *const _ as *mut i64 + } else { + tensor.shape.as_ptr() + } as *mut i64, + strides: if flatten || tensor.is_contiguous() { + ptr::null_mut() + } else { + tensor.strides.as_ref().unwrap().as_ptr() + } as *mut i64, + byte_offset: 0, + } + } +} + +impl<'a, 't> From<&'a Tensor<'t>> for DLTensor { + fn from(tensor: &'a Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +impl<'a, 't> From<&'a mut Tensor<'t>> for DLTensor { + fn from(tensor: &'a mut Tensor<'t>) -> Self { + DLTensor::from_tensor(tensor, false /* flatten */) + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct DataType { + pub(super) code: usize, + pub(super) bits: usize, + pub(super) lanes: usize, +} + +impl DataType { + /// Returns the number of bytes occupied by an element of this `DataType`. + fn itemsize(&self) -> usize { + (self.bits * self.lanes) >> 3 + } + + /// Returns whether this `DataType` represents primitive type `T`. + fn is_type(&self) -> bool { + if self.lanes != 1 { + return false; + } + let typ = TypeId::of::(); + (typ == TypeId::of::() && self.code == 0 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 0 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 1 && self.bits == 64) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 32) + || (typ == TypeId::of::() && self.code == 2 && self.bits == 64) + } +} + +impl<'a> From<&'a DataType> for DLDataType { + fn from(dtype: &'a DataType) -> Self { + Self { + code: dtype.code as u8, + bits: dtype.bits as u8, + lanes: dtype.lanes as u16, + } + } +} + +macro_rules! make_dtype_const { + ($name: ident, $code: ident, $bits: expr, $lanes: expr) => { + const $name: DataType = DataType { + code: $code as usize, + bits: $bits, + lanes: $lanes, + }; + }; +} + +make_dtype_const!(DTYPE_INT32, DLDataTypeCode_kDLInt, 32, 1); +make_dtype_const!(DTYPE_UINT32, DLDataTypeCode_kDLUInt, 32, 1); +// make_dtype_const!(DTYPE_FLOAT16, DLDataTypeCode_kDLFloat, 16, 1); +make_dtype_const!(DTYPE_FLOAT32, DLDataTypeCode_kDLFloat, 32, 1); +make_dtype_const!(DTYPE_FLOAT64, DLDataTypeCode_kDLFloat, 64, 1); + +impl Default for DLContext { + fn default() -> Self { + DLContext { + device_type: DLDeviceType_kDLCPU, + device_id: 0, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TVMContext { + pub(super) device_type: usize, + pub(super) device_id: usize, +} + +impl<'a> From<&'a TVMContext> for DLContext { + fn from(ctx: &'a TVMContext) -> Self { + Self { + device_type: ctx.device_type as u32, + device_id: ctx.device_id as i32, + } + } +} + +impl Default for TVMContext { + fn default() -> Self { + Self { + device_type: DLDeviceType_kDLCPU as usize, + device_id: 0, + } + } +} + +/// `From` conversions to `Tensor` for owned or borrowed `ndarray::Array`. +/// +/// # Panics +/// +/// Panics if the ndarray is not contiguous. +macro_rules! impl_tensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl From> for Tensor<'static> { + fn from(arr: ndarray::Array<$type, D>) -> Self { + assert!(arr.is_standard_layout(), "Array must be contiguous."); + let size = arr.len() * mem::size_of::<$type>() as usize; + let storage = + Storage::from(unsafe { slice::from_raw_parts(arr.as_ptr() as *const u8, size) }); + Tensor::from_array_storage(&arr, storage, $typecode as usize) + } + } + impl<'a, D: ndarray::Dimension> From<&'a ndarray::Array<$type, D>> for Tensor<'a> { + fn from(arr: &'a ndarray::Array<$type, D>) -> Self { + assert!(arr.is_standard_layout(), "Array must be contiguous."); + Tensor::from_array_storage( + arr, + Storage::from(arr.as_slice().unwrap()), + $typecode as usize, + ) + } + } + }; +} + +/// `From` conversions to `DLTensor` for `ndarray::Array`. +/// Takes a reference to the `ndarray` since `DLTensor` is not owned. +macro_rules! impl_dltensor_from_ndarray { + ($type:ty, $typecode:expr) => { + impl<'a, D: ndarray::Dimension> From<&'a mut ndarray::Array<$type, D>> for DLTensor { + fn from(arr: &'a mut ndarray::Array<$type, D>) -> Self { + DLTensor { + data: arr.as_mut_ptr() as *mut c_void, + ctx: DLContext::default(), + ndim: arr.ndim() as c_int, + dtype: DLDataType { + code: $typecode as u8, + bits: 8 * mem::size_of::<$type>() as u8, + lanes: 1, + }, + shape: arr.shape().as_ptr() as *const i64 as *mut i64, + strides: arr.strides().as_ptr() as *const isize as *mut i64, + byte_offset: 0, + } + } + } + }; +} + +impl_dltensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_dltensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_dltensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_dltensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); + +impl_tensor_from_ndarray!(f32, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(f64, DLDataTypeCode_kDLFloat); +impl_tensor_from_ndarray!(i32, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(i64, DLDataTypeCode_kDLInt); +impl_tensor_from_ndarray!(u32, DLDataTypeCode_kDLUInt); +impl_tensor_from_ndarray!(u64, DLDataTypeCode_kDLUInt); diff --git a/rust/src/runtime/c_runtime_api.rs b/rust/src/runtime/c_runtime_api.rs new file mode 100644 index 000000000000..62cfa0d15451 --- /dev/null +++ b/rust/src/runtime/c_runtime_api.rs @@ -0,0 +1,770 @@ +/* automatically generated by rust-bindgen for TVM revision 6292c78 */ + +pub const TVM_VERSION: &'static [u8; 8usize] = b"0.5.dev\0"; +pub const DLPACK_VERSION: u32 = 8; +pub const _STDINT_H: u32 = 1; +pub const _FEATURES_H: u32 = 1; +pub const _DEFAULT_SOURCE: u32 = 1; +pub const __USE_ISOC11: u32 = 1; +pub const __USE_ISOC99: u32 = 1; +pub const __USE_ISOC95: u32 = 1; +pub const __USE_POSIX_IMPLICITLY: u32 = 1; +pub const _POSIX_SOURCE: u32 = 1; +pub const _POSIX_C_SOURCE: u32 = 200809; +pub const __USE_POSIX: u32 = 1; +pub const __USE_POSIX2: u32 = 1; +pub const __USE_POSIX199309: u32 = 1; +pub const __USE_POSIX199506: u32 = 1; +pub const __USE_XOPEN2K: u32 = 1; +pub const __USE_XOPEN2K8: u32 = 1; +pub const _ATFILE_SOURCE: u32 = 1; +pub const __USE_MISC: u32 = 1; +pub const __USE_ATFILE: u32 = 1; +pub const __USE_FORTIFY_LEVEL: u32 = 0; +pub const _STDC_PREDEF_H: u32 = 1; +pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_ISO_10646__: u32 = 201505; +pub const __STDC_NO_THREADS__: u32 = 1; +pub const __GNU_LIBRARY__: u32 = 6; +pub const __GLIBC__: u32 = 2; +pub const __GLIBC_MINOR__: u32 = 23; +pub const _SYS_CDEFS_H: u32 = 1; +pub const __WORDSIZE: u32 = 64; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const _BITS_WCHAR_H: u32 = 1; +pub const INT8_MIN: i32 = -128; +pub const INT16_MIN: i32 = -32768; +pub const INT32_MIN: i32 = -2147483648; +pub const INT8_MAX: u32 = 127; +pub const INT16_MAX: u32 = 32767; +pub const INT32_MAX: u32 = 2147483647; +pub const UINT8_MAX: u32 = 255; +pub const UINT16_MAX: u32 = 65535; +pub const UINT32_MAX: u32 = 4294967295; +pub const INT_LEAST8_MIN: i32 = -128; +pub const INT_LEAST16_MIN: i32 = -32768; +pub const INT_LEAST32_MIN: i32 = -2147483648; +pub const INT_LEAST8_MAX: u32 = 127; +pub const INT_LEAST16_MAX: u32 = 32767; +pub const INT_LEAST32_MAX: u32 = 2147483647; +pub const UINT_LEAST8_MAX: u32 = 255; +pub const UINT_LEAST16_MAX: u32 = 65535; +pub const UINT_LEAST32_MAX: u32 = 4294967295; +pub const INT_FAST8_MIN: i32 = -128; +pub const INT_FAST16_MIN: i64 = -9223372036854775808; +pub const INT_FAST32_MIN: i64 = -9223372036854775808; +pub const INT_FAST8_MAX: u32 = 127; +pub const INT_FAST16_MAX: u64 = 9223372036854775807; +pub const INT_FAST32_MAX: u64 = 9223372036854775807; +pub const UINT_FAST8_MAX: u32 = 255; +pub const UINT_FAST16_MAX: i32 = -1; +pub const UINT_FAST32_MAX: i32 = -1; +pub const INTPTR_MIN: i64 = -9223372036854775808; +pub const INTPTR_MAX: u64 = 9223372036854775807; +pub const UINTPTR_MAX: i32 = -1; +pub const PTRDIFF_MIN: i64 = -9223372036854775808; +pub const PTRDIFF_MAX: u64 = 9223372036854775807; +pub const SIG_ATOMIC_MIN: i32 = -2147483648; +pub const SIG_ATOMIC_MAX: u32 = 2147483647; +pub const SIZE_MAX: i32 = -1; +pub const WINT_MIN: u32 = 0; +pub const WINT_MAX: u32 = 4294967295; +pub type int_least8_t = ::std::os::raw::c_schar; +pub type int_least16_t = ::std::os::raw::c_short; +pub type int_least32_t = ::std::os::raw::c_int; +pub type int_least64_t = ::std::os::raw::c_long; +pub type uint_least8_t = ::std::os::raw::c_uchar; +pub type uint_least16_t = ::std::os::raw::c_ushort; +pub type uint_least32_t = ::std::os::raw::c_uint; +pub type uint_least64_t = ::std::os::raw::c_ulong; +pub type int_fast8_t = ::std::os::raw::c_schar; +pub type int_fast16_t = ::std::os::raw::c_long; +pub type int_fast32_t = ::std::os::raw::c_long; +pub type int_fast64_t = ::std::os::raw::c_long; +pub type uint_fast8_t = ::std::os::raw::c_uchar; +pub type uint_fast16_t = ::std::os::raw::c_ulong; +pub type uint_fast32_t = ::std::os::raw::c_ulong; +pub type uint_fast64_t = ::std::os::raw::c_ulong; +pub type intmax_t = ::std::os::raw::c_long; +pub type uintmax_t = ::std::os::raw::c_ulong; +pub type wchar_t = ::std::os::raw::c_int; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct max_align_t { + pub __clang_max_align_nonce1: ::std::os::raw::c_longlong, + pub __bindgen_padding_0: u64, + pub __clang_max_align_nonce2: f64, +} +pub const DLDeviceType_kDLCPU: DLDeviceType = 1; +pub const DLDeviceType_kDLGPU: DLDeviceType = 2; +pub const DLDeviceType_kDLCPUPinned: DLDeviceType = 3; +pub const DLDeviceType_kDLOpenCL: DLDeviceType = 4; +pub const DLDeviceType_kDLMetal: DLDeviceType = 8; +pub const DLDeviceType_kDLVPI: DLDeviceType = 9; +pub const DLDeviceType_kDLROCM: DLDeviceType = 10; +/// \brief The device type in DLContext. +pub type DLDeviceType = u32; +/// \brief A Device context for Tensor and operator. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLContext { + /// \brief The device type used in the device. + pub device_type: DLDeviceType, + /// \brief The device index + pub device_id: ::std::os::raw::c_int, +} +pub const DLDataTypeCode_kDLInt: DLDataTypeCode = 0; +pub const DLDataTypeCode_kDLUInt: DLDataTypeCode = 1; +pub const DLDataTypeCode_kDLFloat: DLDataTypeCode = 2; +/// \brief The type code options DLDataType. +pub type DLDataTypeCode = u32; +/// \brief The data type the tensor can hold. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLDataType { + /// \brief Type code of base types. + /// We keep it uint8_t instead of DLDataTypeCode for minimal memory + /// footprint, but the value should be one of DLDataTypeCode enum values. + /// + pub code: u8, + /// \brief Number of bits, common choices are 8, 16, 32. + pub bits: u8, + /// \brief Number of lanes in the type, used for vector types. + pub lanes: u16, +} +/// \brief Plain C Tensor object, does not manage memory. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLTensor { + /// \brief The opaque data pointer points to the allocated data. + /// This will be CUDA device pointer or cl_mem handle in OpenCL. + /// This pointer is always aligns to 256 bytes as in CUDA. + pub data: *mut ::std::os::raw::c_void, + /// \brief The device context of the tensor + pub ctx: DLContext, + /// \brief Number of dimensions + pub ndim: ::std::os::raw::c_int, + /// \brief The data type of the pointer + pub dtype: DLDataType, + /// \brief The shape of the tensor + pub shape: *mut i64, + /// \brief strides of the tensor, + /// can be NULL, indicating tensor is compact. + pub strides: *mut i64, + /// \brief The offset in bytes to the beginning pointer to data + pub byte_offset: u64, +} +/// \brief C Tensor object, manage memory of DLTensor. This data structure is +/// intended to faciliate the borrowing of DLTensor by another framework. It is +/// not meant to transfer the tensor. When the borrowing framework doesn't need +/// the tensor, it should call the deleter to notify the host that the resource +/// is no longer needed. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DLManagedTensor { + /// \brief DLTensor which is being memory managed + pub dl_tensor: DLTensor, + /// \brief the context of the original host framework of DLManagedTensor in + /// which DLManagedTensor is used in the framework. It can also be NULL. + pub manager_ctx: *mut ::std::os::raw::c_void, + /// \brief Destructor signature void (*)(void*) - this should be called + /// to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + /// if there is no way for the caller to provide a reasonable destructor. + pub deleter: ::std::option::Option, +} +/// \brief type of array index. +pub type tvm_index_t = i64; +pub const TVMDeviceExtType_kDLAOCL: TVMDeviceExtType = 5; +pub const TVMDeviceExtType_kDLSDAccel: TVMDeviceExtType = 6; +pub const TVMDeviceExtType_kDLVulkan: TVMDeviceExtType = 7; +pub const TVMDeviceExtType_kOpenGL: TVMDeviceExtType = 11; +pub const TVMDeviceExtType_kExtDev: TVMDeviceExtType = 12; +/// \brief Extension device types in TVM +pub type TVMDeviceExtType = u32; +pub const TVMTypeCode_kHandle: TVMTypeCode = 3; +pub const TVMTypeCode_kNull: TVMTypeCode = 4; +pub const TVMTypeCode_kTVMType: TVMTypeCode = 5; +pub const TVMTypeCode_kTVMContext: TVMTypeCode = 6; +pub const TVMTypeCode_kArrayHandle: TVMTypeCode = 7; +pub const TVMTypeCode_kNodeHandle: TVMTypeCode = 8; +pub const TVMTypeCode_kModuleHandle: TVMTypeCode = 9; +pub const TVMTypeCode_kFuncHandle: TVMTypeCode = 10; +pub const TVMTypeCode_kStr: TVMTypeCode = 11; +pub const TVMTypeCode_kBytes: TVMTypeCode = 12; +pub const TVMTypeCode_kNDArrayContainer: TVMTypeCode = 13; +pub const TVMTypeCode_kExtBegin: TVMTypeCode = 15; +pub const TVMTypeCode_kNNVMFirst: TVMTypeCode = 16; +pub const TVMTypeCode_kNNVMLast: TVMTypeCode = 20; +pub const TVMTypeCode_kExtReserveEnd: TVMTypeCode = 64; +pub const TVMTypeCode_kExtEnd: TVMTypeCode = 128; +/// \brief The type code in TVMType +/// \note TVMType is used in two places. +pub type TVMTypeCode = u32; +/// \brief The data type used in TVM Runtime. +/// +/// Examples +/// - float: type_code = 2, bits = 32, lanes=1 +/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4 +/// - int8: type_code = 0, bits = 8, lanes=1 +/// +/// \note Arguments TVM API function always takes bits=64 and lanes=1 +pub type TVMType = DLDataType; +/// \brief The Device information, abstract away common device types. +pub type TVMContext = DLContext; +/// \brief The tensor array stucture to TVM API. +pub type TVMArray = DLTensor; +/// \brief the array handle +pub type TVMArrayHandle = *mut TVMArray; +/// \brief Union type of values +/// being passed through API and function calls. +#[repr(C)] +#[derive(Copy, Clone)] +pub union TVMValue { + pub v_int64: i64, + pub v_float64: f64, + pub v_handle: *mut ::std::os::raw::c_void, + pub v_str: *const ::std::os::raw::c_char, + pub v_type: TVMType, + pub v_ctx: TVMContext, + _bindgen_union_align: u64, +} +/// \brief Byte array type used to pass in byte array +/// When kBytes is used as data type. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMByteArray { + pub data: *const ::std::os::raw::c_char, + pub size: usize, +} +/// \brief Handle to TVM runtime modules. +pub type TVMModuleHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to packed function handle. +pub type TVMFunctionHandle = *mut ::std::os::raw::c_void; +/// \brief Handle to hold return value. +pub type TVMRetValueHandle = *mut ::std::os::raw::c_void; +/// \brief The stream that is specific to device +/// can be NULL, which indicates the default one. +pub type TVMStreamHandle = *mut ::std::os::raw::c_void; +extern "C" { + /// \brief Used for implementing C API function. + /// Set last error message before return. + /// \param msg The error message to be set. + pub fn TVMAPISetLastError(msg: *const ::std::os::raw::c_char); +} +extern "C" { + /// \brief return str message of the last error + /// all function in this file will return 0 when success + /// and -1 when an error occured, + /// TVMGetLastError can be called to retrieve the error + /// + /// this function is threadsafe and can be called by different thread + /// \return error info + pub fn TVMGetLastError() -> *const ::std::os::raw::c_char; +} +extern "C" { + /// \brief Load module from file. + /// \param file_name The file name to load the module from. + /// \param format The format of the module. + /// \param out The result module + /// + /// \return 0 when success, -1 when failure happens + /// \note The resulting module do not contain import relation. + /// It can be reconstructed by TVMModImport. + pub fn TVMModLoadFromFile( + file_name: *const ::std::os::raw::c_char, + format: *const ::std::os::raw::c_char, + out: *mut TVMModuleHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Add dep to mod's dependency. + /// This allows functions in this module to use modules. + /// + /// \param mod The module handle. + /// \param dep The dependent module to be imported. + /// \return 0 when success, -1 when failure happens + pub fn TVMModImport(mod_: TVMModuleHandle, dep: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get function from the module. + /// \param mod The module handle. + /// \param func_name The name of the function. + /// \param query_imports Whether to query imported modules + /// \param out The result function, can be NULL if it is not available. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMModGetFunction( + mod_: TVMModuleHandle, + func_name: *const ::std::os::raw::c_char, + query_imports: ::std::os::raw::c_int, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free front-end extension type resource. + /// \param handle The extension handle. + /// \param type_code The type of of the extension type. + /// \return 0 when success, -1 when failure happens + pub fn TVMExtTypeFree( + handle: *mut ::std::os::raw::c_void, + type_code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the Module + /// \param mod The module to be freed. + /// + /// \note This may not free up the module's resources. + /// If there is active TVMFunctionHandle uses the module + /// Or if this module is imported by another active module. + /// + /// The all functions remains valid until TVMFuncFree is called. + /// \return 0 when success, -1 when failure happens + pub fn TVMModFree(mod_: TVMModuleHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the function when it is no longer needed. + /// \param func The function handle + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncFree(func: TVMFunctionHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Call a Packed TVM Function. + /// + /// \param func node handle of the function. + /// \param arg_values The arguments + /// \param type_codes The type codes of the arguments + /// \param num_args Number of arguments. + /// + /// \param ret_val The return value. + /// \param ret_type_code the type code of return value. + /// + /// \return 0 when success, -1 when failure happens + /// \note TVM calls always exchanges with type bits=64, lanes=1 + /// + /// \note API calls always exchanges with type bits=64, lanes=1 + /// If API call returns container handles (e.g. FunctionHandle) + /// these handles should be managed by the front-end. + /// The front-end need to call free function (e.g. TVMFuncFree) + /// to free these handles. + pub fn TVMFuncCall( + func: TVMFunctionHandle, + arg_values: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the return value of TVMPackedCFunc. + /// + /// This function is called by TVMPackedCFunc to set the return value. + /// When this function is not called, the function returns null by default. + /// + /// \param ret The return value handle, pass by ret in TVMPackedCFunc + /// \param value The value to be returned. + /// \param type_code The type of the value to be returned. + /// \param num_ret Number of return values, for now only 1 is supported. + pub fn TVMCFuncSetReturn( + ret: TVMRetValueHandle, + value: *mut TVMValue, + type_code: *mut ::std::os::raw::c_int, + num_ret: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Inplace translate callback argument value to return value. + /// This is only needed for non-POD arguments. + /// + /// \param value The value to be translated. + /// \param code The type code to be translated. + /// \note This function will do a shallow copy when necessary. + /// + /// \return 0 when success, -1 when failure happens. + pub fn TVMCbArgToReturn( + value: *mut TVMValue, + code: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +/// \brief C type of packed function. +/// +/// \param args The arguments +/// \param type_codes The type codes of the arguments +/// \param num_args Number of arguments. +/// \param ret The return value handle. +/// \param resource_handle The handle additional resouce handle from fron-end. +/// \return 0 if success, -1 if failure happens, set error via TVMAPISetLastError. +/// \sa TVMCFuncSetReturn +pub type TVMPackedCFunc = ::std::option::Option< + unsafe extern "C" fn( + args: *mut TVMValue, + type_codes: *mut ::std::os::raw::c_int, + num_args: ::std::os::raw::c_int, + ret: TVMRetValueHandle, + resource_handle: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +/// \brief C callback to free the resource handle in C packed function. +/// \param resource_handle The handle additional resouce handle from fron-end. +pub type TVMPackedCFuncFinalizer = + ::std::option::Option; +/// \brief Signature for extension function declarer. +/// +/// TVM call this function to get the extension functions +/// The declarer will call register_func to register function and their name. +/// +/// \param register_func_handle The register function +/// \return 0 if success, -1 if failure happens +pub type TVMExtensionFuncDeclarer = ::std::option::Option< + unsafe extern "C" fn(register_func_handle: TVMFunctionHandle) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Wrap a TVMPackedCFunc to become a FunctionHandle. + /// + /// The resource_handle will be managed by TVM API, until the function is no longer used. + /// + /// \param func The packed C function. + /// \param resource_handle The resource handle from front-end, can be NULL. + /// \param fin The finalizer on resource handle when the FunctionHandle get freed, can be NULL + /// \param out the result function handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncCreateFromCFunc( + func: TVMPackedCFunc, + resource_handle: *mut ::std::os::raw::c_void, + fin: TVMPackedCFuncFinalizer, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Register the function to runtime's global table. + /// + /// The registered function then can be pulled by the backend by the name. + /// + /// \param name The name of the function. + /// \param f The function to be registered. + /// \param override Whether allow override already registered function. + pub fn TVMFuncRegisterGlobal( + name: *const ::std::os::raw::c_char, + f: TVMFunctionHandle, + override_: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Get a global function. + /// + /// \param name The name of the function. + /// \param out the result function pointer, NULL if it does not exist. + /// + /// \note The function handle of global function is managed by TVM runtime, + /// So TVMFuncFree is should not be called when it get deleted. + pub fn TVMFuncGetGlobal( + name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief List all the globally registered function name + /// \param out_size The number of functions + /// \param out_array The array of function names. + /// \return 0 when success, -1 when failure happens + pub fn TVMFuncListGlobalNames( + out_size: *mut ::std::os::raw::c_int, + out_array: *mut *mut *const ::std::os::raw::c_char, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Allocate a nd-array's memory, + /// including space of shape, of given spec. + /// + /// \param shape The shape of the array, the data content will be copied to out + /// \param ndim The number of dimension of the array. + /// \param dtype_code The type code of the dtype + /// \param dtype_bits The number of bits of dtype + /// \param dtype_lanes The number of lanes in the dtype. + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param out The output handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayAlloc( + shape: *const tvm_index_t, + ndim: ::std::os::raw::c_int, + dtype_code: ::std::os::raw::c_int, + dtype_bits: ::std::os::raw::c_int, + dtype_lanes: ::std::os::raw::c_int, + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free the TVM Array. + /// \param handle The array handle to be freed. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFree(handle: TVMArrayHandle) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data from CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy array data to CPU byte array. + /// \param handle The array handle. + /// \param data the data pointer + /// \param nbytes The number of bytes to copy. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyToBytes( + handle: TVMArrayHandle, + data: *mut ::std::os::raw::c_void, + nbytes: usize, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Copy the array, both from and to must be valid during the copy. + /// \param from The array to be copied from. + /// \param to The target space. + /// \param stream The stream where the copy happens, can be NULL. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayCopyFromTo( + from: TVMArrayHandle, + to: TVMArrayHandle, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce an array from the DLManagedTensor that shares data memory + /// with the DLManagedTensor. + /// \param from The source DLManagedTensor. + /// \param out The output array handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayFromDLPack( + from: *mut DLManagedTensor, + out: *mut TVMArrayHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Produce a DLMangedTensor from the array that shares data memory with + /// the array. + /// \param from The source array. + /// \param out The DLManagedTensor handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMArrayToDLPack( + from: TVMArrayHandle, + out: *mut *mut DLManagedTensor, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Delete (free) a DLManagedTensor's data. + /// \param dltensor Pointer to the DLManagedTensor. + pub fn TVMDLManagedTensorCallDeleter(dltensor: *mut DLManagedTensor); +} +extern "C" { + /// \brief Create a new runtime stream. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param out The new stream handle + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamCreate( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + out: *mut TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Free a created stream handle. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param stream The stream to be freed + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamFree( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Set the runtime stream of current thread to be stream. + /// The subsequent calls to the same device_type + /// will use the setted stream handle. + /// The specific type of stream is runtime device dependent. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param handle The stream handle. + /// \return 0 when success, -1 when failure happens + pub fn TVMSetStream( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + handle: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Wait until all computations on stream completes. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context. + /// \param stream The stream to be synchronized. + /// \return 0 when success, -1 when failure happens + pub fn TVMSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + stream: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Synchronize two streams of execution. + /// + /// \param device_type The device type of context + /// \param device_id The device id of context + /// \param src The source stream to synchronize. + /// \param dst The destination stream to synchronize. + /// \return 0 when success, -1 when failure happens + pub fn TVMStreamStreamSynchronize( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + src: TVMStreamHandle, + dst: TVMStreamHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function for modules to get function + /// from its environment mod_node (its imports and global function). + /// The user do should not call TVMFuncFree on func. + /// + /// \param mod_node The module handle. + /// \param func_name The name of the function. + /// \param out The result function. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendGetFuncFromEnv( + mod_node: *mut ::std::os::raw::c_void, + func_name: *const ::std::os::raw::c_char, + out: *mut TVMFunctionHandle, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to register system-wide library symbol. + /// + /// \param name The name of the symbol + /// \param ptr The symbol address. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRegisterSystemLibSymbol( + name: *const ::std::os::raw::c_char, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Backend function to allocate temporal workspace. + /// + /// \note The result allocate spaced is ensured to be aligned to kTempAllocaAlignment. + /// + /// \param nbytes The size of the space requested. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \param dtype_code_hint The type code of the array elements. Only used in + /// certain backends such as OpenGL. + /// \param dtype_bits_hint The type bits of the array elements. Only used in + /// certain backends such as OpenGL. + /// \return nullptr when error is thrown, a valid ptr if success + pub fn TVMBackendAllocWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + nbytes: u64, + dtype_code_hint: ::std::os::raw::c_int, + dtype_bits_hint: ::std::os::raw::c_int, + ) -> *mut ::std::os::raw::c_void; +} +extern "C" { + /// \brief Backend function to free temporal workspace. + /// + /// \param ptr The result allocated space pointer. + /// \param device_type The device type which the space will be allocated. + /// \param device_id The device id which the space will be allocated. + /// \return 0 when no error is thrown, -1 when failure happens + /// + /// \sa TVMBackendAllocWorkspace + pub fn TVMBackendFreeWorkspace( + device_type: ::std::os::raw::c_int, + device_id: ::std::os::raw::c_int, + ptr: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int; +} +/// \brief Environment for TVM parallel task. +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct TVMParallelGroupEnv { + /// \brief Auxiliary used for synchronization + pub sync_handle: *mut ::std::os::raw::c_void, + /// \brief total amount of task + pub num_task: i32, +} +/// \brief The callback function to execute a parallel lambda +/// \param task_id the task id of the function. +/// \param penv The parallel environment backs the execution. +/// \param cdata The supporting closure data. +pub type FTVMParallelLambda = ::std::option::Option< + unsafe extern "C" fn( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + cdata: *mut ::std::os::raw::c_void, + ) -> ::std::os::raw::c_int, +>; +extern "C" { + /// \brief Backend function for running parallel jobs. + /// + /// \param flambda The parallel function to be launched. + /// \param cdata The closure data. + /// \param num_task Number of tasks to launch, can be 0, means launch + /// with all available threads. + /// + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelLaunch( + flambda: FTVMParallelLambda, + cdata: *mut ::std::os::raw::c_void, + num_task: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief BSP barrrier between parallel threads + /// \param task_id the task id of the function. + /// \param penv The parallel environment backs the execution. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendParallelBarrier( + task_id: ::std::os::raw::c_int, + penv: *mut TVMParallelGroupEnv, + ) -> ::std::os::raw::c_int; +} +extern "C" { + /// \brief Simple static initialization fucntion. + /// Run f once and set handle to be not null. + /// This function is mainly used for test purpose. + /// + /// \param handle An global address to indicate f + /// \param f The function to be ran + /// \param cdata The closure data to pass to the function. + /// \param nbytes Number of bytes in the closure data. + /// \return 0 when no error is thrown, -1 when failure happens + pub fn TVMBackendRunOnce( + handle: *mut *mut ::std::os::raw::c_void, + f: ::std::option::Option< + unsafe extern "C" fn(arg1: *mut ::std::os::raw::c_void) -> ::std::os::raw::c_int, + >, + cdata: *mut ::std::os::raw::c_void, + nbytes: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} diff --git a/rust/src/runtime/graph.rs b/rust/src/runtime/graph.rs new file mode 100644 index 000000000000..08fbd5938380 --- /dev/null +++ b/rust/src/runtime/graph.rs @@ -0,0 +1,472 @@ +use std::{cmp, collections::HashMap, convert::TryFrom, iter::FromIterator, mem, str}; + +use nom::{alpha1, digit1, le_i32, le_i64, le_u16, le_u32, le_u64, le_u8, types::CompleteStr}; +use serde; +use serde_json; + +use super::{DataType, Module, Storage, TVMArgValue, TVMContext, Tensor}; +use errors::{Error, ErrorKind, Result}; +use ffi::runtime::{ + DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLDataTypeCode_kDLUInt, DLTensor, +}; + +// Magic number for NDArray file. @see `kTVMNDArrayMagic` in `ndarray.h` +const _NDARRAY_MAGIC: u64 = 0xDD5E40F096B4A13F; +// Magic number for NDArray list file. @see `kTVMNDArrayListMagic` in `graph_runtime.h` +const _NDARRAY_LIST_MAGIC: u64 = 0xF7E58D4F05049CB7; + +/// A TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// let graph_json = fs::read_to_string("graph.json")).unwrap(); +/// let graph = Graph::try_from(&graph_json).unwrap(); +/// ``` +#[derive(Serialize, Deserialize, Debug)] +pub struct Graph { + pub nodes: Vec, + pub arg_nodes: Vec, + pub heads: Vec, + pub node_row_ptr: Option>, + pub attrs: Option>, +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Entry { + pub id: usize, + pub index: usize, + pub version: usize, +} + +impl Graph { + fn entry_index(&self, entry: &Entry) -> Result { + self + .node_row_ptr + .as_ref() + .map(|nrp| nrp[entry.id] + entry.index) + .ok_or("Missing node_row_ptr.".into()) + } + + /// Attempt to deserialize a JSON attribute to a type `T`. + fn get_attr(&self, attr: &str) -> Result { + Ok(serde_json::from_value::( + self + .attrs + .as_ref() + .ok_or(ErrorKind::GraphFormatError( + "Missing graph attrs".to_string(), + ))? + .get(attr) + .ok_or(ErrorKind::GraphFormatError(format!( + "Missing {} attr", + attr + )))? + .to_owned(), + )?) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub struct Node { + pub op: String, + pub name: String, + pub inputs: Vec, + pub attrs: Option>, + pub control_deps: Option>, +} + +struct NodeAttrs { + func_name: String, + num_outputs: usize, + flatten_data: bool, +} + +impl Node { + fn parse_attrs(&self) -> Result { + let attrs = self + .attrs + .as_ref() + .ok_or(format!("Missing node.attrs for `{}`", self.name))?; + let func_name = attrs + .get("func_name") + .ok_or(format!("Node `{}` is missing attrs.func_name", self.name))? + .to_string(); + let num_outputs = attrs + .get("num_outputs") + .ok_or(format!("Node `{}` is missing attrs.num_outputs", self.name))? + .parse::()?; + let flatten_data = attrs + .get("flatten_data") + .ok_or(format!( + "Node `{}` is missing attrs.flatten_data", + self.name + ))? + .parse::()? + == 1; + Ok(NodeAttrs { + func_name, + num_outputs, + flatten_data, + }) + } +} + +impl<'a> TryFrom<&'a String> for Graph { + type Error = Error; + fn try_from(graph_json: &String) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +impl<'a> TryFrom<&'a str> for Graph { + type Error = Error; + fn try_from(graph_json: &'a str) -> Result { + let graph = serde_json::from_str(graph_json)?; + Ok(graph) + } +} + +/// A executor for a TVM computation graph. +/// +/// # Examples +/// +/// ``` +/// use ndarray::Array; +/// +/// let syslib = SystemLibModule::default(); // a provider of TVM functions +/// +/// let mut params_bytes = Vec::new(); +/// fs::File::open("graph.params").unwrap().read_to_end(&mut params_bytes).unwrap(); +/// let params = tvm::runtime::load_param_dict(¶ms_bytes).unwrap(); +/// +/// let graph = Graph::try_from(&fs::read_to_string("graph.json").unwrap()).unwrap(); +/// +/// let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); +/// exec.load_params(params); +/// +/// let x = Array::from_vec(vec![1f32, 2., 3., 4.]); +/// exec.set_input("data", x.into()); +/// exec.run(); +/// let output = exec.get_output(0).unwrap(); +/// +/// println!("{:#?}", Array::try_from(output).unwrap()); +/// ``` +pub struct GraphExecutor<'m, 't> { + graph: Graph, + op_execs: Vec>, + tensors: Vec>, +} + +unsafe impl<'m, 't> Send for GraphExecutor<'m, 't> {} + +impl<'m, 't> GraphExecutor<'m, 't> { + pub fn new(graph: Graph, lib: &'m M) -> Result { + let tensors = Self::setup_storages(&graph)?; + Ok(GraphExecutor { + op_execs: Self::setup_op_execs(&graph, lib, &tensors)?, + tensors: tensors, + graph: graph, + }) + } + + /// Runs the computation graph. + pub fn run(&self) { + self.op_execs.iter().for_each(|op_exec| { + op_exec(); + }); + } + + /// Allocates `Storages` for each `storage_id` and returns `Tensor`s to hold each output. + fn setup_storages<'a>(graph: &'a Graph) -> Result>> { + let storage_ids = graph.get_attr::<(String, Vec)>("storage_id")?.1; + let shapes = graph.get_attr::<(String, Vec>)>("shape")?.1; + let dtypes = graph + .get_attr::<(String, Vec)>("dltype")? + .1 + .iter() + .map(|dltype| { + if let Ok((_, dtype)) = tvm_str_to_type(CompleteStr(dltype)) { + Ok(dtype) + } else { + Err(ErrorKind::GraphFormatError(format!("Invalid dltype: {}", dltype).to_string()).into()) + } + }) + .collect::>>()?; + + let align = dtypes.iter().map(|dtype| dtype.bits as usize).max(); + let mut storage_num_bytes = vec![0usize; *storage_ids.iter().max().unwrap_or(&1) + 1]; + for (i, &storage_id) in storage_ids.iter().enumerate() { + let dtype_size = dtypes[i].bits * dtypes[i].lanes >> 3; + let nbytes = dtype_size * shapes[i].iter().product::() as usize; + storage_num_bytes[storage_id] = cmp::max(nbytes, storage_num_bytes[storage_id]); + } + + let mut storages: Vec = storage_num_bytes + .into_iter() + .map(|nbytes| Storage::new(nbytes, align)) + .collect::>>()?; + + let tensors = izip!(storage_ids, shapes, dtypes) + .map(|(storage_id, shape, dtype)| { + let storage = storages[storage_id].view(); + Tensor { + data: mem::replace(&mut storages[storage_id], storage), + ctx: TVMContext::default(), + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + } + }) + .collect(); + + Ok(tensors) + } + + /// Creates closures which represent the computation performed by this graph. + fn setup_op_execs( + graph: &Graph, + lib: &'m M, + tensors: &Vec>, + ) -> Result>> { + ensure!(graph.node_row_ptr.is_some(), "Missing node_row_ptr."); + let node_row_ptr = graph.node_row_ptr.as_ref().unwrap(); + + let mut op_execs = Vec::new(); + for (i, node) in graph.nodes.iter().enumerate() { + if node.op == "null" { + continue; + } + ensure!(node.op == "tvm_op", "Only TVM ops are supported."); + ensure!(node.attrs.is_some(), "Missing node attrs."); + + let attrs = node.parse_attrs()?; + + if attrs.func_name == "__nop" { + continue; + } + + let func = lib + .get_function(&attrs.func_name) + .ok_or(format!("Missing function {}", attrs.func_name))?; + let arg_indices = node + .inputs + .iter() + .map(|entry| graph.entry_index(entry)) + .chain((0..attrs.num_outputs).map(|oi| Ok(node_row_ptr[i].clone() + oi))); + + let dl_tensors = arg_indices + .map(|idx| { + let tensor = &tensors[idx?]; + Ok(if attrs.flatten_data { + DLTensor::from_tensor(tensor, true /* flatten */) + } else { + DLTensor::from(tensor) + }) + }) + .collect::>>() + .unwrap(); + let op: Box = box move || { + let args = dl_tensors + .iter() + .map(|t| t.into()) + .collect::>(); + func(args.as_slice()); + }; + op_execs.push(op); + } + Ok(op_execs) + } + + pub fn load_params(&mut self, params: HashMap>) { + params.into_iter().for_each(|(name, param)| { + self.set_input(name, param); + }) + } + + pub fn set_input>(&mut self, name: S, value: Tensor<'t>) { + if let Some(idx) = self.get_input_index(name.as_ref()) { + // TODO: consider `new_with_params` to avoid ever allocating + let ptr = self.tensors[idx].data.as_ptr(); + let mut to_replace = self.tensors.iter_mut().filter(|t| t.data.as_ptr() == ptr); + let mut owner = to_replace.nth(0).unwrap(); + if value.data.is_owned() { + // FIXME: for no-copy, need setup_op_execs to not capture tensor ptr + // mem::replace(&mut (*owner), value); + // to_replace.for_each(|t| { + // panic!("replacing"); + // t.data = owner.data.view(); + // }); + owner.copy(&value); + } else { + owner.copy(&value); + } + } else { + println!("Unexpected input `{}`", name.as_ref()); + } + } + + /// Returns the graph input with name `name`, if it exists. + pub fn get_input>(&mut self, name: S) -> Option<&Tensor> { + self + .get_input_index(name.as_ref()) + .and_then(move |idx| Some(&self.tensors[idx])) + } + + /// Returns the graph output with index `index`, if it exists. + pub fn get_output(&self, idx: usize) -> Option<&Tensor> { + let graph = &self.graph; + graph.heads.get(idx).and_then(|entry| { + graph + .entry_index(entry) + .map(|idx| self.tensors.get(idx)) + .unwrap_or(None) + }) + } + + /// Returns the index for graph input with name `name`, if it exists. + pub fn get_input_index>(&self, name: S) -> Option { + let graph = &self.graph; + (0..graph.nodes.len()) + .skip_while(|&i| graph.nodes[i].name != name.as_ref()) + .nth(0) + .and_then(|i| { + if graph.arg_nodes.iter().any(|&id| id == i) { + graph.node_row_ptr.as_ref().map(|nrp| nrp[i]) + } else { + None + } + }) + } +} + +/// Converts a string to TVM DLDataTypeCode. @see `String2TVMType` in packed_func.h +named!( + tvm_str_to_type, + do_parse!( + type_name: alpha1 >> + bits: digit1 >> + lanes: opt!(tuple!(tag!("x"), digit1)) >> + (DataType { + code: match type_name { + CompleteStr("int") => DLDataTypeCode_kDLInt, + CompleteStr("uint") => DLDataTypeCode_kDLUInt, + CompleteStr("float") => DLDataTypeCode_kDLFloat, + _ => DLDataTypeCode_kDLFloat, + } as usize, + bits: bits.parse::().unwrap() as usize, + lanes: match lanes { + Some(lanes) => lanes.1.parse::().unwrap() as usize, + None => 1, + }, + }) + ) +); + +/// Converts a bytes to String. +named!( + name, + map_res!(length_bytes!(le_u64), |b: &[u8]| String::from_utf8( + b.to_vec() + )) +); + +/// Parses a TVMContext +named!( + tvm_ctx<&[u8], TVMContext>, + do_parse!( + device_type: le_u32 >> + device_id: le_i32 >> + (TVMContext { device_type: device_type as usize, device_id: device_id as usize }) + ) +); + +/// Parses a DataType +named!( + data_type<&[u8], DataType>, + do_parse!( + code: le_u8 >> + bits: le_u8 >> + lanes: le_u16 >> + (DataType { code: code as usize, bits: bits as usize, lanes: lanes as usize }) + ) +); + +/// Parses a Tensor from a TVM array file. +named!( + tensor, + do_parse!( + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> ctx: tvm_ctx + >> ndim: le_u32 + >> dtype: data_type + >> shape: count!(map!(le_i64, |sz| sz as i64), ndim as usize) + >> length: le_i64 + >> data: take!(length) + >> (Tensor { + data: Storage::from(data), + ctx: ctx, + dtype: dtype, + size: shape.iter().product::() as usize, + shape: shape, + strides: None, + byte_offset: 0, + }) + ) +); + +/// Parses a graph params dict from a params binary file. +named!( + parse_param_dict>, + do_parse!( + take!(8) + >> bits!(tag_bits!(u64, 64, 0)) + >> names: length_count!(le_u64, name) + >> tensors: length_count!(le_u64, tensor) + >> (HashMap::from_iter(names.into_iter().zip(tensors.into_iter()))) + ) +); + +/// Loads a param dict saved using `nnvm.compiler.save_param_dict`. +pub fn load_param_dict(bytes: &[u8]) -> Result> { + if let Ok((remaining_bytes, param_dict)) = parse_param_dict(bytes) { + if remaining_bytes.len() > 0 { + bail!(ErrorKind::LoadGraphParamsError("extra input".to_string())) + } else { + Ok(param_dict) + } + } else { + bail!(ErrorKind::LoadGraphParamsError( + "invalid parameters file".to_string() + )) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_str_to_type() { + assert_eq!( + tvm_str_to_type(CompleteStr("float24")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLFloat as usize, + bits: 24, + lanes: 1 + } + ); + assert_eq!( + tvm_str_to_type(CompleteStr("uint111x44")).unwrap().1, + DataType { + code: DLDataTypeCode_kDLUInt as usize, + bits: 111, + lanes: 44 + } + ); + } +} diff --git a/rust/src/runtime/mod.rs b/rust/src/runtime/mod.rs new file mode 100644 index 000000000000..bdf7094113d8 --- /dev/null +++ b/rust/src/runtime/mod.rs @@ -0,0 +1,25 @@ +mod allocator; +mod array; +mod module; +#[macro_use] +mod packed_func; +mod graph; +#[cfg(target_env = "sgx")] +#[macro_use] +pub mod sgx; +mod threading; +mod workspace; + +use std::os::raw::c_char; + +pub use self::{array::*, graph::*, module::*, packed_func::*, threading::*, workspace::*}; + +#[no_mangle] +pub extern "C" fn TVMAPISetLastError(cmsg: *const c_char) { + #[cfg(not(target_env = "sgx"))] + unsafe { + panic!(std::ffi::CStr::from_ptr(cmsg).to_str().unwrap()); + } + #[cfg(target_env = "sgx")] + ocall_packed!("__sgx_set_last_error__", cmsg); +} diff --git a/rust/src/runtime/module.rs b/rust/src/runtime/module.rs new file mode 100644 index 000000000000..2594756d9885 --- /dev/null +++ b/rust/src/runtime/module.rs @@ -0,0 +1,46 @@ +use std::{ + collections::HashMap, convert::AsRef, ffi::CStr, os::raw::c_char, string::String, sync::Mutex, +}; + +use ffi::runtime::BackendPackedCFunc; +use runtime::packed_func::{wrap_backend_packed_func, PackedFunc}; + +pub trait Module { + fn get_function>(&self, name: S) -> Option; +} + +pub struct SystemLibModule; + +lazy_static! { + static ref SYSTEM_LIB_FUNCTIONS: Mutex> = + Mutex::new(HashMap::new()); +} + +impl Module for SystemLibModule { + fn get_function>(&self, name: S) -> Option { + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .get(name.as_ref()) + .map(|func| wrap_backend_packed_func(func.to_owned())) + } +} + +impl Default for SystemLibModule { + fn default() -> Self { + SystemLibModule {} + } +} + +#[no_mangle] +pub extern "C" fn TVMBackendRegisterSystemLibSymbol( + cname: *const c_char, + func: BackendPackedCFunc, +) -> i32 { + let name = unsafe { CStr::from_ptr(cname).to_str().unwrap() }; + SYSTEM_LIB_FUNCTIONS + .lock() + .unwrap() + .insert(name.to_string(), func); + return 0; +} diff --git a/rust/src/runtime/packed_func.rs b/rust/src/runtime/packed_func.rs new file mode 100644 index 000000000000..030d677329c0 --- /dev/null +++ b/rust/src/runtime/packed_func.rs @@ -0,0 +1,286 @@ +use std::{any::Any, convert::TryFrom, marker::PhantomData, os::raw::c_void}; + +use ffi::runtime::{ + BackendPackedCFunc, DLDataTypeCode_kDLFloat, DLDataTypeCode_kDLInt, DLTensor, + TVMTypeCode_kArrayHandle, TVMTypeCode_kHandle, TVMValue, +}; + +use errors::*; + +pub type PackedFunc = Box TVMRetValue + Send + Sync>; + +/// Calls a packed function and returns a `TVMRetValue`. +/// +/// # Example +/// +/// `call_packed!(my_tvm_func, &mut arg1, &mut arg2)` +#[macro_export] +macro_rules! call_packed { + ($fn:expr, $($args:expr),+) => { + $fn(&[$($args.into(),)+]) + }; + ($fn:expr) => { + $fn(&Vec::new()) + }; +} + +/// A borrowed TVMPODValue. Can be constructed using `into()` but the preferred way +/// to obtain a `TVMArgValue` is automatically via `call_packed!`. +#[derive(Clone, Copy)] +pub struct TVMArgValue<'a> { + _lifetime: PhantomData<&'a ()>, + pub(crate) value: TVMValue, + pub(crate) type_code: i64, +} + +impl<'a> TVMArgValue<'a> { + pub fn new(value: TVMValue, type_code: i64) -> Self { + TVMArgValue { + _lifetime: PhantomData, + value: value, + type_code: type_code, + } + } +} + +/// Creates a conversion to a `TVMArgValue` for a primitive type and DLDataTypeCode. +macro_rules! impl_prim_tvm_arg { + ($type:ty, $field:ident, $code:expr, $as:ty) => { + impl<'a> From<$type> for TVMArgValue<'a> { + fn from(val: $type) -> Self { + TVMArgValue { + value: TVMValue { $field: val as $as }, + type_code: $code as i64, + _lifetime: PhantomData, + } + } + } + }; + ($type:ty, $field:ident, $code:expr) => { + impl_prim_tvm_arg!($type, $field, $code, $type); + }; + ($type:ty,v_int64) => { + impl_prim_tvm_arg!($type, v_int64, DLDataTypeCode_kDLInt, i64); + }; + ($type:ty,v_float64) => { + impl_prim_tvm_arg!($type, v_float64, DLDataTypeCode_kDLFloat, f64); + }; +} + +impl_prim_tvm_arg!(f32, v_float64); +impl_prim_tvm_arg!(f64, v_float64); +impl_prim_tvm_arg!(i8, v_int64); +impl_prim_tvm_arg!(u8, v_int64); +impl_prim_tvm_arg!(i32, v_int64); +impl_prim_tvm_arg!(u32, v_int64); +impl_prim_tvm_arg!(i64, v_int64); +impl_prim_tvm_arg!(u64, v_int64); +impl_prim_tvm_arg!(bool, v_int64); + +/// Creates a conversion to a `TVMArgValue` for an object handle. +impl<'a, T> From<*const T> for TVMArgValue<'a> { + fn from(ptr: *const T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut T as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +/// Creates a conversion to a `TVMArgValue` for a mutable object handle. +impl<'a, T> From<*mut T> for TVMArgValue<'a> { + fn from(ptr: *mut T) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: ptr as *mut c_void, + }, + type_code: TVMTypeCode_kHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a mut DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a mut DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *mut _ as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +impl<'a> From<&'a DLTensor> for TVMArgValue<'a> { + fn from(arr: &'a DLTensor) -> Self { + TVMArgValue { + value: TVMValue { + v_handle: arr as *const _ as *mut DLTensor as *mut c_void, + }, + type_code: TVMTypeCode_kArrayHandle as i64, + _lifetime: PhantomData, + } + } +} + +/// An owned TVMPODValue. Can be converted from a variety of primitive and object types. +/// Can be downcasted using `try_from` if it contains the desired type. +/// +/// # Example +/// +/// ``` +/// let a = 42u32; +/// let b: i64 = TVMRetValue::from(a).try_into().unwrap(); +/// +/// let s = "hello, world!"; +/// let t: TVMRetValue = s.into(); +/// assert_eq!(String::try_from(t).unwrap(), s); +/// ``` +pub struct TVMRetValue { + /// A primitive return value, if any. + prim_value: u64, + /// An object return value, if any. + box_value: Box, + /// The DLDataTypeCode which determines whether `prim_value` or `box_value` is in use. + type_code: i64, +} + +#[cfg(target_env = "sgx")] +impl TVMRetValue { + pub(crate) fn from_tvm_value(value: TVMValue, type_code: i64) -> Self { + unsafe { + Self { + prim_value: match type_code { + 0 | 1 => value.v_int64 as u64, + 2 => value.v_float64 as u64, + 3 | 7 | 8 | 9 | 10 => value.v_handle as u64, + 11 | 12 => value.v_str as u64, + _ => 0, + } as u64, + box_value: box (), + type_code: type_code, + } + } + } + + pub fn into_tvm_value(self) -> (TVMValue, i64) { + let val = match self.type_code { + 0 | 1 => TVMValue { + v_int64: self.prim_value.clone() as i64, + }, + 2 => TVMValue { + v_float64: self.prim_value.clone() as f64, + }, + 3 | 7 | 8 | 9 | 10 => TVMValue { + v_handle: Box::into_raw(self.box_value) as *mut c_void, + }, + 11 | 12 => TVMValue { + v_str: Box::into_raw(self.box_value) as *const _, + }, + _ => unreachable!(), + }; + (val, self.type_code) + } +} + +impl Default for TVMRetValue { + fn default() -> Self { + TVMRetValue { + prim_value: 0, + box_value: box (), + type_code: 0, + } + } +} + +macro_rules! impl_prim_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: val as u64, + box_value: box (), + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if ret.type_code == $code { + Ok(ret.prim_value as $type) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +macro_rules! impl_boxed_ret_value { + ($type:ty, $code:expr) => { + impl From<$type> for TVMRetValue { + fn from(val: $type) -> Self { + TVMRetValue { + prim_value: 0, + box_value: box val, + type_code: $code, + } + } + } + impl TryFrom for $type { + type Error = Error; + fn try_from(ret: TVMRetValue) -> Result<$type> { + if let Ok(val) = ret.box_value.downcast::<$type>() { + Ok(*val) + } else { + bail!(ErrorKind::TryFromTVMRetValueError( + stringify!($type).to_string(), + ret.type_code + )) + } + } + } + }; +} + +impl_prim_ret_value!(i8, 0); +impl_prim_ret_value!(u8, 1); +impl_prim_ret_value!(i16, 0); +impl_prim_ret_value!(u16, 1); +impl_prim_ret_value!(i32, 0); +impl_prim_ret_value!(u32, 1); +impl_prim_ret_value!(f32, 2); +impl_prim_ret_value!(i64, 0); +impl_prim_ret_value!(u64, 1); +impl_prim_ret_value!(f64, 2); +impl_prim_ret_value!(isize, 0); +impl_prim_ret_value!(usize, 1); +impl_boxed_ret_value!(String, 11); + +// @see `WrapPackedFunc` in `llvm_module.cc`. +pub(super) fn wrap_backend_packed_func(func: BackendPackedCFunc) -> PackedFunc { + box move |args: &[TVMArgValue]| { + func( + args + .iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args + .iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + ); + TVMRetValue::default() + } +} diff --git a/rust/src/runtime/sgx.rs b/rust/src/runtime/sgx.rs new file mode 100644 index 000000000000..bf9d54a4af65 --- /dev/null +++ b/rust/src/runtime/sgx.rs @@ -0,0 +1,82 @@ +use std::{ + ffi::CString, + os::raw::{c_char, c_int}, +}; + +use errors::Result; +use ffi::runtime::TVMValue; +use runtime::{threading::sgx_join_threads, SystemLibModule, TVMArgValue, TVMRetValue}; + +pub use runtime::threading::tvm_run_worker as run_worker; + +#[macro_export] +macro_rules! tvm_ocall { + ($func: expr) => { + match $func { + 0 => Ok(()), + err => Err(format!("SGX error: {}", err)), + } + }; +} + +pub type SgxStatus = u32; + +#[cfg(target_env = "sgx")] +extern "C" { + fn tvm_ocall_packed_func( + name: *const c_char, + arg_values: *const TVMValue, + type_codes: *const c_int, + num_args: c_int, + ret_val: *mut TVMValue, + ret_type_code: *mut c_int, + ) -> SgxStatus; +} + +pub fn ocall_packed_func>(fn_name: S, args: &[TVMArgValue]) -> Result { + let mut ret_val = TVMValue { v_int64: 0 }; + let ret_type_code = 0i64; + unsafe { + tvm_ocall!(tvm_ocall_packed_func( + CString::new(fn_name.as_ref()).unwrap().as_ptr(), + args + .iter() + .map(|ref arg| arg.value) + .collect::>() + .as_ptr(), + args + .iter() + .map(|ref arg| arg.type_code as i32) + .collect::>() + .as_ptr() as *const i32, + args.len() as i32, + &mut ret_val as *mut TVMValue, + &mut (ret_type_code as i32) as *mut c_int, + ))?; + } + Ok(TVMRetValue::from_tvm_value(ret_val, ret_type_code as i64)) +} + +#[macro_export] +macro_rules! ocall_packed { + ($fn_name:expr, $($args:expr),+) => { + ::runtime::sgx::ocall_packed_func($fn_name, &[$($args.into(),)+]) + .expect(concat!("Error calling `", $fn_name, "`")) + }; + ($fn_name:expr) => { + ::runtime::sgx::ocall_packed_func($fn_name, &Vec::new()) + .expect(concat!("Error calling `", $fn_name, "`")) + } +} + +pub fn shutdown() { + if env!("TVM_NUM_THREADS") != "0" { + sgx_join_threads() + } +} + +impl Drop for SystemLibModule { + fn drop(&mut self) { + shutdown() + } +} diff --git a/rust/src/runtime/threading.rs b/rust/src/runtime/threading.rs new file mode 100644 index 000000000000..693ebf7c4a33 --- /dev/null +++ b/rust/src/runtime/threading.rs @@ -0,0 +1,337 @@ +use std::{ + os::raw::{c_int, c_void}, + sync::{ + atomic::{AtomicUsize, Ordering, ATOMIC_USIZE_INIT}, + Arc, Barrier, + }, +}; + +#[cfg(not(target_env = "sgx"))] +use num_cpus; +#[cfg(not(target_env = "sgx"))] +use std::{ + env, + thread::{self, JoinHandle}, +}; + +#[cfg(target_env = "sgx")] +use std::{collections::VecDeque, ptr, sync::Mutex}; + +use bounded_spsc_queue::{self, Producer}; + +use super::super::errors::*; +use ffi::runtime::TVMParallelGroupEnv; + +#[cfg(target_env = "sgx")] +use super::{TVMArgValue, TVMRetValue}; + +type FTVMParallelLambda = + extern "C" fn(task_id: usize, penv: *const TVMParallelGroupEnv, cdata: *const c_void) -> i32; + +/// Holds a parallel job request made by a TVM library function. +struct Job { + cb: FTVMParallelLambda, + cdata: *const c_void, + req_num_tasks: usize, + pending: Arc, +} + +impl Job { + /// Splits this job into a number of `Task`s which can be scheduled. + fn tasks(&self, num_workers: usize) -> Vec { + let num_tasks = if self.req_num_tasks == 0 { + num_workers + } else { + self.req_num_tasks.min(num_workers) + }; + self.pending.store(num_tasks, Ordering::SeqCst); + + let barrier = Arc::new(Barrier::new(num_tasks)); + + (0..num_tasks) + .map(move |i| Task { + id: i, + flambda: self.cb, + penv: TVMParallelGroupEnv { + sync_handle: &Arc::clone(&barrier) as *const _ as *mut c_void, + num_task: num_tasks as i32, + }, + cdata: self.cdata, + pending: Arc::clone(&self.pending), + }) + .collect() + } + + /// Waits for all tasks in this `Job` to be completed. + fn wait(&self) -> Result<()> { + while self.pending.load(Ordering::Acquire) > 0 { + #[cfg(not(target_env = "sgx"))] + thread::yield_now(); + } + Ok(()) + } +} + +/// A chunk of work requested by a TVM function. +struct Task { + id: usize, + flambda: FTVMParallelLambda, + penv: TVMParallelGroupEnv, + cdata: *const c_void, + pending: Arc, +} +unsafe impl Send for Task {} +unsafe impl Sync for Task {} + +impl FnOnce<()> for Task { + type Output = i32; + extern "rust-call" fn call_once(self, _args: ()) -> Self::Output { + let status = (self.flambda)(self.id, &self.penv as *const _, self.cdata); + self.pending.fetch_sub(1, Ordering::AcqRel); + status + } +} + +#[derive(Default)] +struct Threads { + #[allow(unused)] + #[cfg(not(target_env = "sgx"))] + handles: Vec>, + queues: Vec>, +} + +impl<'a> Threads { + #[cfg(not(target_env = "sgx"))] + fn launch) + 'static + Copy>( + num_threads: usize, + cb: F, + ) -> Self { + let (handles, queues) = (0..num_threads) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + let handle = thread::spawn(move || cb(c.into())); + (handle, p) + }) + .unzip(); + Threads { + handles: handles, + queues: queues, + } + } + + #[cfg(target_env = "sgx")] + fn launch) + 'static + Copy>( + num_threads: usize, + _cb: F, + ) -> Self { + let mut consumer_queues = SGX_QUEUES.lock().unwrap(); + let queues = (0..num_threads) + .map(|_| { + let (p, c) = bounded_spsc_queue::make(2); + consumer_queues.push_back(c.into()); + p + }) + .collect(); + ocall_packed!("__sgx_thread_group_launch__", num_threads as u64); + Threads { queues: queues } + } +} + +struct ThreadPool { + num_workers: usize, + #[allow(unused)] + threads: Threads, +} + +thread_local!(static THREAD_POOL: ThreadPool = ThreadPool::new()); + +impl ThreadPool { + fn new() -> Self { + let num_workers = max_concurrency(); + ThreadPool { + num_workers: num_workers, + threads: Threads::launch(num_workers, ThreadPool::run_worker), + } + } + + fn launch(&self, job: Job) { + let mut tasks = job.tasks(self.num_workers + 1); + + for (i, task) in tasks.split_off(1).into_iter().enumerate() { + self.threads.queues[i].push(task); + } + + tasks.pop().unwrap()(); + job.wait().unwrap(); + } + + fn run_worker(queue: Consumer) { + loop { + let task = queue.pop(); + let result = task(); + if result == ::min_value() { + break; + } else if result != 0 { + panic!("Error running task."); + } + } + } +} + +// Send + Sync wrapper for bounded_spsc_queue::Consumer +struct Consumer { + consumer: bounded_spsc_queue::Consumer, +} +impl From> for Consumer { + fn from(c: bounded_spsc_queue::Consumer) -> Self { + Consumer { consumer: c } + } +} +impl Consumer { + fn pop(&self) -> T { + self.consumer.pop() + } +} +unsafe impl Send for Consumer {} +unsafe impl Sync for Consumer {} + +#[cfg(target_env = "sgx")] +lazy_static! { + /// Holds tasks for untrusted threads which re-enter the enclave to execute. + static ref SGX_QUEUES: Mutex>> = Mutex::new(VecDeque::new()); +} + +#[cfg(all(not(target_arch = "wasm32"), not(target_env = "sgx")))] +fn max_concurrency() -> usize { + if let Ok(threads_str) = env::var("TVM_NUM_THREADS").or(env::var("OMP_NUM_THREADS")) { + if let Ok(threads) = usize::from_str_radix(&threads_str, 10) { + return threads; + } + } + num_cpus::get_physical() +} + +#[cfg(target_env = "sgx")] +fn max_concurrency() -> usize { + usize::from_str_radix(env!("TVM_NUM_THREADS"), 10).unwrap_or(1) +} + +#[cfg(target_arch = "wasm32")] +fn max_concurrency() -> usize { + 0 // wasm doesn't support threads yet +} + +#[cfg(target_env = "sgx")] +pub fn tvm_run_worker(_args: &[TVMArgValue]) -> TVMRetValue { + let q = { + let mut qs = SGX_QUEUES.lock().unwrap(); + qs.pop_front() + // `qs: MutexGuard` needs to be dropped here since `run_worker` won't return + }; + if let Some(q) = q { + ThreadPool::run_worker(q); + } + TVMRetValue::default() +} + +#[no_mangle] +pub extern "C" fn TVMBackendParallelLaunch( + cb: FTVMParallelLambda, + cdata: *const c_void, + num_task: usize, +) -> c_int { + if max_concurrency() == 0 { + let penv = TVMParallelGroupEnv { + sync_handle: 0 as *mut c_void, + num_task: 1, + }; + cb(0, &penv as *const _, cdata); + } else { + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: cb, + cdata: cdata, + req_num_tasks: num_task, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + } + return 0; +} + +#[cfg(target_env = "sgx")] +pub(crate) fn sgx_join_threads() { + extern "C" fn poison_pill( + _task_id: usize, + _penv: *const TVMParallelGroupEnv, + _cdata: *const c_void, + ) -> i32 { + ::min_value() + } + + THREAD_POOL.with(|pool| { + pool.launch(Job { + cb: poison_pill, + cdata: ptr::null(), + req_num_tasks: 0, + pending: Arc::new(ATOMIC_USIZE_INIT), + }); + }); + ocall_packed!("__sgx_thread_group_join__", 0); +} + +// @see https://github.com/dmlc/tvm/issues/988 for information on why this function is used. +#[no_mangle] +pub extern "C" fn TVMBackendParallelBarrier(_task_id: usize, penv: *const TVMParallelGroupEnv) { + let barrier: &Arc = unsafe { &*((*penv).sync_handle as *const Arc) }; + barrier.wait(); +} + +#[cfg(test)] +mod tests { + use std::{ptr, thread, time::Duration}; + + use super::*; + + #[test] + fn test_max_concurrency() { + env::set_var("TVM_NUM_THREADS", "42"); + env::set_var("OMP_NUM_THREADS", "24"); + assert_eq!(max_concurrency(), 42); + env::remove_var("TVM_NUM_THREADS"); + assert_eq!(max_concurrency(), 24); + } + + extern "C" fn flambda( + task_id: usize, + penv: *const TVMParallelGroupEnv, + cdata: *const c_void, + ) -> i32 { + if cdata == ptr::null() { + return 0; + } + unsafe { + let &(ref counter, ref task_ids_sum) = &*(cdata as *const (AtomicUsize, AtomicUsize)); + thread::sleep(Duration::from_millis(50 * task_id as u64)); + counter.fetch_add(1, Ordering::SeqCst); + task_ids_sum.fetch_add(task_id, Ordering::SeqCst); + assert_eq!((*penv).num_task, 3); + } + 0 + } + + #[test] + fn test_parallel_launch() { + TVMBackendParallelLaunch(flambda, ptr::null(), 6); + let counter = ATOMIC_USIZE_INIT; + let task_ids_sum = ATOMIC_USIZE_INIT; + let cdata = (counter, task_ids_sum); + let num_tasks = 3; + TVMBackendParallelLaunch(flambda, &cdata as *const _ as *const c_void, num_tasks); + assert_eq!(cdata.0.load(Ordering::SeqCst), num_tasks); + assert_eq!( + cdata.1.load(Ordering::SeqCst), + (0..num_tasks).sum::() + ); + } +} diff --git a/rust/src/runtime/workspace.rs b/rust/src/runtime/workspace.rs new file mode 100644 index 000000000000..d0e6d8c89255 --- /dev/null +++ b/rust/src/runtime/workspace.rs @@ -0,0 +1,119 @@ +use std::{ + cell::RefCell, + os::raw::{c_int, c_void}, + ptr, +}; + +use super::allocator::Allocation; +use errors::*; + +const WS_ALIGN: usize = 64; // taken from `kTempAllocaAlignment` in `device_api.h` + +struct WorkspacePool { + workspaces: Vec, + free: Vec, + in_use: Vec, +} + +impl WorkspacePool { + fn new() -> Self { + WorkspacePool { + workspaces: Vec::new(), + free: Vec::new(), + in_use: Vec::new(), + } + } + + fn alloc_new(&mut self, size: usize) -> Result<*mut u8> { + self.workspaces.push(Allocation::new(size, Some(WS_ALIGN))?); + self.in_use.push(self.workspaces.len() - 1); + Ok(self.workspaces[self.workspaces.len() - 1].as_mut_ptr()) + } + + fn alloc(&mut self, size: usize) -> Result<*mut u8> { + if self.free.len() == 0 { + return self.alloc_new(size); + } + let idx = self + .free + .iter() + .fold(None, |cur_ws_idx: Option, &idx| { + let ws_size = self.workspaces[idx].size(); + if !ws_size >= size { + return cur_ws_idx; + } + cur_ws_idx.or(Some(idx)).and_then(|cur_idx| { + let cur_size = self.workspaces[cur_idx].size(); + Some(match ws_size <= cur_size { + true => idx, + false => cur_idx, + }) + }) + }); + match idx { + Some(idx) => { + self.free.remove_item(&idx).unwrap(); + self.in_use.push(idx); + Ok(self.workspaces[idx].as_mut_ptr()) + } + None => self.alloc_new(size), + } + } + + fn free(&mut self, ptr: *mut u8) -> Result<()> { + let mut ws_idx = None; + for i in 0..self.in_use.len() { + let idx = self.in_use[i]; + if self.workspaces[idx].as_mut_ptr() == ptr { + self.in_use.remove(i); + ws_idx = Some(idx); + break; + } + } + Ok( + self + .free + .push(ws_idx.ok_or("Tried to free nonexistent workspace.")?), + ) + } +} + +thread_local!(static WORKSPACE_POOL: RefCell = RefCell::new(WorkspacePool::new())); + +const WORKSPACE_PAGE_SIZE: usize = 4 << 10; + +#[no_mangle] +pub extern "C" fn TVMBackendAllocWorkspace( + _device_type: c_int, + _device_id: c_int, + size: u64, + _dtype_code_hint: c_int, + _dtype_bits_hint: c_int, +) -> *mut c_void { + let nbytes = if size == 0 { + WORKSPACE_PAGE_SIZE + } else { + size as usize + }; + WORKSPACE_POOL.with(|pool_cell| { + pool_cell + .borrow_mut() + .alloc(nbytes as usize) + .unwrap_or(ptr::null_mut()) as *mut c_void + }) +} + +#[no_mangle] +pub extern "C" fn TVMBackendFreeWorkspace( + _device_type: c_int, + _device_id: c_int, + ptr: *mut c_void, +) -> c_int { + WORKSPACE_POOL.with(|pool_cell| { + (match pool_cell.borrow_mut().free(ptr as *mut u8) { + Ok(()) => 0, + Err(_) => -1, + }) as c_int + }); + return 0; +} diff --git a/rust/tests/.gitignore b/rust/tests/.gitignore new file mode 100644 index 000000000000..811076739bfa --- /dev/null +++ b/rust/tests/.gitignore @@ -0,0 +1,3 @@ +*.json +*.params +*.o diff --git a/rust/tests/build_model.py b/rust/tests/build_model.py new file mode 100644 index 000000000000..e0b90495159f --- /dev/null +++ b/rust/tests/build_model.py @@ -0,0 +1,53 @@ +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + +CWD = osp.dirname(osp.abspath(osp.expanduser(__file__))) + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc1 = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc1, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1))) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + params[param] = tvm.nd.array(init_value) + return params + +def main(): + dshape = (32, 16) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm', + shape=ishape_dict, + params=params, + dtype='float32') + + with open(osp.join(CWD, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + with open(osp.join(CWD, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_graph_serde.rs b/rust/tests/test_graph_serde.rs new file mode 100644 index 000000000000..b02c12889794 --- /dev/null +++ b/rust/tests/test_graph_serde.rs @@ -0,0 +1,39 @@ +#![feature(try_from)] + +extern crate serde; +extern crate serde_json; + +extern crate tvm; + +use std::{convert::TryFrom, fs, io::Read}; + +use tvm::runtime::Graph; + +#[test] +fn test_load_graph() { + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.params")) + .expect("Could not find TVM graph. Did you run `tests/build_model.py`?") + .read_to_end(&mut params_bytes) + .unwrap(); + let _params = tvm::runtime::load_param_dict(¶ms_bytes); + + let graph = Graph::try_from( + &fs::read_to_string(concat!(env!("CARGO_MANIFEST_DIR"), "/tests/graph.json")).unwrap(), + ) + .unwrap(); + + assert_eq!(graph.nodes[3].op, "tvm_op"); + assert_eq!( + graph.nodes[3] + .attrs + .as_ref() + .unwrap() + .get("func_name") + .unwrap(), + "fuse_dense" + ); + assert_eq!(graph.nodes[5].inputs[0].index, 0); + assert_eq!(graph.nodes[6].inputs[0].index, 1); + assert_eq!(graph.heads.len(), 2); +} diff --git a/rust/tests/test_nnvm/Cargo.toml b/rust/tests/test_nnvm/Cargo.toml new file mode 100644 index 000000000000..7e6ce5fb729c --- /dev/null +++ b/rust/tests/test_nnvm/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "test-nnvm" +version = "0.0.0" +license = "Apache-2.0" +authors = ["Nick Hynes "] + +[dependencies] +ndarray = "0.11.2" +tvm = { path = "../../" } +serde = "1.0.59" +serde_json = "1.0.17" + +[build-dependencies] +ar = "0.6.0" diff --git a/rust/tests/test_nnvm/build.rs b/rust/tests/test_nnvm/build.rs new file mode 100644 index 000000000000..cb3a4e0d574d --- /dev/null +++ b/rust/tests/test_nnvm/build.rs @@ -0,0 +1,28 @@ +extern crate ar; + +use std::{env, path::PathBuf, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_graph.py" + )).arg(&out_dir) + .output() + .expect("Failed to execute command"); + if output.stderr.len() > 0 { + panic!(String::from_utf8(output.stderr).unwrap()); + } + + let in_path: PathBuf = [&out_dir, "graph.o"].iter().collect(); + let out_path: PathBuf = [&out_dir, "libgraph.a"].iter().collect(); + let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); + builder.append_path(in_path.to_str().unwrap()).unwrap(); + + println!("cargo:rustc-link-lib=static=graph"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tests/test_nnvm/src/build_test_graph.py b/rust/tests/test_nnvm/src/build_test_graph.py new file mode 100755 index 000000000000..429cc2128931 --- /dev/null +++ b/rust/tests/test_nnvm/src/build_test_graph.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python3 + +"""Builds a simple NNVM graph for testing.""" + +from os import path as osp +import sys + +import nnvm +from nnvm import sym +from nnvm.compiler import graph_util +from nnvm.testing import init +import numpy as np +import tvm + + +def _get_model(dshape): + data = sym.Variable('data', shape=dshape) + fc = sym.dense(data, units=dshape[-1]*2, use_bias=True) + left, right = sym.split(fc, indices_or_sections=2, axis=1) + return sym.Group(((left + 1), (right - 1), fc)) + + +def _init_params(graph, input_shapes, initializer=init.Xavier(), seed=10): + if isinstance(graph, sym.Symbol): + graph = nnvm.graph.create(graph) + ishapes, _ = graph_util.infer_shape(graph, **input_shapes) + param_shapes = dict(zip(graph.index.input_names, ishapes)) + np.random.seed(seed) + params = {} + for param, shape in param_shapes.items(): + if param in {'data', 'label'} or not shape: + continue + + init_value = np.arange(np.product(shape), 0, -1).reshape(*shape).astype('float32') + if param.endswith('_bias'): + params[param] = tvm.nd.array(init_value) + continue + + init_value = np.empty(shape).astype('float32') + initializer(param, init_value) + # init_value /= init_value.sum() + 1e-10 + params[param] = tvm.nd.array(init_value) + return params + +def main(): + dshape = (4, 8) + net = _get_model(dshape) + ishape_dict = {'data': dshape} + params = _init_params(net, ishape_dict) + graph, lib, params = nnvm.compiler.build(net, 'llvm --system-lib', + shape=ishape_dict, + params=params, + dtype='float32') + + out_dir = sys.argv[1] + lib.save(osp.join(sys.argv[1], 'graph.o')) + with open(osp.join(out_dir, 'graph.json'), 'w') as f_resnet: + f_resnet.write(graph.json()) + with open(osp.join(out_dir, 'graph.params'), 'wb') as f_params: + f_params.write(nnvm.compiler.save_param_dict(params)) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_nnvm/src/main.rs b/rust/tests/test_nnvm/src/main.rs new file mode 100644 index 000000000000..0953ce2a2603 --- /dev/null +++ b/rust/tests/test_nnvm/src/main.rs @@ -0,0 +1,80 @@ +#![feature(try_from)] + +#[macro_use] +extern crate ndarray; +extern crate serde; +extern crate serde_json; + +extern crate tvm; +use std::{collections::HashMap, convert::TryFrom, fs, io::Read}; + +use ndarray::Array; +use tvm::runtime::{Graph, GraphExecutor, SystemLibModule, Tensor}; + +const BATCH_SIZE: usize = 4; +const IN_DIM: usize = 8; + +macro_rules! check_sum { + ($e:expr, $a:ident, $b:ident) => { + let a = Array::try_from($e.get_input(stringify!($a)).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($e:expr, $a:expr, $b:ident) => { + let a = Array::try_from($e.get_output($a).unwrap()).unwrap(); + check_sum!(a, $b); + }; + ($a:ident, $b:ident) => { + let a_sum: f32 = $a.scalar_sum(); + let b_sum: f32 = $b.scalar_sum(); + assert!((a_sum - b_sum).abs() < 1e-2, "{} != {}", a_sum, b_sum); + }; +} + +fn main() { + let syslib = SystemLibModule::default(); + + let mut params_bytes = Vec::new(); + fs::File::open(concat!(env!("OUT_DIR"), "/graph.params")) + .unwrap() + .read_to_end(&mut params_bytes) + .unwrap(); + let params = tvm::runtime::load_param_dict(¶ms_bytes) + .unwrap() + .into_iter() + .map(|(k, v)| (k, v.to_owned())) + .collect::>>(); + + let graph = + Graph::try_from(&fs::read_to_string(concat!(env!("OUT_DIR"), "/graph.json")).unwrap()).unwrap(); + let mut exec = GraphExecutor::new(graph, &syslib).unwrap(); + + let x = Array::from_shape_vec( + (BATCH_SIZE, IN_DIM), + (0..BATCH_SIZE * IN_DIM) + .map(|x| x as f32) + .collect::>(), + ).unwrap(); + let w = Array::try_from(params.get("dense0_weight").unwrap()) + .unwrap() + .into_shape((IN_DIM * 2, IN_DIM)) + .unwrap(); + let b = Array::try_from(params.get("dense0_bias").unwrap()).unwrap(); + let dense = x.dot(&w.t()) + &b; + let left = dense.slice(s![.., 0..IN_DIM]); + let right = dense.slice(s![.., IN_DIM..]); + let expected_o0 = &left + 1f32; + let expected_o1 = &right - 1f32; + + exec.load_params(params); + exec.set_input("data", x.clone().into()); + + check_sum!(exec, data, x); + check_sum!(exec, dense0_weight, w); + check_sum!(exec, dense0_bias, b); + + exec.run(); + + check_sum!(exec, 0, expected_o0); + check_sum!(exec, 1, expected_o1); + check_sum!(exec, 2, dense); +} diff --git a/rust/tests/test_tvm_basic/Cargo.toml b/rust/tests/test_tvm_basic/Cargo.toml new file mode 100644 index 000000000000..bd4193bcb8fb --- /dev/null +++ b/rust/tests/test_tvm_basic/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "test-tvm-basic" +version = "0.0.0" +license = "Apache-2.0" +authors = ["Nick Hynes "] + +[dependencies] +ndarray = "0.11.2" +tvm = { path = "../../" } + +[build-dependencies] +ar = "0.6.0" diff --git a/rust/tests/test_tvm_basic/build.rs b/rust/tests/test_tvm_basic/build.rs new file mode 100644 index 000000000000..778dd1cab1ca --- /dev/null +++ b/rust/tests/test_tvm_basic/build.rs @@ -0,0 +1,28 @@ +extern crate ar; + +use std::{env, path::PathBuf, process::Command}; + +use ar::Builder; +use std::fs::File; + +fn main() { + let out_dir = env::var("OUT_DIR").unwrap(); + + let output = Command::new(concat!( + env!("CARGO_MANIFEST_DIR"), + "/src/build_test_lib.py" + )).arg(&out_dir) + .output() + .expect("Failed to execute command"); + if output.stderr.len() > 0 { + panic!(String::from_utf8(output.stderr).unwrap()); + } + + let in_path: PathBuf = [&out_dir, "test.o"].iter().collect(); + let out_path: PathBuf = [&out_dir, "libtest.a"].iter().collect(); + let mut builder = Builder::new(File::create(out_path.to_str().unwrap()).unwrap()); + builder.append_path(in_path.to_str().unwrap()).unwrap(); + + println!("cargo:rustc-link-lib=static=test"); + println!("cargo:rustc-link-search=native={}", out_dir); +} diff --git a/rust/tests/test_tvm_basic/src/build_test_lib.py b/rust/tests/test_tvm_basic/src/build_test_lib.py new file mode 100755 index 000000000000..7289a778fcec --- /dev/null +++ b/rust/tests/test_tvm_basic/src/build_test_lib.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +"""Prepares a simple TVM library for testing.""" + +from os import path as osp +import sys + +import tvm + +def main(): + n = tvm.var('n') + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) + B(*i), name='C') + s = tvm.create_schedule(C.op) + s[C].parallel(s[C].op.axis[0]) + print(tvm.lower(s, [A, B, C], simple_mode=True)) + tvm.build(s, [A, B, C], 'llvm --system-lib').save(osp.join(sys.argv[1], 'test.o')) + +if __name__ == '__main__': + main() diff --git a/rust/tests/test_tvm_basic/src/main.rs b/rust/tests/test_tvm_basic/src/main.rs new file mode 100644 index 000000000000..b6c11451d12a --- /dev/null +++ b/rust/tests/test_tvm_basic/src/main.rs @@ -0,0 +1,25 @@ +extern crate ndarray; +#[macro_use] +extern crate tvm; + +use ndarray::Array; +use tvm::{ + ffi::runtime::DLTensor, + runtime::{Module, SystemLibModule}, +}; + +fn main() { + let syslib = SystemLibModule::default(); + let add = syslib + .get_function("default_function") + .expect("main function not found"); + let mut a = Array::from_vec(vec![1f32, 2., 3., 4.]); + let mut b = Array::from_vec(vec![1f32, 0., 1., 0.]); + let mut c = Array::from_vec(vec![0f32; 4]); + let e = Array::from_vec(vec![2f32, 2., 4., 4.]); + let mut a_dl: DLTensor = (&mut a).into(); + let mut b_dl: DLTensor = (&mut b).into(); + let mut c_dl: DLTensor = (&mut c).into(); + call_packed!(add, &mut a_dl, &mut b_dl, &mut c_dl); + assert!(c.all_close(&e, 1e-8f32)); +} diff --git a/src/api/api_lang.cc b/src/api/api_lang.cc index 8ca49f19baec..75365da5bf50 100644 --- a/src/api/api_lang.cc +++ b/src/api/api_lang.cc @@ -239,6 +239,14 @@ TVM_REGISTER_API("_TensorIntrin") args[6]); }); +TVM_REGISTER_API("_TensorIntrinCall") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TensorIntrinCallNode::make(args[0], + args[1], + args[2], + args[3]); + }); + TVM_REGISTER_API("_TensorEqual") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = args[0].operator Tensor() == args[1].operator Tensor(); @@ -278,6 +286,18 @@ TVM_REGISTER_API("_ScanOp") args[7]); }); +TVM_REGISTER_API("_TensorComputeOp") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TensorComputeOpNode::make(args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + args[7]); + }); + TVM_REGISTER_API("_ExternOp") .set_body([](TVMArgs args, TVMRetValue* ret) { *ret = ExternOpNode::make(args[0], diff --git a/src/codegen/codegen_cuda.cc b/src/codegen/codegen_cuda.cc index 0960106ae471..2ed8d8e3ff78 100644 --- a/src/codegen/codegen_cuda.cc +++ b/src/codegen/codegen_cuda.cc @@ -77,6 +77,8 @@ void CodeGenCUDA::PrintType(Type t, std::ostream& os) { // NOLINT(*) if (!fail && (lanes >= 2 && lanes <= 4)) { os << lanes; return; } + } else if (t == Bool()) { + os << "bool"; return; } else if (t.is_uint() || t.is_int()) { if (t.is_uint()) { if (t.lanes() != 1) { diff --git a/src/codegen/codegen_metal.cc b/src/codegen/codegen_metal.cc index 3bbe98289439..031313190370 100644 --- a/src/codegen/codegen_metal.cc +++ b/src/codegen/codegen_metal.cc @@ -141,6 +141,9 @@ void CodeGenMetal::PrintType(Type t, std::ostream& os) { // NOLINT(*) << "do not yet support vector types"; os << "void*"; return; } + if (t == Bool()) { + os << "bool"; return; + } bool fail = false; if (t.is_float()) { switch (t.bits()) { diff --git a/src/codegen/codegen_opencl.cc b/src/codegen/codegen_opencl.cc index 3d3de5e3bcf4..a0b3c2000a80 100644 --- a/src/codegen/codegen_opencl.cc +++ b/src/codegen/codegen_opencl.cc @@ -80,6 +80,9 @@ void CodeGenOpenCL::PrintType(Type t, std::ostream& os) { // NOLINT(*) << "do not yet support vector types"; os << "void*"; return; } + if (t == Bool()) { + os << "bool"; return; + } bool fail = false; if (t.is_float()) { switch (t.bits()) { diff --git a/src/codegen/spirv/ir_builder.cc b/src/codegen/spirv/ir_builder.cc index 41cb48c5854b..fdf4b9852430 100644 --- a/src/codegen/spirv/ir_builder.cc +++ b/src/codegen/spirv/ir_builder.cc @@ -438,8 +438,25 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { const tvm::Type& from = value.stype.type; const tvm::Type& to = dst_type.type; CHECK_EQ(from.lanes(), to.lanes()); - - if (from.is_int() && to.is_int()) { + if (from == Bool()) { + if (to.is_int()) { + return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); + } else if (to.is_uint()) { + return Select(value, UIntImm(dst_type, 1), UIntImm(dst_type, 0)); + } else { + LOG(FATAL) << "cannot cast from " << from << " to " << to; + return Value(); + } + } else if (to == Bool()) { + if (from.is_int()) { + return NE(value, IntImm(value.stype, 0)); + } else if (to.is_uint()) { + return NE(value, UIntImm(value.stype, 0)); + } else { + LOG(FATAL) << "cannot cast from " << from << " to " << to; + return Value(); + } + } else if (from.is_int() && to.is_int()) { return MakeValue(spv::OpSConvert, dst_type, value); } else if (from.is_uint() && to.is_uint()) { return MakeValue(spv::OpUConvert, dst_type, value); diff --git a/src/lang/buffer.cc b/src/lang/buffer.cc index 69967c55a7ff..183a52f785bd 100644 --- a/src/lang/buffer.cc +++ b/src/lang/buffer.cc @@ -260,25 +260,42 @@ inline Expr BufferOffset(const BufferNode* n, Array index, Type dtype) { } Expr Buffer::vload(Array begin, Type dtype) const { + // specially handle bool, stored as Int(8) const BufferNode* n = operator->(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; - return ir::Load::make( - dtype, n->data, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + if (dtype == Bool()) { + return ir::Cast::make( + Bool(), + ir::Load::make( + Int(8), n->data, BufferOffset(n, begin, Int(8)), + const_true())); + } else { + return ir::Load::make( + dtype, n->data, BufferOffset(n, begin, dtype), + const_true(dtype.lanes())); + } } Stmt Buffer::vstore(Array begin, Expr value) const { + // specially handle bool, stored as Int(8) const BufferNode* n = operator->(); Type dtype = value.type(); CHECK(dtype.element_of() == n->dtype.element_of() && dtype.lanes() % n->dtype.lanes() == 0) << "Cannot load " << dtype << " from buffer of " << n->dtype; - return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), - const_true(dtype.lanes())); + if (value.type() == Bool()) { + return ir::Store::make(n->data, + ir::Cast::make(Int(8), value), + BufferOffset(n, begin, Int(8)), + const_true()); + } else { + return ir::Store::make(n->data, value, BufferOffset(n, begin, dtype), + const_true(dtype.lanes())); + } } Buffer Buffer::MakeStrideView() const { diff --git a/src/lang/tensor.cc b/src/lang/tensor.cc index 4f9c3e9d1782..9b1a58abcee4 100644 --- a/src/lang/tensor.cc +++ b/src/lang/tensor.cc @@ -10,6 +10,8 @@ namespace tvm { +// Tensor + Expr Tensor::operator()(Array indices) const { Array arr(indices.begin(), indices.end()); return operator()(arr); @@ -26,6 +28,15 @@ Expr Tensor::operator()(Array indices) const { return n; } +Tensor Operation::output(size_t i) const { + auto node = make_node(); + node->op = *this; + node->value_index = i; + node->dtype = (*this)->output_dtype(i); + node->shape = (*this)->output_shape(i); + return Tensor(node); +} + Tensor TensorNode::make(Array shape, Type dtype, Operation op, @@ -46,14 +57,8 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) TVM_REGISTER_NODE_TYPE(TensorNode); -Tensor Operation::output(size_t i) const { - auto node = make_node(); - node->op = *this; - node->value_index = i; - node->dtype = (*this)->output_dtype(i); - node->shape = (*this)->output_shape(i); - return Tensor(node); -} + +// TensorIntrin TensorIntrin TensorIntrinNode::make(std::string name, Operation op, @@ -79,4 +84,27 @@ TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) }); TVM_REGISTER_NODE_TYPE(TensorIntrinNode); + + +// TensorIntrinCall + +TensorIntrinCall TensorIntrinCallNode::make(TensorIntrin intrin, + Array tensors, + Array regions, + Array reduce_axis) { + auto n = make_node(); + n->intrin = std::move(intrin); + n->tensors = std::move(tensors); + n->regions = std::move(regions); + n->reduce_axis = std::move(reduce_axis); + return TensorIntrinCall(n); +} + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const TensorIntrinCallNode *n, IRPrinter *p) { + p->stream << "TensorIntrinCall(intrin=" << n->intrin << ", " << n << ")"; + }); + +TVM_REGISTER_NODE_TYPE(TensorIntrinCallNode); + } // namespace tvm diff --git a/src/op/compute_op.cc b/src/op/compute_op.cc index daafac21b180..5c972595ff00 100644 --- a/src/op/compute_op.cc +++ b/src/op/compute_op.cc @@ -13,6 +13,7 @@ #include "compute_op.h" #include "op_util.h" #include "../schedule/message_passing.h" +#include "../arithmetic/compute_expr.h" namespace tvm { @@ -545,4 +546,38 @@ static void VerifyComputeOp(const ComputeOpNode* op) { v.Run(); } +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update) { + Array conds; + std::unordered_set banned; + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { + IterVar iv = stage->leaf_iter_vars[i]; + auto iit = stage->iter_var_attrs.find(iv); + if (iit != stage->iter_var_attrs.end()) { + const IterVarAttr& attr = (*iit).second; + if (attr->iter_type == kTensorized) { + break; + } + } + if (iv->iter_type == kCommReduce) { + auto vit = dom_map.find(iv); + CHECK(vit != dom_map.end()); + const Range& vrange = vit->second; + conds.push_back(likely(iv->var > vrange->min)); + banned.insert(iv->var.get()); + } + } + for (const Expr& pred : n.main_predicates) { + if (ir::ExprUseVar(pred, banned)) { + LOG(FATAL) << "Tensorize update transform failed, the condition " + << pred << " has a conflict with the reset condition"; + } + } + + return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), + update, body); +} } // namespace tvm diff --git a/src/op/compute_op.h b/src/op/compute_op.h index 996764c6cdc1..87b0814c1ad9 100644 --- a/src/op/compute_op.h +++ b/src/op/compute_op.h @@ -14,7 +14,7 @@ namespace tvm { // loop nest structure for general compute -// This the the loop nest structured used in compute. +// This the loop nest structured used in compute. // Does not include the loop body. struct ComputeLoopNest { // The common number of loops between init and main @@ -73,6 +73,21 @@ Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop); + +/*! + * \brief Transform the update part when there is no init func in tensorizing + * \param stage The stage for tensorizing. + * \param dom_map The range of each iter var. + * \param n The loop nest structured used in compute. + * \param body The body func in tensorize intrin + * \param update The update func in tensorize intrin + * \return Transformed result. + */ +Stmt TransformUpdate(const Stage& stage, + const std::unordered_map& dom_map, + const ComputeLoopNest& n, + Stmt body, + Stmt update); } // namespace tvm #endif // TVM_OP_COMPUTE_OP_H_ diff --git a/src/op/tensor_compute_op.cc b/src/op/tensor_compute_op.cc new file mode 100644 index 000000000000..f9b8188d4685 --- /dev/null +++ b/src/op/tensor_compute_op.cc @@ -0,0 +1,361 @@ +/*! + * Copyright (c) 2017 by Contributors + * \brief Tensor Compute Op. + * \file tensor_compute_op.cc + */ +#include +#include +#include +#include +#include +#include +#include "./op_util.h" +#include "./compute_op.h" +#include "../arithmetic/compute_expr.h" + +namespace tvm { +using namespace ir; +// TensorComputeOpNode +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const TensorComputeOpNode *op, + IRPrinter *p) { + p->stream << "tensor_compute_op(" << op->name << ", " << op << ")"; + }); + +TVM_REGISTER_NODE_TYPE(TensorComputeOpNode); + +int TensorComputeOpNode::num_outputs() const { + return static_cast(this->intrin->buffers.size() - this->inputs.size()); +} + +Array TensorComputeOpNode::root_iter_vars() const { + Array ret = axis; + for (IterVar iv : reduce_axis) { + ret.push_back(iv); + } + return ret; +} + +Type TensorComputeOpNode::output_dtype(size_t i) const { + return this->intrin->buffers[this->inputs.size() + i]->dtype; +} + +Array TensorComputeOpNode::output_shape(size_t i) const { + Array shape; + for (const auto& ivar : this->axis) { + shape.push_back(ivar->dom->extent); + } + return shape; +} + + +Operation TensorComputeOpNode::make(std::string name, + std::string tag, + Array axis, + Array reduce_axis, + int schedulable_ndim, + TensorIntrin intrin, + Array tensors, + Array regions) { + auto n = make_node(); + n->name = std::move(name); + n->tag = std::move(tag); + n->axis = std::move(axis); + n->reduce_axis = std::move(reduce_axis); + n->schedulable_ndim = std::move(schedulable_ndim); + n->intrin = std::move(intrin); + n->inputs = std::move(tensors); + n->input_regions = std::move(regions); + return Operation(n); +} + +Array TensorComputeOpNode::InputTensors() const { + return inputs; +} + +Operation TensorComputeOpNode::ReplaceInputs( + const Operation& self, + const std::unordered_map& rmap) const { + CHECK_EQ(self.operator->(), this); + auto n = make_node(*this); + auto intrin = make_node(*(this->intrin.operator->())); + intrin->body = op::ReplaceTensor(this->intrin->body, rmap); + if (intrin->reduce_init.defined()) { + intrin->reduce_init = op::ReplaceTensor(this->intrin->reduce_init, rmap); + } + if (intrin->reduce_update.defined()) { + intrin->reduce_update = op::ReplaceTensor(this->intrin->reduce_update, rmap); + } + for (size_t i = 0; i < n->inputs.size(); ++i) { + Tensor t = n->inputs[i]; + if (rmap.count(t)) { + n->inputs.Set(i, rmap.at(t)); + } + } + + if (intrin->body.same_as(n->intrin->body) && + intrin->reduce_init.same_as(n->intrin->reduce_init) && + intrin->reduce_update.same_as(n->intrin->reduce_update) && + inputs.same_as(n->inputs)) { + return self; + } else { + n->intrin = TensorIntrin(intrin); + return Operation(n); + } +} + +void TensorComputeOpNode::PropBoundToInputs( + const Operation& self, + const std::unordered_map& dom_map, + std::unordered_map* out_dom_map) const { + for (size_t i = 0; i < this->inputs.size(); ++i) { + Tensor t = this->inputs[i]; + Region region = input_regions[i]; + + auto it = out_dom_map->find(t); + if (it == out_dom_map->end()) continue; + TensorDom& dom = it->second; + for (size_t j = 0; j < t.ndim(); ++j) { + dom.data[j].emplace_back(EvalSet(region[j], dom_map)); + } + } +} + +void TensorComputeOpNode::GatherBound( + const Operation& self, + const std::unordered_map& tensor_dom, + std::unordered_map* out_dom_map) const { + const TensorDom& tdom = tensor_dom.at(self.output(0)); + for (size_t i = 0; i < this->axis.size(); ++i) { + Range r = arith::Union(tdom.data.at(i)).cover_range(this->axis[i]->dom); + CHECK(!out_dom_map->count(this->axis[i])); + (*out_dom_map)[this->axis[i]] = r; + } + for (size_t i = 0; i < this->reduce_axis.size(); ++i) { + CHECK(!out_dom_map->count(this->reduce_axis[i])); + (*out_dom_map)[this->reduce_axis[i]] = this->reduce_axis[i]->dom; + } +} + +Stmt TensorComputeOpNode::BuildRealize( + const Stage& stage, + const std::unordered_map& realize_map, + const Stmt& body) const { + CHECK_EQ(stage->op.get(), this); + HalideIR::Internal::Region bounds; + for (IterVar iv : this->axis) { + bounds.push_back(realize_map.at(iv)); + } + Stmt realize = body; + for (int i = this->num_outputs(); i > 0; --i) { + Tensor t = stage->op.output(i-1); + realize = ir::Realize::make(t->op, t->value_index, + t->dtype, bounds, const_true(), realize); + // alignment requirement, only useful for compute + for (int i = 0; i < schedulable_ndim; ++i) { + auto it = stage->iter_var_attrs.find(this->axis[i]); + if (it != stage->iter_var_attrs.end()) { + IterVarAttr attr = (*it).second; + if (attr->dim_align_factor != 0) { + Array tuple = {static_cast(i), + attr->dim_align_factor, + attr->dim_align_offset}; + realize = ir::AttrStmt::make( + t, ir::attr::buffer_dim_align, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), + realize); + } + } + } + } + return realize; +} + +ComputeLoopNest MakeLoopNest( + const TensorComputeOpNode* self, + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) { + CHECK_EQ(stage->op.operator->(), self); + ComputeLoopNest ret; + // make main loop nest + ret.main_nest = op::MakeLoopNest( + stage, dom_map, 0, false, std::unordered_set(), &ret.main_vmap, + debug_keep_trivial_loop); + ret.main_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.main_vmap, false, + std::unordered_set()); + for (auto& e : ret.main_predicates) { + e = likely(e); + } + if (stage->store_predicate.defined()) { + ret.main_predicates.push_back(stage->store_predicate); + } + if (self->reduce_axis.size() != 0) { + // try to find the location to insert the initialization. + // Fuse the initialization and provide loop when possible. + std::unordered_map update_state; + for (IterVar iv : self->reduce_axis) { + update_state[iv] = 2; + } + for (int i = 0; i < self->schedulable_ndim; ++i) { + update_state[self->axis[i]] = 1; + } + // find which iter var is related to reduction and which is related to axis. + schedule::PassDownBitMaskOr(stage, &update_state); + auto leaf_iter_vars = stage->leaf_iter_vars; + // first first loop that is related to reduction. + size_t begin_loop = leaf_iter_vars.size(); + for (size_t i = 0; i < leaf_iter_vars.size(); ++i) { + auto iv = leaf_iter_vars[i]; + int flag = update_state.at(iv); + if ((flag & 2) != 0) { + begin_loop = i; break; + } + ret.init_vmap[iv] = ret.main_vmap.at(iv); + } + ret.num_common_loop = begin_loop; + // skip loops that does not relates to axis. + std::unordered_set skip_iter; + for (auto kv : update_state) { + int flag = kv.second; + if ((flag & 1) == 0) skip_iter.insert(kv.first); + } + ret.init_nest = op::MakeLoopNest( + stage, dom_map, begin_loop, true, + skip_iter, &(ret.init_vmap), debug_keep_trivial_loop); + ret.init_predicates = schedule::MakeBoundCheck( + stage, dom_map, ret.init_vmap, true, skip_iter); + for (auto& e : ret.init_predicates) { + e = likely(e); + } + } else { + CHECK_EQ(ret.main_nest.size(), stage->leaf_iter_vars.size() + 1); + ret.num_common_loop = stage->leaf_iter_vars.size(); + } + // copy elison here. + return ret; +} + + +Stmt TensorComputeOpNode::BuildProvide( + const Stage& stage, + const std::unordered_map& dom_map, + bool debug_keep_trivial_loop) const { + CHECK_EQ(stage->op.operator->(), this); + + // Start bind data. + Stmt nop = Evaluate::make(0); + std::vector input_bind_nest, output_bind_nest; + Array inputs = this->InputTensors(); + + // input binding + size_t num_inputs = inputs.size(); + for (size_t i = 0; i < num_inputs; ++i) { + Tensor tensor = inputs[i]; + Region region = this->input_regions[i]; + Buffer buffer = this->intrin->buffers[i]; + Array bind_spec{buffer, tensor}; + + Array tuple; + for (size_t i = 0; i < region.size(); ++i) { + tuple.push_back(region[i]->min); + tuple.push_back(region[i]->extent); + } + input_bind_nest.emplace_back(AttrStmt::make( + bind_spec, ir::attr::buffer_bind_scope, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + } + + // output binding + for (int i = 0; i < this->num_outputs(); ++i) { + Tensor tensor = stage->op.output(i); + Buffer buffer = this->intrin->buffers[num_inputs + i]; + Array bind_spec{buffer, tensor}; + + Array tuple; + for (size_t i = 0; i < this->axis.size(); ++i) { + auto ivar = this->axis[i]; + if (i < static_cast(this->schedulable_ndim)) { + tuple.push_back(ivar->var); + tuple.push_back(1); + } else { + Range dom = ivar->dom; + tuple.push_back(dom->min); + tuple.push_back(dom->extent); + } + } + + output_bind_nest.emplace_back(AttrStmt::make( + bind_spec, ir::attr::buffer_bind_scope, + Call::make(Handle(), ir::intrinsic::tvm_tuple, tuple, Call::Intrinsic), nop)); + } + + // Check variable remap + std::unordered_map vmap; + ir::ArgBinder binder(&vmap); + + size_t tloc = stage->leaf_iter_vars.size(); + ComputeLoopNest n = MakeLoopNest(this, stage, dom_map, debug_keep_trivial_loop); + + if (this->reduce_axis.size() == 0) { + std::vector > nest( + n.main_nest.begin(), n.main_nest.begin() + tloc + 1); + nest.emplace_back(op::MakeIfNest(n.main_predicates)); + CHECK_EQ(n.init_predicates.size(), 0U); + CHECK(this->intrin->body.defined()) + << "Normal store op for intrin " << this << " is not defined"; + Stmt body = MergeNest(output_bind_nest, this->intrin->body); + body = MergeNest(input_bind_nest, body); + body = ir::Substitute(body, vmap); + body = MergeNest(binder.asserts(), body); + body = op::Substitute(body, n.main_vmap); + Stmt ret = MergeNest(nest, body); + return ret; + } else { + // Need to split reduction + CHECK(this->intrin->reduce_update.defined()) + << "Reduction update op is not defined"; + // Need init and update steps + CHECK_NE(this->reduce_axis.size(), 0U); + std::vector > common( + n.main_nest.begin(), n.main_nest.begin() + n.num_common_loop + 1); + std::vector > update_nest( + n.main_nest.begin() + n.num_common_loop + 1, n.main_nest.begin() + tloc + 1); + update_nest.emplace_back(op::MakeIfNest(n.main_predicates)); + + if (this->intrin->reduce_init.defined()) { + // init nest + std::vector > init_nest( + n.init_nest.begin(), n.init_nest.begin() + tloc + 1); + init_nest.emplace_back(op::MakeIfNest(n.init_predicates)); + Stmt init = MergeNest(output_bind_nest, this->intrin->reduce_init); + init = op::Substitute(init, n.init_vmap); + init = MergeNest(init_nest, init); + // The update + Stmt update = MergeNest(output_bind_nest, this->intrin->reduce_update); + update = MergeNest(input_bind_nest, update); + update = ir::Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = op::Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, Block::make(init, update)); + } else { + // When init op is not available, use body op for reset in the first iter. + CHECK(this->intrin->body.defined()) + << "Normal body op is not defined"; + Stmt update = TransformUpdate(stage, dom_map, n, + this->intrin->body, + this->intrin->reduce_update); + update = MergeNest(output_bind_nest, update); + update = MergeNest(input_bind_nest, update); + update = ir::Substitute(update, vmap); + update = MergeNest(binder.asserts(), update); + update = op::Substitute(update, n.main_vmap); + update = MergeNest(update_nest, update); + return MergeNest(common, update); + } + } +} + +} // namespace tvm diff --git a/src/op/tensorize.cc b/src/op/tensorize.cc index 6daaedd16de1..a61aac422284 100644 --- a/src/op/tensorize.cc +++ b/src/op/tensorize.cc @@ -10,7 +10,6 @@ #include "op_util.h" #include "compute_op.h" #include "../schedule/message_passing.h" -#include "../arithmetic/compute_expr.h" namespace tvm { @@ -323,50 +322,6 @@ void VerifyTensorizeBody( } } -/*! - * \brief Transform the update part when there is no init func in tensorizing - * \param stage The stage for tensorizing. - * \param dom_map The range of each iter var. - * \param n The loop nest structured used in compute. - * \param body The body func in tensorize intrin - * \param update The update func in tensorize intrin - * \return Transformed result. - */ -Stmt TransformUpdate(const Stage& stage, - const std::unordered_map& dom_map, - const ComputeLoopNest& n, - Stmt body, - Stmt update) { - Array conds; - std::unordered_set banned; - for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { - IterVar iv = stage->leaf_iter_vars[i]; - auto iit = stage->iter_var_attrs.find(iv); - if (iit != stage->iter_var_attrs.end()) { - const IterVarAttr& attr = (*iit).second; - if (attr->iter_type == kTensorized) { - break; - } - } - if (iv->iter_type == kCommReduce) { - auto vit = dom_map.find(iv); - CHECK(vit != dom_map.end()); - const Range& vrange = vit->second; - conds.push_back(likely(iv->var > vrange->min)); - banned.insert(iv->var.get()); - } - } - for (const Expr& pred : n.main_predicates) { - if (ir::ExprUseVar(pred, banned)) { - LOG(FATAL) << "Tensorize update transform failed, the condition " - << pred << " has a conflict with the reset condition"; - } - } - - return IfThenElse::make(arith::ComputeReduce(conds, const_true(1)), - update, body); -} - Stmt MakeTensorize(const ComputeOpNode* self, const Stage& stage, const std::unordered_map& dom_map, diff --git a/src/pass/arg_binder.cc b/src/pass/arg_binder.cc index 0fac313c079b..623886c31b86 100644 --- a/src/pass/arg_binder.cc +++ b/src/pass/arg_binder.cc @@ -91,7 +91,9 @@ void ArgBinder::BindBuffer(const Buffer& arg, // bind pointer and offset. if (is_zero(arg->elem_offset)) { CHECK(is_zero(value->elem_offset)) - << "Trying to bind a Buffer with offset into one without offset"; + << "Trying to bind a Buffer with offset into one without offset " + << " required elem_offset=" << arg->elem_offset + << ", provided elem_offset=" << value->elem_offset; } this->Bind(arg->data, value->data, arg_name + ".data"); diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 28a6ace9bfa6..993f6294e15b 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -191,10 +191,16 @@ class StorageFlattener : public IRMutator { buf_map_[key].released = true; Stmt ret; + Type storage_type = e.buffer->dtype; + // specially handle bool, lower its storage + // type to be Int(8)(byte) + if (storage_type == Bool()) { + storage_type = Int(8); + } if (strides.size() != 0) { int first_dim = 0; ret = Allocate::make( - e.buffer->data, e.buffer->dtype, + e.buffer->data, storage_type, {arith::ComputeExpr(e.buffer->strides[first_dim], e.buffer->shape[first_dim])}, make_const(Bool(e.buffer->dtype.lanes()), true), body); } else { @@ -203,7 +209,7 @@ class StorageFlattener : public IRMutator { shape.push_back(make_const(Int(32), 1)); } ret = Allocate::make( - e.buffer->data, e.buffer->dtype, shape, + e.buffer->data, storage_type, shape, make_const(Bool(e.buffer->dtype.lanes()), true), body); } ret = AttrStmt::make( diff --git a/src/relay/ir/debug_printer.cc b/src/relay/ir/debug_printer.cc index e216faa0f195..90e82d3b2dd7 100644 --- a/src/relay/ir/debug_printer.cc +++ b/src/relay/ir/debug_printer.cc @@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor { } Doc VisitExpr_(const CallNode* c) final { - auto args = DocifyExprArray(c->args); return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">"); } @@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor { return DocOfStr(o->name); } + Doc VisitExpr_(const TupleGetItemNode* g) final { + return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index)); + } + public: ExprDocifier(const Environment& env) : env(env), td(env) { } @@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) { TVM_REGISTER_API("relay._expr._debug_print") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[1]; - std::cout << x << std::endl; if (x.as()) { *ret = PrintType(args[0], Downcast(x)); } else { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index dbbb5b84fc8b..6b56cb4e844f 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << ", " << node->false_branch << ")"; }); +TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { + NodePtr n = make_node(); + n->tuple = std::move(tuple); + n->index = index; + return TupleGetItem(n); +} + +TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TupleGetItemNode::make(args[0], args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TupleGetItemNode* node, tvm::IRPrinter* p) { + p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index e3393bdb039b..792f99d699dd 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { } } -Type ExprMutator::VisitType(const Type& t) { - return t; +Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { + auto t = this->Mutate(g->tuple); + if (g->tuple == t) { + return GetRef(g); + } else { + return TupleGetItemNode::make(t, g->index); + } } +Type ExprMutator::VisitType(const Type& t) { return t; } + void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { } @@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { void ExprVisitor::VisitExpr_(const OpNode* op) { return; } +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitExpr(op->tuple); +} + void ExprVisitor::VisitType(const Type& t) { return; } } // namespace relay diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc new file mode 100644 index 000000000000..e6d60f9344a1 --- /dev/null +++ b/src/relay/op/image/resize.cc @@ -0,0 +1,87 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file resize.cc + * \brief Image operators + */ +#include +#include +#include "../nn/layout.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(ResizeAttrs); + +bool ResizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + + const ResizeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->layout); + CHECK(in_layout.convertible(kNCHW)) + << "Resize only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + + auto oshape = ConvertLayout(data->shape, in_layout, kNCHW); + oshape[2] = param->size[0]; + oshape[3] = param->size[1]; + + // assign output type + reporter->Assign(types[1], + TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout), + data->dtype)); + return true; +} + + +// Positional relay function to create image operator +// used by frontend FFI. +Expr MakeResize(Expr data, + Array size, + std::string layout, + std::string method, + bool align_corners) { + auto attrs = make_node(); + attrs->size = std::move(size); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->align_corners = align_corners; + static const Op& op = Op::Get("image.resize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.image._make.resize") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeResize, args, rv); + }); + + +RELAY_REGISTER_OP("image.resize") +.describe(R"code(Perform resize to input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **out**: Output is 4D array of shape + for layout NCHW + (batch_size, channels, size[0], size[1]) + + for layout NHWC + (batch_size, size[0], size[1], channels) +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(5) +.add_type_rel("Resize", ResizeRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 920fc68d51e8..4717e3fe0803 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -49,9 +49,9 @@ bool Conv2DRel(const Array& types, CHECK_EQ(param->dilation.size(), 2); std::vector wshape( {param->channels / param->groups, - data->shape[1] / param->groups, - param->kernel_size[0], - param->kernel_size[1]}); + data->shape[1] / param->groups, + param->kernel_size[0], + param->kernel_size[1]}); wshape = ConvertLayout(wshape, kOIHW, kernel_layout); wshape[kernel_layout.indexof('O')] *= param->groups; channels = param->channels; @@ -154,5 +154,153 @@ with the layer input to produce a tensor of outputs. .set_support_level(2) .add_type_rel("Conv2D", Conv2DRel); + +// Conv2DTranspose +TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); + +bool Conv2DTransposeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + static const Layout kOIHW("OIHW"); + + const Conv2DTransposeAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->weight_layout); + CHECK(in_layout.convertible(kNCHW)) + << "Conv only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + CHECK(kernel_layout.convertible(kOIHW)) + << "Conv only support kernel layouts that are convertible from OIHW." + << " But got "<< kernel_layout; + + IndexExpr channels, dilated_ksize_y, dilated_ksize_x; + const auto dshape_nchw = ConvertLayout(data->shape, in_layout, kNCHW); + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + CHECK_EQ(param->dilation.size(), 2); + + std::vector wshape({dshape_nchw[1], + param->channels / param->groups, + param->kernel_size[0], + param->kernel_size[1]}); + + wshape = ConvertLayout(wshape, kOIHW, kernel_layout); + dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + channels = param->channels; + + // assign result to reporter + reporter->Assign(types[1], TensorTypeNode::make(wshape, data->dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = ConvertLayout(weight->shape, kernel_layout, kOIHW); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 2); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3])) + << "Conv2D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size + << " wshape=" << Array(wshape); + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[1])) + << "Conv2D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels + << " wshape=" << Array(wshape); + } + CHECK(reporter->AssertEQ(dshape_nchw[1] / param->groups, wshape[0])); + channels = wshape[1]; + dilated_ksize_y = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_x = 1 + (wshape[3] - 1) * param->dilation[1]; + } + // dilation + std::vector oshape({dshape_nchw[0], channels, 0, 0}); + oshape[2] = (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y - + 2 * param->padding[0] + param->output_padding[0]); + oshape[3] = (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x - + 2 * param->padding[1] + param->output_padding[1]); + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = ConvertLayout(oshape, kNCHW, in_layout); + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + + +Expr MakeConv2DTranspose(Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string weight_layout, + Array output_padding, + DataType out_dtype) { + auto attrs = make_node(); + attrs->channels = channels; + attrs->kernel_size = kernel_size; + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->output_padding = std::move(output_padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->data_layout = std::move(data_layout); + attrs->weight_layout = std::move(weight_layout); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("nn.conv2d_transpose"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.conv2d_transpose") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeConv2DTranspose, args, rv); + }); + +RELAY_REGISTER_OP("nn.conv2d_transpose") +.describe(R"code(Transposed 2D convolution layer (sometimes called Deconvolution). + +The need for transposed convolutions generally arises +from the desire to use a transformation going in the opposite direction +of a normal convolution, i.e., from something that has the shape of the +output of some convolution to something that has the shape of its input +while maintaining a connectivity pattern that is compatible with +said convolution. + +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, in_channels, height, width) if `layout` is `NCHW`. +- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1]) +- **bias**: (channels,) +- **out**: This depends on the `layout` parameter. Output is 4D array of shape +v (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. + + out_height and out_width are calculated as:: + out_height = (height-1)*strides[0]-2*padding[0]+kernel_size[0]+output_padding[0] + out_width = (width-1)*strides[1]-2*padding[1]+kernel_size[1]+output_padding[1] + +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(2) +.add_type_rel("Conv2DTranspose", Conv2DTransposeRel); + } // namespace relay } // namespace tvm diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc new file mode 100644 index 000000000000..f2439b9fb7ca --- /dev/null +++ b/src/relay/op/nn/nn.cc @@ -0,0 +1,221 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file nn.cc + * \brief Property def of nn operators. + */ + +#include +#include +#include +#include +#include "../type_relations.h" +#include "../op_common.h" +#include "layout.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_API("relay.op.nn._make.softmax") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + auto make_func = [](Expr data, int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("nn.softmax"); + return CallNode::make(op, {data}, Attrs(attrs), {}); + }; + + runtime::detail::unpack_call(make_func, args, rv); +}); + +RELAY_REGISTER_OP("nn.softmax") + .describe(R"code(Softmax layer. + +.. math:: \text{softmax}(x)_i = \frac{exp(x_i)}{\sum_j exp(x_j)} + +.. note:: + This operator can be optimized away for inference. + +- **data**: The input data +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + + +TVM_REGISTER_API("relay.op.nn._make.log_softmax") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + auto make_func = [](Expr data, int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("nn.log_softmax"); + return CallNode::make(op, {data}, Attrs(attrs), {}); + }; + + runtime::detail::unpack_call(make_func, args, rv); +}); + +RELAY_REGISTER_OP("nn.log_softmax") + .describe(R"code(Computes log softmax. + +.. math:: \text{log_softmax}(x)_i = \log \frac{exp(x_i)}{\sum_j exp(x_j)} + +.. note:: + This operator can be optimized away for inference. + +- **data**: The input data +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + + +// BatchFlatten +bool BatchFlattenRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + if (data->shape.size() == 0) return false; + + auto target_dim = make_const(Int(32), 1); + + for (uint32_t i = 1; i < data->shape.size(); ++i) { + target_dim = target_dim * data->shape[i]; + } + + std::vector oshape({data->shape[0], target_dim}); + + // assign output type + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Expr MakeBatchFlatten(Expr data) { + static const Op& op = Op::Get("nn.batch_flatten"); + return CallNode::make(op, {data}, Attrs(), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.batch_flatten") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeBatchFlatten, args, rv); + }); + + +RELAY_REGISTER_OP("nn.batch_flatten") +.describe(R"code(Flattens the input into a 2-D array. + +For an input array with shape ``(d1, d2, ..., dk)``, `batch_flatten` operation reshapes +the input array into an output array of shape ``(d1, d2*...*dk)``. + +Example:: + + x = [[ + [1,2,3], + [4,5,6], + [7,8,9] + ], + [ [1,2,3], + [4,5,6], + [7,8,9] + ]], + + batch_flatten(x) = [[ 1., 2., 3., 4., 5., 6., 7., 8., 9.], + [ 1., 2., 3., 4., 5., 6., 7., 8., 9.]] + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("BatchFlatten", BatchFlattenRel); + +RELAY_REGISTER_UNARY_OP("relay.op.nn._make.", "relu") +.describe(R"code(Returns the relu input array, computed element-wise. + +.. math:: + max(x, 0) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + + +// Positional relay function to create LRN operator used by frontend FFI. +Expr MakeLRN(Expr data, + IndexExpr size, + IndexExpr axis, + double alpha, + double beta, + double bias) { + auto attrs = make_node(); + attrs->size = size; + attrs->axis = axis; + attrs->alpha = alpha; + attrs->beta = beta; + attrs->bias = bias; + static const Op& op = Op::Get("nn.lrn"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.lrn") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeLRN, args, rv); + }); + +RELAY_REGISTER_OP("nn.lrn") + .describe(R"code(LRN layer. + +Normalize the input in a local region across or within feature maps. +Each input value is divided by (1 + (\alpha/n) \sum_i x_i^2)^\beta, +where n is the size of each local region, and the sum is taken over the region +centered at that value (zero padding is added where necessary). + +.. math:: + + data / (bias + (alpha * sum_data ^2 /size))^beta + +- **data**: The input tensor. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("Identity", IdentityRel); + + +// Positional relay function to create L2Normalize operator used by frontend FFI. +Expr MakeL2Normalize(Expr data, + double eps, + Array axis) { + auto attrs = make_node(); + attrs->eps = eps; + attrs->axis = std::move(axis); + static const Op& op = Op::Get("nn.l2_normalize"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.l2_normalize") + .set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeL2Normalize, args, rv); + }); + +RELAY_REGISTER_OP("nn.l2_normalize") + .describe(R"code(L2 Normalization layer. + +Normalizes along dimension axis using an L2 norm + +.. math:: + output = x / sqrt(max(sum(x^2), epsilon)) + +- **data**: The input tensor. +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("Identity", IdentityRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc new file mode 100644 index 000000000000..665eaf6de880 --- /dev/null +++ b/src/relay/op/nn/pooling.cc @@ -0,0 +1,270 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file pooling.cc + * \brief Pooling operators + */ +#include +#include +#include +#include "layout.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs); + +template +bool Pool2DRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + + CHECK(data != nullptr); + const auto dshape = data->shape; + CHECK_NE(dshape.size(), 0); + CHECK_GE(dshape.size(), 2U) + << "Pool2D only support input >= 2-D: input must have height and width"; + const auto param = attrs.as(); + CHECK(param != nullptr); + + Layout layout(param->layout); + CHECK(layout.contains('H') && layout.contains('W') && + !layout.contains('h') && !layout.contains('w')) + << "Invalid layout " << layout + << ". Pool2D layout must have H and W, which cannot be split"; + + const auto hidx = layout.indexof('H'); + const auto widx = layout.indexof('W'); + + IndexExpr pad_h, pad_w; + if (param->padding.size() == 1) { + pad_h = param->padding[0] * 2; + pad_w = param->padding[0] * 2; + } else if (param->padding.size() == 2) { + // (top, left) + pad_h = param->padding[0] * 2; + pad_w = param->padding[1] * 2; + } else if (param->padding.size() == 4) { + // (top, left, bottom, right) + pad_h = param->padding[0] + param->padding[2]; + pad_w = param->padding[1] + param->padding[3]; + } else { + return false; + } + + std::vector oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); + if (param->ceil_mode) { + oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0] + + param->strides[0] - 1) / param->strides[0]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1] + + param->strides[1] - 1) / param->strides[1]) + 1; + } else { + oshape[hidx] = ((dshape[hidx] + pad_h - param->pool_size[0]) / param->strides[0]) + 1; + oshape[widx] = ((dshape[widx] + pad_w - param->pool_size[1]) / param->strides[1]) + 1; + } + + // assign output type + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +// MaxPool2D +Expr MakeMaxPool2D(Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode) { + auto attrs = make_node(); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->layout = std::move(layout); + attrs->ceil_mode = ceil_mode; + static const Op& op = Op::Get("nn.max_pool2d"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.max_pool2d") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeMaxPool2D, args, rv); + }); + + +RELAY_REGISTER_OP("nn.max_pool2d") +.describe(R"code(Max pooling operation for two dimensional data. + +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. + out_height and out_width are calculated as:: + + out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1 + out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1 + + where padding will be an expanded array based on number of values passed as:: + one int : all sides same padding used. + two int : bottom, right use same as top and left. + four int: padding width in the order of (top, left, bottom, right). + + When `ceil_mode` is `True`, ceil will be used instead of floor in this + equation. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("MaxPool2D", Pool2DRel); + + +// AvgPool2D +Expr MakeAvgPool2D(Expr data, + Array pool_size, + Array strides, + Array padding, + std::string layout, + bool ceil_mode, + bool count_include_pad) { + auto attrs = make_node(); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->layout = std::move(layout); + attrs->ceil_mode = ceil_mode; + attrs->count_include_pad = count_include_pad; + static const Op& op = Op::Get("nn.avg_pool2d"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.avg_pool2d") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeAvgPool2D, args, rv); + }); + + +RELAY_REGISTER_OP("nn.avg_pool2d") +.describe(R"code( +Average pooling operation for one dimensional data. + +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, out_height, out_width) if `layout` is `NCHW`. + out_height and out_width are calculated as:: + + out_height = floor((height+padding[0]+padding[2]-pool_size[0])/strides[0])+1 + out_width = floor((width+padding[1]+padding[3]-pool_size[1])/strides[1])+1 + + where padding will be an expanded array based on number of values passed as:: + one int : all sides same padding used. + two int : bottom, right use same as top and left. + four int: padding width in the order of (top, left, bottom, right). + + When `ceil_mode` is `True`, ceil will be used instead of floor in this + equation. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("AvgPool2D", Pool2DRel); + +// Global Pool +TVM_REGISTER_NODE_TYPE(GlobalPool2DAttrs); + +bool GlobalPool2DRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + + CHECK(data != nullptr); + const auto dshape = data->shape; + CHECK_NE(dshape.size(), 0); + CHECK_GE(dshape.size(), 2U) + << "Pool2D only support input >= 2-D: input must have height and width"; + const auto param = attrs.as(); + CHECK(param != nullptr); + + Layout layout(param->layout); + CHECK(layout.contains('H') && layout.contains('W') && + !layout.contains('h') && !layout.contains('w')) + << "Invalid layout " << layout + << ". Pool2D layout must have H and W, which cannot be split"; + + const auto hidx = layout.indexof('H'); + const auto widx = layout.indexof('W'); + std::vector oshape({dshape[0], dshape[1], dshape[2], dshape[3]}); + oshape[hidx] = oshape[widx] = 1; + + // assign output type + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Expr MakeGlobalAvgPool2D(Expr data, + std::string layout) { + auto attrs = make_node(); + attrs->layout = std::move(layout); + static const Op& op = Op::Get("nn.global_avg_pool2d"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.global_avg_pool2d") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeGlobalAvgPool2D, args, rv); + }); + +// GlobalAvgPool +RELAY_REGISTER_OP("nn.global_avg_pool2d") +.describe(R"code(Global average pooling operation for 2D data. + +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, 1, 1) if `layout` is `NCHW`. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("GlobalAvgPool2D", GlobalPool2DRel); + +// GlobalMaxPool +Expr MakeGlobalMaxPool2D(Expr data, + std::string layout) { + auto attrs = make_node(); + attrs->layout = std::move(layout); + static const Op& op = Op::Get("nn.global_max_pool2d"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.nn._make.global_max_pool2d") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeGlobalMaxPool2D, args, rv); + }); + + +RELAY_REGISTER_OP("nn.global_max_pool2d") +.describe(R"code(Global max pooling operation for 2D data. + +- **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, channels, height, width) if `layout` is `NCHW`. +- **out**: This depends on the `layout` parameter. Output is 4D array of shape + (batch_size, channels, 1, 1) if `layout` is `NCHW`. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("GlobalMaxPool2D", GlobalPool2DRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/nn/upsampling.cc b/src/relay/op/nn/upsampling.cc new file mode 100644 index 000000000000..a429a7c40e82 --- /dev/null +++ b/src/relay/op/nn/upsampling.cc @@ -0,0 +1,87 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file upsampling.cc + * \brief upsampling operator + */ +#include +#include +#include "layout.h" + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(UpSamplingAttrs); + +bool UpSamplingRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) return false; + + static const Layout kNCHW("NCHW"); + + const UpSamplingAttrs* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->layout); + CHECK(in_layout.convertible(kNCHW)) + << "UpSampling only support input layouts that are convertible from NCHW." + << " But got " << in_layout; + + auto oshape = ConvertLayout(data->shape, in_layout, kNCHW); + + oshape[2] = oshape[2] * param->scale; + oshape[3] = oshape[3] * param->scale; + + // assign output type + reporter->Assign(types[1], + TensorTypeNode::make(ConvertLayout(oshape, kNCHW, in_layout), + data->dtype)); + return true; +} + + +// Positional relay function to create upsampling operator +// used by frontend FFI. +Expr MakeUpSampling(Expr data, + int scale, + std::string layout, + std::string method) { + auto attrs = make_node(); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->scale = scale; + static const Op& op = Op::Get("nn.upsampling"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.upsampling") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeUpSampling, args, rv); + }); + + +RELAY_REGISTER_OP("nn.upsampling") +.describe(R"code(Perform upsampling on input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **out**: Output is 4D array of shape + for layout NCHW + (batch_size, channels, in_height*scale, in_width*scale) + + for layout NHWC + (batch_size, in_height*scale, in_width*scale, channels) + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(2) +.add_type_rel("UpSampling", UpSamplingRel); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h new file mode 100644 index 000000000000..d07b7f02cd67 --- /dev/null +++ b/src/relay/op/op_common.h @@ -0,0 +1,76 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file op_common.h + * \brief A set of utilities and common functionality + * for relay ops. + */ +#ifndef TVM_RELAY_OP_OP_COMMON_H_ +#define TVM_RELAY_OP_OP_COMMON_H_ + +#include +#include +#include + +namespace tvm { +namespace relay { + +template +std::vector AsVector(const Array &array) { + std::vector result; + result.reserve(array.size()); + for (const T& ele : array) { + result.push_back(ele); + } + return result; +} + +/*! Quick helper macro + * - Expose a positional make function to construct the node. + * - Register op to the registry. + * + * We make the decision to always only expose positional argument. + * We will do rewrapping in the frontend to support language + * sugars such as keyword arguments and default value. + * + * \param Prefix the prefix of the registry, for example, "relay.op._make.". + * + * \param OpName the name of registry. + */ +#define RELAY_REGISTER_UNARY_OP(Prefix, OpName) \ + TVM_REGISTER_API(Prefix OpName) \ + .set_body_typed([](Expr data) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {data}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(1) \ + .add_argument("data", "Tensor", "The input tensor.") + +/*! Quick helper macro + * - Expose a positional make function to construct the node. + * - Register op to the registry. + * + * We make the decision to always only expose positional argument. + * We will do rewrapping in the frontend to support language + * sugars such as keyword arguments and default value. + * + * \param Prefix the prefix of the registry, for example, "relay.op._make.". + * + * \param OpName the name of registry. + */ +#define RELAY_REGISTER_BINARY_OP(Prefix, OpName) \ + TVM_REGISTER_API(Prefix OpName) \ + .set_body_typed([](Expr lhs, Expr rhs) { \ + static const Op& op = Op::Get(OpName); \ + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ + }); \ + RELAY_REGISTER_OP(OpName) \ + .set_num_inputs(2) \ + .add_argument("lhs", "Tensor", "The left hand side tensor.") \ + .add_argument("rhs", "Tensor", "The right hand side tensor.") \ + .add_type_rel("Broadcast", BroadcastRel) + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_OP_OP_COMMON_H_ diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 4c0fa657bac4..fe614aa4ea1c 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -6,55 +6,85 @@ #include #include #include "../type_relations.h" +#include "../op_common.h" namespace tvm { namespace relay { -#define RELAY_REGISTER_BINARY_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ - .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(2) \ - .add_argument("lhs", "Tensor", "The left hand side tensor.") \ - .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .add_type_rel("Broadcast", BroadcastRel) - // Addition -RELAY_REGISTER_BINARY_OP("add") +RELAY_REGISTER_BINARY_OP("relay.op._make.", "add") .describe("Elementwise add with with broadcasting") .set_support_level(1); -RELAY_REGISTER_BINARY_OP("subtract") +// Subtraction +RELAY_REGISTER_BINARY_OP("relay.op._make.", "subtract") .describe("Elementwise substract with broadcasting") .set_support_level(1); -RELAY_REGISTER_BINARY_OP("right_shift") +// Right shift +RELAY_REGISTER_BINARY_OP("relay.op._make.", "right_shift") .describe("Elementwise right shift with broadcasting") .set_support_level(4); +RELAY_REGISTER_BINARY_OP("relay.op._make.", "left_shift") +.describe("Elementwise left shift with broadcasting") +.set_support_level(4); + +RELAY_REGISTER_BINARY_OP("relay.op._make.", "maximum") +.describe("Elementwise maximum of two tensors with broadcasting") +.set_support_level(4); + +RELAY_REGISTER_BINARY_OP("relay.op._make.", "minimum") +.describe("Elementwise minimum of two tensors with broadcasting") +.set_support_level(4); + +RELAY_REGISTER_BINARY_OP("relay.op._make.", "divide") +.describe("Elementwise divide with broadcasting") +.set_support_level(1); + +RELAY_REGISTER_BINARY_OP("relay.op._make.", "multiply") +.describe("Elementwise multiply with broadcasting") +.set_support_level(1); + +RELAY_REGISTER_BINARY_OP("relay.op._make.", "pow") +.describe("Elementwise power with broadcasting") +.set_support_level(4); + +RELAY_REGISTER_BINARY_OP("relay.op._make.", "mod") +.describe("Elementwise mod with broadcasting") +.set_support_level(1); + // Comparisons -#define RELAY_REGISTER_CMP_OP(OpName, SupportLevel) \ +#define RELAY_REGISTER_CMP_OP(OpName) \ TVM_REGISTER_API("relay.op._make." OpName) \ .set_body_typed([](Expr lhs, Expr rhs) { \ - static const Op& op = Op::Get(OpName); \ + static const Op& op = Op::Get(OpName); \ return CallNode::make(op, {lhs, rhs}, Attrs(), {}); \ }); \ RELAY_REGISTER_OP(OpName) \ .set_num_inputs(2) \ .add_argument("lhs", "Tensor", "The left hand side tensor.") \ .add_argument("rhs", "Tensor", "The right hand side tensor.") \ - .set_support_level(SupportLevel) \ - .add_type_rel("BroadcastComp", BroadcastCompRel); - -RELAY_REGISTER_CMP_OP("equal", 4); -RELAY_REGISTER_CMP_OP("not_equal", 4); -RELAY_REGISTER_CMP_OP("less", 4); -RELAY_REGISTER_CMP_OP("less_equal", 4); -RELAY_REGISTER_CMP_OP("greater", 4); -RELAY_REGISTER_CMP_OP("greater_equal", 4); + .add_type_rel("BroadcastComp", BroadcastCompRel) + +RELAY_REGISTER_CMP_OP("equal") +.describe("Elementwise equal compare with broadcasting") +.set_support_level(4); +RELAY_REGISTER_CMP_OP("not_equal") +.describe("Elementwise not equal with broadcasting") +.set_support_level(4); +RELAY_REGISTER_CMP_OP("less") +.describe("Elementwise less than with broadcasting") +.set_support_level(4); +RELAY_REGISTER_CMP_OP("less_equal") +.describe("Elementwise less than or equal compare with broadcasting") +.set_support_level(4); +RELAY_REGISTER_CMP_OP("greater") +.describe("Elementwise greater than compare with broadcasting") +.set_support_level(4); +RELAY_REGISTER_CMP_OP("greater_equal") +.describe("Elementwise greater than or equal compare with broadcasting") +.set_support_level(4); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 61db1f90ae39..663dd5c38ec5 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -5,25 +5,29 @@ */ #include #include +#include #include +#include "../op_common.h" namespace tvm { namespace relay { +/* relay.expand_dims */ + TVM_REGISTER_NODE_TYPE(ExpandDimsAttrs); bool ExpandDimsRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - // `types` contains: [data, output] + // `types` contains: [data, result] CHECK_EQ(types.size(), 2); const auto* data = types[0].as(); if (data == nullptr) { return false; } - const ExpandDimsAttrs* param = attrs.as(); + const auto* param = attrs.as(); const int ndim = static_cast(data->shape.size()); const int axis = param->axis; const int num_newaxis = param->num_newaxis; @@ -76,6 +80,423 @@ RELAY_REGISTER_OP("expand_dims") .set_support_level(1) .add_type_rel("ExpandDims", ExpandDimsRel); +/* relay.concatenate */ + +TVM_REGISTER_NODE_TYPE(ConcatenateAttrs); + +bool ConcatenateRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + const auto* tensor_tuple = types[0].as(); + if (tensor_tuple == nullptr) { + return false; + } + const auto* param = attrs.as(); + const auto& first = Downcast(tensor_tuple->fields[0]); + // Sanity check: ndim and dtype. + const int ndim = static_cast(first->shape.size()); + const DataType dtype = first->dtype; + for (const Type& ele : tensor_tuple->fields) { + const auto& e = Downcast(ele); + int e_ndim = static_cast(e->shape.size()); + const DataType& e_dtype = e->dtype; + CHECK_EQ(e_ndim, ndim) << "relay.concatenate requires all tensors have the same ndim"; + CHECK_EQ(e_dtype, dtype) << "relay.concatenate requires all tensors have the same dtype"; + } + // Sanity check: axis + int axis = param->axis; + CHECK(-ndim <= axis && axis < ndim) + << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis + << ", and ndim = " << ndim; + axis = axis < 0 ? ndim + axis : axis; + // Calculate shape + std::vector&& oshape = AsVector(first->shape); + IndexExpr &concat_dim = oshape[axis]; + for (int i = 1; i < static_cast(tensor_tuple->fields.size()); ++i) { + const auto& e = Downcast(tensor_tuple->fields[i]); + concat_dim += e->shape[axis]; + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, dtype)); + return true; +} + +Expr MakeConcatenate(Expr data, + int axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("concatenate"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.concatenate") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeConcatenate, args, rv); +}); + +RELAY_REGISTER_OP("concatenate") +.describe(R"code(Concatenate the input tensors along the given axis. + +- **data** : A list of tensors. + +- **axis** : The axis along which the tensors are concatenated. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input list of tensors.") +.set_support_level(1) +.add_type_rel("Concatenate", ConcatenateRel); + +/* relay.transpose */ + +bool TransposeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* param = attrs.as(); + const int ndim = data->shape.size(); + const Array& axes = param->axes; + // check dimension match + CHECK(axes.empty() || static_cast(axes.size()) == ndim) + << "Dimension mismatch: axes has " << axes.size() << " elements" + << ", but data.ndim = " << ndim; + // construct int_axes + std::vector int_axes; + int_axes.reserve(ndim); + if (axes.empty()) { + for (int i = ndim - 1; i >= 0; --i) { + int_axes.push_back(i); + } + } else { + std::vector axis_used(ndim, 0); + for (const IndexExpr& e : axes) { + const int64_t *axis_ptr = as_const_int(e); + CHECK(axis_ptr != nullptr); + int axis = *axis_ptr; + // sanity check for axis and ndim + CHECK(-ndim <= axis && axis < ndim) + << "transpose only allows each `axis` in `axes` in range [-data.ndim, data.ndim)" + << ", but got axis = " << axis + << ", and data.ndim = " << ndim; + axis = axis < 0 ? axis + ndim : axis; + // sanity check for duplication + CHECK(!axis_used[axis]) << "Duplicate axes in transpose: " << axis; + axis_used[axis] = 1; + int_axes.push_back(axis); + } + } + std::vector oshape; + oshape.reserve(ndim); + for (int axis : int_axes) { + oshape.push_back(data->shape[axis]); + } + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Expr MakeTranspose(Expr data, + Array axes) { + auto attrs = make_node(); + attrs->axes = std::move(axes); + static const Op& op = Op::Get("transpose"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.transpose") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeTranspose, args, rv); +}); + +RELAY_REGISTER_OP("transpose") +.describe(R"code(Permutes the dimensions of an array. + +- **data**: The input data to the operator. + +- **axes**: The target axes order, reverse order if not specified. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(3) +.add_type_rel("Transpose", TransposeRel); + +/* relay.reshape */ + +bool ReshapeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // types: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* param = attrs.as(); + reporter->Assign(types[1], TensorTypeNode::make(param->newshape, data->dtype)); + return true; +} + +Expr MakeReshape(Expr data, + Array newshape) { + auto attrs = make_node(); + attrs->newshape = std::move(newshape); + static const Op& op = Op::Get("reshape"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.reshape") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeReshape, args, rv); +}); + +RELAY_REGISTER_OP("reshape") +.describe(R"code(Reshapes the input array. + +Example:: + +To give user more convenience in without doing manual shape inference, +some dimensions of the shape can take special values from the set {0, -1, -2, -3, -4}. +The significance of each is explained below: + +- ``0`` copy this dimension from the input to the output shape. + +Example:: + +- data.shape = (2,3,4), newshape = (4,0,2), result.shape = (4,3,2) +- data.shape = (2,3,4), newshape = (2,0,0), result.shape = (2,3,4) + +- ``-1`` infers the dimension of the output shape by using the remainder of the input dimensions +keeping the size of the new array same as that of the input array. +At most one dimension of shape can be -1. + +Example:: + +- data.shape = (2,3,4), newshape = (6,1,-1), result.shape = (6,1,4) +- data.shape = (2,3,4), newshape = (3,-1,8), result.shape = (3,1,8) +- data.shape = (2,3,4), newshape = (-1,), result.shape = (24,) + +- ``-2`` copy all/remainder of the input dimensions to the output shape. + +Example:: + +- data.shape = (2,3,4), newshape = (-2,), result.shape = (2,3,4) +- data.shape = (2,3,4), newshape = (2,-2), result.shape = (2,3,4) +- data.shape = (2,3,4), newshape = (-2,1,1), result.shape = (2,3,4,1,1) + +- ``-3`` use the product of two consecutive dimensions of the input shape as the output dimension. + +Example:: + +- data.shape = (2,3,4), newshape = (-3,4), result.shape = (6,4) +- data.shape = (2,3,4,5), newshape = (-3,-3), result.shape = (6,20) +- data.shape = (2,3,4), newshape = (0,-3), result.shape = (2,12) +- data.shape = (2,3,4), newshape = (-3,-2), result.shape = (6,4) + +- ``-4`` split one dimension of the input into two dimensions passed subsequent to -4 in shape (can contain -1). + +Example:: + +- data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) +- data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(3) +.add_type_rel("Reshape", ReshapeRel); + +// Take +TVM_REGISTER_NODE_TYPE(TakeAttrs); + +bool TakeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, indices, result] + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + CHECK(data != nullptr); + const auto* indices = types[1].as(); + CHECK(indices != nullptr); + const auto param = attrs.as(); + CHECK(param != nullptr); + + if (!param->axis.defined()) { + std::vector&& oshape = AsVector(indices->shape); + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; + } + + std::vector oshape; + const auto ndim_data = static_cast(data->shape.size()); + const auto ndim_indices = static_cast(indices->shape.size()); + auto axis = (*as_const_int(param->axis)); + if (axis < 0) axis += ndim_data; + CHECK_LE(axis, ndim_data) + << "axis should be with in data shape" + << ", but got = " << axis; + + oshape.reserve(ndim_data - 1 + ndim_indices); + for (int i = 0; i < axis; ++i) { + oshape.emplace_back(data->shape[i]); + } + for (int i = 0; i < ndim_indices; ++i) { + oshape.emplace_back(indices->shape[i]); + } + for (int i = axis+1; i < ndim_data; ++i) { + oshape.emplace_back(data->shape[i]); + } + + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Expr MakeTake(Expr data, + Expr indices, + IndexExpr axis) { + auto attrs = make_node(); + attrs->axis = axis; + static const Op& op = Op::Get("take"); + return CallNode::make(op, {data, indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.take") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeTake, args, rv); +}); + +RELAY_REGISTER_OP("take") +.describe(R"code(Take elements from an array along an axis. + +When axis is not None, this function does the same thing as 'fancy' indexing +(indexing arrays using arrays); however, it can be easier to use if you need +elements along a given axis. + +**Note** that when axis is none the flattened input array is used. + +Examples:: + + a = [[ 1, 2], + [ 3, 4]] + indices = [3, 0, 2] + take(a, indices) = [ 4, 1, 3] + + a = [[ 1., 2.], + [ 3., 4.]] + indices = [1, 0] + take(a, indices, axis=1) = [[ 2., 1.], + [ 4., 3.]] + +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("indices", "Tensor", "The indices tensor.") +.set_support_level(2) +.add_type_rel("Take", TakeRel); + +TVM_REGISTER_NODE_TYPE(FullAttrs); + +bool FullRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const FullAttrs* param = attrs.as(); + const auto* fill_value = types[0].as(); + if (fill_value == nullptr) { + return false; + } + + DataType out_dtype = param->dtype; + if (out_dtype.bits() == 0) { + out_dtype = fill_value->dtype; + } + + CHECK_EQ(fill_value->shape.size(), 0) + << "Fill value should be a scalar but has dimension " + << fill_value->shape.size() << "."; + + reporter->Assign(types[1], TensorTypeNode::make(param->shape, out_dtype)); + return true; +} + +Expr MakeFull(Expr fill_value, + Array shape, + DataType dtype) { + auto attrs = make_node(); + attrs->shape = std::move(shape); + attrs->dtype = std::move(dtype); + static const Op& op = Op::Get("full"); + return CallNode::make(op, {fill_value}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make.full") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeFull, args, rv); +}); + +RELAY_REGISTER_OP("full") +.describe(R"code(Fill array with scalar value. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.add_argument("fill_value", "double", "The value to fill.") +.set_support_level(3) +.add_type_rel("Full", FullRel); + +bool FullLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* fill_value = types[1].as(); + if (fill_value == nullptr) { + return false; + } + + CHECK_EQ(fill_value->shape.size(), 0) + << "The fill value should be a scalar but here it has dimension " + << fill_value->shape.size() << "."; + + reporter->Assign(types[2], TensorTypeNode::make(data->shape, data->dtype)); + return true; +} + +Expr MakeFullLike(Expr data, + Expr fill_value) { + static const Op& op = Op::Get("full_like"); + return CallNode::make(op, {data, fill_value}, Attrs(), {}); +} + +TVM_REGISTER_API("relay.op._make.full_like") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeFullLike, args, rv); + }); + +RELAY_REGISTER_OP("full_like") +.describe(R"code(Return an scalar value array with the same shape +and type as the input array. + +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("fill_value", "double", "Scalar value to fill.") +.set_support_level(3) +.add_type_rel("FullLike", FullLikeRel); } // namespace relay } // namespace tvm diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index 798d4aa791ad..22f97e8f0d54 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -6,30 +6,13 @@ #include #include #include "../type_relations.h" +#include "../op_common.h" namespace tvm { namespace relay { -// Quick helper macro -// - Expose a positional make function to construct the node. -// - Register op to the registry. -// -// We make the decision to always only expose positional argument. -// We will do rewrapping in the frontend to support language -// sugars such as keyword arguments and default value. -// -#define RELAY_REGISTER_UNARY_OP(OpName) \ - TVM_REGISTER_API("relay.op._make." OpName) \ - .set_body_typed([](Expr data) { \ - static const Op& op = Op::Get(OpName); \ - return CallNode::make(op, {data}, Attrs(), {}); \ - }); \ - RELAY_REGISTER_OP(OpName) \ - .set_num_inputs(1) \ - .add_argument("data", "Tensor", "The input tensor.") - - -RELAY_REGISTER_UNARY_OP("log") + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "log") .describe(R"code(Returns the log input array, computed element-wise. .. math:: @@ -39,11 +22,7 @@ RELAY_REGISTER_UNARY_OP("log") .set_support_level(1) .add_type_rel("Identity", IdentityRel); -// data : Tensor[shape, dtype] -// result: Tensor[shape, dtype] - - -RELAY_REGISTER_UNARY_OP("exp") +RELAY_REGISTER_UNARY_OP("relay.op._make.", "exp") .describe(R"code(Returns the exp input array, computed element-wise. .. math:: @@ -54,7 +33,7 @@ RELAY_REGISTER_UNARY_OP("exp") .add_type_rel("Identity", IdentityRel); -RELAY_REGISTER_UNARY_OP("sqrt") +RELAY_REGISTER_UNARY_OP("relay.op._make.", "sqrt") .describe(R"code(Returns the sqrt input array, computed element-wise. .. math:: @@ -64,19 +43,131 @@ RELAY_REGISTER_UNARY_OP("sqrt") .set_support_level(1) .add_type_rel("Identity", IdentityRel); +RELAY_REGISTER_UNARY_OP("relay.op._make.", "zeros_like") +.describe(R"code(Returns an array of zeros, with same type and shape as the input. +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "ones_like") +.describe(R"code(Returns an array of ones, with same type and shape as the input. +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "sigmoid") +.describe(R"code(Returns the sigmoid input array, computed element-wise. + +.. math:: + sigmoid(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(1) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "copy") +.describe(R"code(Copy a tensor. +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("Identity", IdentityRel); -// Concat -TVM_REGISTER_API("relay.op._make.concat") - .set_body_typed([](Expr tuple) { - static const Op& op = Op::Get("concat"); - return CallNode::make(op, { tuple }, Attrs(), {}); +// Clip +struct ClipAttrs : public tvm::AttrsNode { + double a_min; + double a_max; + + TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { + TVM_ATTR_FIELD(a_min) + .describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max) + .describe("The maximum clip value."); + } +}; + +TVM_REGISTER_API("relay.op._make.clip") + .set_body_typed([](Expr a, double a_min, double a_max) { + auto attrs = make_node(); + attrs->a_min = a_min; + attrs->a_max = a_max; + static const Op& op = Op::Get("clip"); + return CallNode::make(op, {a}, Attrs(attrs), {}); }); -RELAY_REGISTER_OP("concat") -.set_num_inputs(1) -.add_argument("tuple", "Tuple", "The tupled tensor arguments.") +RELAY_REGISTER_OP("clip") + .describe(R"code(Clip tensor values. + This function takes a tensor, a minimum value `a_min`, and a maximum value `a_max`, and returns a clipped tensor where all values below `a_min` are set to `a_min` and all values above `a_max` are set to `a_max`. `a_min` and `a_max` are cast to the tensor's dtype. + )code" TVM_ADD_FILELINE) + .set_num_inputs(1) + .add_argument("tensor", "Tensor", "The input tensor.") + .set_support_level(3) + .add_type_rel("Clip", IdentityRel); + + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "floor") +.describe(R"code(Returns the floor of input array, computed element-wise. +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "ceil") +.describe(R"code(Returns the ceil of input array, computed element-wise. + +.. math:: + ceil(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "trunc") +.describe(R"code(Returns the trunc of input array, computed element-wise. + +.. math:: + trunc(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "round") +.describe(R"code(Returns the round of input array, computed element-wise. + +.. math:: + round(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "abs") +.describe(R"code(Returns the abs of input array, computed element-wise. + +.. math:: + abs(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "tanh") +.describe(R"code(Returns the tanh of input array, computed element-wise. + +.. math:: + Y = sinh(X) / cosh(X) + +)code" TVM_ADD_FILELINE) .set_support_level(1) -.add_type_rel("Concat", ConcatRel); +.add_type_rel("Identity", IdentityRel); + +RELAY_REGISTER_UNARY_OP("relay.op._make.", "negative") +.describe(R"code(Returns the numeric negative of input array, computed element-wise. + +.. math:: + -(x) + +)code" TVM_ADD_FILELINE) +.set_support_level(3) +.add_type_rel("Identity", IdentityRel); } // namespace relay } // namespace tvm diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 63ce834be7cf..0ed0e3df3056 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -1,10 +1,11 @@ /*! * Copyright (c) 2018 by Contributors * \file src/tvm/relay/pass/alpha_eq.cc - * \brief The structral equivalence comparison. + * \brief Check that two type are syntactically equal up to alpha equivalence. */ #include #include +#include #include "./type_visitor.h" #include "tvm/relay/pass.h" @@ -13,6 +14,25 @@ namespace relay { using namespace tvm::runtime; +bool SameNDArray(const NDArray& lhs, const NDArray& rhs) { + if (lhs.defined() != rhs.defined()) { + return false; + } else if (lhs.same_as(rhs)) { + return true; + } else { + auto ldt = lhs->dtype; + auto rdt = rhs->dtype; + CHECK_EQ(lhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + CHECK_EQ(rhs->ctx.device_type, kDLCPU) << "can only compare CPU tensor"; + if (ldt.code == rdt.code && ldt.lanes == rdt.lanes && ldt.bits == rdt.bits) { + size_t s = GetDataSize(*lhs.operator->()); + return memcmp(lhs->data, rhs->data, s) == 0; + } else { + return false; + } + } +} + struct TypeAlphaEq : TypeVisitor { tvm::Map eq_map; bool equal; @@ -38,8 +58,8 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TensorTypeNode *tt1, const Type& t2) final { - if (const TensorTypeNode *tt2 = t2.as()) { + void VisitType_(const TensorTypeNode* tt1, const Type& t2) final { + if (const TensorTypeNode* tt2 = t2.as()) { DataTypeEqual(tt1->dtype, tt2->dtype); ShapeEqual(tt1->shape, tt2->shape); } else { @@ -47,8 +67,8 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const IncompleteTypeNode *bt1, const Type& t2) final { - if (const IncompleteTypeNode *bt2 = t2.as()) { + void VisitType_(const IncompleteTypeNode* bt1, const Type& t2) final { + if (const IncompleteTypeNode* bt2 = t2.as()) { equal = equal && bt1 == bt2; return; } else { @@ -56,8 +76,8 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const TypeParamNode *ti1, const Type& t2) final { - if (const TypeParamNode *ti2 = t2.as()) { + void VisitType_(const TypeParamNode* ti1, const Type& t2) final { + if (const TypeParamNode* ti2 = t2.as()) { auto tid1 = GetRef(ti1); auto tid2 = GetRef(ti2); @@ -86,13 +106,25 @@ struct TypeAlphaEq : TypeVisitor { } } - void VisitType_(const FuncTypeNode *op, const Type& t2) final { - if (const FuncTypeNode *ta2 = t2.as()) { - if (op->arg_types.size() != ta2->arg_types.size()) { + void VisitType_(const FuncTypeNode* op, const Type& t2) final { + if (const FuncTypeNode* ta2 = t2.as()) { + if (op->arg_types.size() != ta2->arg_types.size() + || op->type_params.size() != ta2->type_params.size() + || op->type_constraints.size() != ta2->type_constraints.size()) { equal = false; return; } + // must visit params first so they are appropriate entered + // into equality map + for (size_t i = 0; i < op->type_params.size(); i++) { + eq_map.Set(op->type_params[i], ta2->type_params[i]); + this->VisitType(op->type_params[i], ta2->type_params[i]); + if (!equal) { + return; + } + } + for (size_t i = 0; i < op->arg_types.size(); i++) { this->VisitType(op->arg_types[i], ta2->arg_types[i]); if (!equal) { @@ -101,21 +133,48 @@ struct TypeAlphaEq : TypeVisitor { } this->VisitType(op->ret_type, ta2->ret_type); + if (!equal) { + return; + } + + for (size_t i = 0; i < op->type_constraints.size(); i++) { + this->VisitType(op->type_constraints[i], ta2->type_constraints[i]); + if (!equal) { + return; + } + } } else { equal = false; } } - void VisitType_(const TypeRelationNode *tr1, const Type& t2) final { - if (const TypeRelationNode *tr2 = t2.as()) { - equal = tr1 == tr2; + void VisitType_(const TypeRelationNode* tr1, const Type& t2) final { + if (const TypeRelationNode* tr2 = t2.as()) { + if (tr1->func != tr2->func + || tr1->num_inputs != tr2->num_inputs + || tr1->attrs != tr2->attrs) { + equal = false; + return; + } + + if (tr1->args.size() != tr2->args.size()) { + equal = false; + return; + } + + for (size_t i = 0; i < tr1->args.size(); i++) { + this->VisitType(tr1->args[i], tr2->args[i]); + if (!equal) { + return; + } + } } else { equal = false; } } - void VisitType_(const TupleTypeNode *op, const Type& t2) final { - if (const TupleTypeNode *pt = t2.as()) { + void VisitType_(const TupleTypeNode* op, const Type& t2) final { + if (const TupleTypeNode* pt = t2.as()) { if (op->fields.size() != pt->fields.size()) { equal = false; return; @@ -146,8 +205,8 @@ struct AlphaEq : ExprFunctor { bool equal; AlphaEq() : eq_map(), equal(true) {} - void VisitExpr_(const VarNode *e1, const Expr& e2) final { - if (const VarNode *id2 = e2.as()) { + void VisitExpr_(const VarNode* e1, const Expr& e2) final { + if (const VarNode* id2 = e2.as()) { auto local1 = GetRef(e1); auto local2 = GetRef(id2); // We handle open terms with this rule assuming variables are identical. @@ -168,17 +227,17 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const GlobalVarNode *g1, const Expr& e2) final { - if (const GlobalVarNode *g2 = e2.as()) { + void VisitExpr_(const GlobalVarNode* g1, const Expr& e2) final { + if (const GlobalVarNode* g2 = e2.as()) { equal = equal && g1 == g2; } else { equal = false; } } - void VisitExpr_(const TupleNode *pl1, const Expr& e2) final { + void VisitExpr_(const TupleNode* pl1, const Expr& e2) final { Tuple prod1 = GetRef(pl1); - if (const TupleNode *pl2 = e2.as()) { + if (const TupleNode* pl2 = e2.as()) { Tuple prod2 = GetRef(pl2); if (prod1->fields.size() != prod2->fields.size()) { equal = false; @@ -193,8 +252,8 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const ParamNode *p1, const Expr& e2) final { - if (const ParamNode *p2 = e2.as()) { + void VisitExpr_(const ParamNode* p1, const Expr& e2) final { + if (const ParamNode* p2 = e2.as()) { eq_map.Set(p1->var, p2->var); equal = equal && AlphaEqual(p1->type, p2->type); } else { @@ -202,25 +261,42 @@ struct AlphaEq : ExprFunctor { } } - void VisitExpr_(const FunctionNode *func1, const Expr& e2) final { - if (const FunctionNode *func2 = e2.as()) { + void VisitExpr_(const FunctionNode* func1, const Expr& e2) final { + if (const FunctionNode* func2 = e2.as()) { if (func1->params.size() != func2->params.size()) { equal = false; return; } + if (func1->type_params.size() != func2->type_params.size()) { + equal = false; + return; + } + for (size_t i = 0U; i < func1->params.size(); i++) { this->VisitExpr(func1->params[i], func2->params[i]); } + for (size_t i = 0U; i < func1->type_params.size(); i++) { + equal = equal && AlphaEqual(func1->type_params[i], func2->type_params[i]); + if (!equal) { + return; + } + } + + equal = equal && AlphaEqual(func1->ret_type, func2->ret_type); + if (!equal) { + return; + } + this->VisitExpr(func1->body, func2->body); } else { equal = false; } } - void VisitExpr_(const CallNode *op, const Expr& e2) final { - if (const CallNode *call = e2.as()) { + void VisitExpr_(const CallNode* op, const Expr& e2) final { + if (const CallNode* call = e2.as()) { this->VisitExpr(op->op, call->op); if (op->args.size() != call->args.size()) { @@ -228,20 +304,86 @@ struct AlphaEq : ExprFunctor { return; } + if (op->type_args.size() != call->type_args.size()) { + equal = false; + return; + } + + // checking attrs by pointer equality for now + equal = equal && (op->attrs == call->attrs); + if (!equal) { + return; + } + for (size_t i = 0U; i < op->args.size(); i++) { this->VisitExpr(op->args[i], call->args[i]); } + for (size_t i = 0U; i < op->type_args.size(); i++) { + equal = equal && AlphaEqual(op->type_args[i], call->type_args[i]); + if (!equal) { + return; + } + } } else { equal = false; } } - void VisitExpr_(const LetNode *op, const Expr& e2) final { - if (const LetNode *let = e2.as()) { + void VisitExpr_(const LetNode* op, const Expr& e2) final { + if (const LetNode* let = e2.as()) { eq_map.Set(op->var, let->var); this->VisitExpr(op->value, let->value); this->VisitExpr(op->body, let->body); + + // value_type should match as well (including nulls) + if (op->value_type.defined() != let->value_type.defined()) { + equal = false; + return; + } + + if (op->value_type.defined()) { + equal = equal && AlphaEqual(op->value_type, let->value_type); + } + } else { + equal = false; + } + } + + void VisitExpr_(const IfNode* op, const Expr& e2) final { + if (const IfNode* i = e2.as()) { + VisitExpr(op->cond, i->cond); + VisitExpr(op->true_branch, i->true_branch); + VisitExpr(op->false_branch, i->false_branch); + } else { + equal = false; + } + } + + void VisitExpr_(const OpNode* op, const Expr& e2) final { + if (const OpNode* o = e2.as()) { + equal = equal && op->name == o->name; + } else { + equal = false; + } + } + + void VisitExpr_(const ConstantNode* op, const Expr& e2) final { + if (const ConstantNode* c = e2.as()) { + if (AlphaEqual(op->tensor_type(), c->tensor_type())) { + equal = equal && SameNDArray(op->data, c->data); + } else { + equal = false; + } + } else { + equal = false; + } + } + + void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final { + if (const TupleGetItemNode* proj = e2.as()) { + this->VisitExpr(op->tuple, proj->tuple); + equal = equal && (op->index == proj->index); } else { equal = false; } @@ -255,15 +397,15 @@ bool AlphaEqual(const Expr& e1, const Expr& e2) { } // TODO(@jroesch): move to correct namespace? -TVM_REGISTER_API("relay._make._alpha_eq") - .set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_API("relay._make._alpha_equal") + .set_body([](TVMArgs args, TVMRetValue* ret) { Expr e1 = args[0]; Expr e2 = args[1]; *ret = AlphaEqual(e1, e2); }); -TVM_REGISTER_API("relay._make._type_alpha_eq") - .set_body([](TVMArgs args, TVMRetValue *ret) { +TVM_REGISTER_API("relay._make._type_alpha_equal") + .set_body([](TVMArgs args, TVMRetValue* ret) { Type t1 = args[0]; Type t2 = args[1]; *ret = AlphaEqual(t1, t2); diff --git a/src/relay/pass/dead_code.cc b/src/relay/pass/dead_code.cc new file mode 100644 index 000000000000..05036042a635 --- /dev/null +++ b/src/relay/pass/dead_code.cc @@ -0,0 +1,119 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file dead_code.cc + * + * \brief Remove code that does not effect the program result. + * + * The algorithm is implemented by two visitor: + * CalcDep turn an expr into a dependency graph of expr, + * GenLet turn the dependency graph into a let list, taking only the used value. + */ +#include +#include +#include "let_list.h" + +namespace tvm { +namespace relay { + +bool IsBoolLit(const Expr& e, bool b) { + if (const ConstantNode* c = e.as()) { + if (c->is_scalar()) { + auto dt = c->tensor_type()->dtype; + if (dt == UInt(8)) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == UInt(16)) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == UInt(32)) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == UInt(64)) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == Int(8)) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == Int(16)) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == Int(32)) { + return *reinterpret_cast(c->data->data) == b; + } else if (dt == Int(64)) { + return *reinterpret_cast(c->data->data) == b; + } + } + } + return false; +} + +// calculate the dependency graph from expression +class CalcDep : private ExprMutator { + public: + static Expr Eliminate(const Expr& e) { + CalcDep cd; + auto res = cd(e); + GenLet gl(cd.var_map_); + gl(res); + return gl.lets_.Get(res); + } + + private: + struct Binder { + Type t; + Expr e; + Binder(const Type& t, const Expr& e) : t(t), e(e) { } + }; + using VarMap = std::unordered_map; + VarMap var_map_; + + Expr VisitExpr_(const IfNode* i) final { + auto cond = VisitExpr(i->cond); + if (IsBoolLit(cond, true)) { + return Eliminate(i->true_branch); + } else if (IsBoolLit(cond, false)) { + return Eliminate(i->false_branch); + } else { + return IfNode::make(cond, Eliminate(i->true_branch), Eliminate(i->false_branch)); + } + } + + Expr VisitExpr_(const LetNode* l) final { + var_map_.insert(std::pair(l->var, + Binder(l->value_type, + Eliminate(l->value)))); + return VisitExpr(l->body); + } + + Expr VisitExpr_(const FunctionNode* f) final { + return FunctionNode::make(f->params, f->ret_type, Eliminate(f->body), f->type_params); + } + + // generate the let list from dependency graph + class GenLet : private ExprVisitor { + private: + LetList lets_; + VarMap var_map_; + explicit GenLet(const VarMap& var_map) : var_map_(var_map) { } + friend CalcDep; + + void VisitExpr_(const VarNode* vn) final { + Var v = GetRef(vn); + if (var_map_.count(v) != 0) { + auto val = var_map_.at(v); + var_map_.erase(v); + // erase before visit to handle letrec + VisitExpr(val.e); + // visit before push back so the dependency of dependency is before the dependency + lets_.Push(v, val.t, val.e); + } + } + }; +}; + +Expr DeadCodeElimination(const Expr& e) { + return CalcDep::Eliminate(e); +} + +TVM_REGISTER_API("relay._ir_pass.dead_code_elimination") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = DeadCodeElimination(args[0]); + }); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h index a451fbe16984..70a2d9347eab 100644 --- a/src/relay/pass/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -8,7 +8,6 @@ #include #include -#include #include namespace tvm { @@ -21,11 +20,11 @@ class TypeFunctor; #define TYPE_FUNCTOR_DEFAULT \ { return VisitTypeDefault_(op, std::forward(args)...); } -#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ return self->VisitType_(static_cast(n.node_.get()), \ - std::forward(args)...); \ + std::forward(args)...); \ }); template diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1e2100fa902e..72bdaf69f061 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor { return TupleTypeNode::make(fields); } + Type VisitExpr_(const TupleGetItemNode* op) final { + // TODO(M.K.) + // handle case where field type is not known + Type tuple_type = GetType(op->tuple); + auto tuple_ty_node = tuple_type.as(); + if (!tuple_ty_node) { + LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef(op); + } + if (static_cast(tuple_ty_node->fields.size()) <= op->index) { + LOG(FATAL) << "tuple not big enough" << GetRef(op); + } + return tuple_ty_node->fields[op->index]; + } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } @@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const TupleGetItemNode* op) final { + return AttachCheckedType(op); + } + Expr VisitExpr_(const ParamNode* op) final { return ExprMutator::VisitExpr_(op); } diff --git a/src/runtime/builtin_fp16.cc b/src/runtime/builtin_fp16.cc index 79c3cc474269..c920c9571f38 100644 --- a/src/runtime/builtin_fp16.cc +++ b/src/runtime/builtin_fp16.cc @@ -3,12 +3,14 @@ * \file builtin_fp16.cc * \brief Functions for conversion between fp32 and fp16 */ - #include #include extern "C" { +// disable under msvc +#ifndef _MSC_VER + TVM_WEAK uint16_t __gnu_f2h_ieee(float a) { return __truncXfYf2__(a); } @@ -17,4 +19,5 @@ TVM_WEAK float __gnu_h2f_ieee(uint16_t a) { return __extendXfYf2__(a); } +#endif } diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 04c178f25dfa..0ffa4c174544 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -20,20 +20,13 @@ inline void VerifyDataType(DLDataType dtype) { if (dtype.code == kDLFloat) { CHECK_EQ(dtype.bits % 8, 0); } else { + // allow uint1 as a special flag for bool. + if (dtype.bits == 1 && dtype.code == kDLUInt) return; CHECK_EQ(dtype.bits % 8, 0); } CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); } -inline size_t GetDataSize(const DLTensor& arr) { - size_t size = 1; - for (tvm_index_t i = 0; i < arr.ndim; ++i) { - size *= arr.shape[i]; - } - size *= (arr.dtype.bits * arr.dtype.lanes + 7) / 8; - return size; -} - inline size_t GetDataAlignment(const DLTensor& arr) { size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; if (align < kAllocAlignment) return kAllocAlignment; @@ -129,8 +122,8 @@ DLManagedTensor* NDArray::ToDLPack() const { } NDArray NDArray::Empty(std::vector shape, - DLDataType dtype, - DLContext ctx) { + DLDataType dtype, + DLContext ctx) { NDArray ret = Internal::Create(shape, dtype, ctx); // setup memory content size_t size = GetDataSize(ret.data_->dl_tensor); diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index 8591c77bd7cc..ccf7fd617194 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -135,29 +135,29 @@ Tensor Schedule::cache_read(const Tensor& tensor, return cache; } -// Cache write and relayout the data according to loop pattern -Array CacheWriteWithReLayout(Schedule sch, - const Array& tensor_array, - const std::string& scope) { - size_t tensor_size = tensor_array.size(); - sch->InvalidateCache(); - Tensor tensor = tensor_array[0]; - Stage orig_stage = sch[tensor->op]; - const ComputeOpNode* compute = orig_stage->op.as(); - std::unordered_set red_axis; - for (IterVar iv : compute->reduce_axis) { +template +void PrepareAxisMapping(Stage orig_stage, + OpType* op, + std::unordered_set* p_red_axis, + Array* p_new_axis, + std::unordered_map* p_dom_map, + std::unordered_map* p_vsub, + std::unordered_map* p_vsub2newvar, + std::vector* p_predicates) { + auto& red_axis = *p_red_axis; + auto& new_axis = *p_new_axis; + auto& dom_map = *p_dom_map; + auto& vsub = *p_vsub; + auto& vsub2newvar = *p_vsub2newvar; + auto& predicates = *p_predicates; + + for (IterVar iv : op->reduce_axis) { red_axis.insert(iv); } - std::unordered_map dom_map; - Array new_axis; - - for (IterVar iv : compute->axis) { + for (IterVar iv : op->axis) { dom_map[iv] = iv->dom; } schedule::PassDownDomain(orig_stage, &dom_map, true); - std::unordered_map vsub; - std::unordered_map vsub2newvar; - std::vector predicates; { // The source->cache std::unordered_map value_map; @@ -178,17 +178,85 @@ Array CacheWriteWithReLayout(Schedule sch, } // skip reduction iteration. std::unordered_set skip_bound_check; - for (IterVar iv : compute->reduce_axis) { + for (IterVar iv : op->reduce_axis) { skip_bound_check.insert(iv); } schedule::PassUpIndex(orig_stage, dom_map, &value_map, true); predicates = schedule::MakeBoundCheck( orig_stage, dom_map, value_map, true, skip_bound_check); // The root axis - for (IterVar iv : compute->axis) { - vsub[iv->var.get()] = value_map.at(iv); + for (IterVar iv : op->axis) { + if (value_map.count(iv)) { + vsub[iv->var.get()] = value_map.at(iv); + } // to handle tensor axis } } +} + +Array ReplaceOriginalOp(Schedule sch, + Stage orig_stage, + const std::string& scope, + Operation cache_op, + Operation orig_new_op, + size_t tensor_size) { + Array cache_tensor_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_tensor_list.push_back(cache_tensor); + } + // The replace of the dataflow + std::unordered_map vmap; + std::unordered_map rvmap; + vmap[orig_stage->op.output(0)] = orig_new_op.output(0); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + for (size_t i = 0; i < tensor_size; i++) { + vmap[orig_stage->op.output(0)] = orig_new_op.output(0); + rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + } + ReplaceDataFlow(sch->stages, &vmap, &rvmap); + // mutate orig stage + orig_stage->op = orig_new_op; + orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); + orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; + orig_stage->relations = Array(); + // create schedule for new cached stage. + ArrayNode* stages = sch->stages.CopyOnWrite(); + size_t pos = FindNodeRef(stages, orig_stage); + Stage cache_stage = Stage(cache_op); + cache_stage.set_scope(scope); + CHECK_LT(pos, stages->data.size()); + stages->data.insert(stages->data.begin() + pos, + cache_stage.node_); + sch->stage_map.Set(cache_op, cache_stage); + // Update group + cache_stage->group = orig_stage->group; + if (cache_stage->group.defined()) { + ++cache_stage->group->num_child_stages; + } + return cache_tensor_list; +} + + +// Cache write and relayout the data according to loop pattern +Array CacheWriteWithReLayout(Schedule sch, + const Array& tensor_array, + const std::string& scope) { + size_t tensor_size = tensor_array.size(); + sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; + Stage orig_stage = sch[tensor->op]; + const ComputeOpNode* compute = orig_stage->op.as(); + + std::unordered_set red_axis; + Array new_axis; + std::unordered_map dom_map; + + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; + + PrepareAxisMapping(orig_stage, compute, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); Expr body; Array body_list; @@ -198,7 +266,7 @@ Array CacheWriteWithReLayout(Schedule sch, body = InjectPredicate(predicates, body); body = VarReplacer(vsub2newvar).Mutate(body); // Reduce nodes in ONE computeOp must be the same except value_index - // This is right only if the oringinal body ensures Reduce nodes are the same + // This is right only if the original body ensures Reduce nodes are the same if (body->is_type()) { const ir::Reduce* reduce_body = body.as(); if (first_reduce != nullptr) { @@ -234,48 +302,107 @@ Array CacheWriteWithReLayout(Schedule sch, Operation cache_op = ComputeOpNode::make( compute->name + "." + scope, compute->tag, compute->attrs, new_axis, body_list); - Array cache_tensor_list; + Array cache_expr_list; for (size_t i = 0; i < tensor_size; i++) { Tensor cache_tensor = cache_op.output(i); - cache_tensor_list.push_back(cache_tensor); cache_expr_list.push_back(cache_tensor(args)); } Operation orig_new_op = ComputeOpNode::make( compute->name, compute->tag, compute->attrs, compute->axis, cache_expr_list); - // The replace of the dataflow - std::unordered_map vmap; - std::unordered_map rvmap; - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); - for (size_t i = 0; i < tensor_size; i++) { - vmap[orig_stage->op.output(0)] = orig_new_op.output(0); - rvmap[orig_new_op.output(0)] = orig_stage->op.output(0); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); +} + + +// for tensor compute op +Array CacheWriteWithReLayoutTensor(Schedule sch, + const Array& tensor_array, + const std::string& scope) { + size_t tensor_size = tensor_array.size(); + sch->InvalidateCache(); + Tensor tensor = tensor_array[0]; + Stage orig_stage = sch[tensor->op]; + const TensorComputeOpNode* tensor_op = orig_stage->op.as(); + CHECK_EQ(tensor_op->num_outputs(), 1) + << "cache write only support single output tensor_compute_op"; + + std::unordered_set red_axis; + Array new_axis; + std::unordered_map dom_map; + + std::unordered_map vsub; + std::unordered_map vsub2newvar; + std::vector predicates; + + PrepareAxisMapping(orig_stage, tensor_op, + &red_axis, &new_axis, &dom_map, &vsub, &vsub2newvar, &predicates); + + + for (int i = tensor_op->schedulable_ndim; i < static_cast(tensor_op->axis.size()); ++i) { + IterVar iv = tensor_op->axis[i]; + IterVar new_iv = IterVarNode::make( + iv->dom, iv->var.copy_with_suffix(".c"), iv->iter_type); + new_axis.push_back(new_iv); + } + Array new_regions; + for (Region old_region : tensor_op->input_regions) { + Region region; + for (Range r : old_region) { + Expr min = VarReplacer(vsub2newvar).Mutate(r->min); + Expr extent = VarReplacer(vsub2newvar).Mutate(r->extent); + region.push_back(Range::make_by_min_extent(min, extent)); + } + new_regions.push_back(region); } - ReplaceDataFlow(sch->stages, &vmap, &rvmap); - // mutate orig stage - orig_stage->op = orig_new_op; - orig_stage->all_iter_vars = orig_stage->op->root_iter_vars(); - orig_stage->leaf_iter_vars = orig_stage->all_iter_vars; - orig_stage->relations = Array(); - // create schedule for new cached stage. - ArrayNode* stages = sch->stages.CopyOnWrite(); - size_t pos = FindNodeRef(stages, orig_stage); - Stage cache_stage = Stage(cache_op); - cache_stage.set_scope(scope); - CHECK_LT(pos, stages->data.size()); - stages->data.insert(stages->data.begin() + pos, - cache_stage.node_); - sch->stage_map.Set(cache_op, cache_stage); - // Update group - cache_stage->group = orig_stage->group; - if (cache_stage->group.defined()) { - ++cache_stage->group->num_child_stages; + + Operation cache_op = TensorComputeOpNode::make( + tensor_op->name + "." + scope, tensor_op->tag, new_axis, + tensor_op->reduce_axis, tensor_op->schedulable_ndim, + tensor_op->intrin, tensor_op->inputs, new_regions); + + // axis will be used in generating compute op + Array compute_axis = tensor_op->axis; + for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { + IterVar iv = tensor_op->axis[i]; + IterVar aiv = IterVarNode::make(iv->dom, iv->var, kDataPar); + compute_axis.Set(i, aiv); } - return cache_tensor_list; + + // The reader args + Array args; + { + // cache->compute + std::unordered_map value_map; + for (IterVar iv : compute_axis) { + value_map[iv] = iv->var; + } + schedule::PassDownIndex(orig_stage, dom_map, &value_map, true); + for (IterVar iv : orig_stage->leaf_iter_vars) { + if (red_axis.count(iv)) continue; + args.push_back(value_map.at(iv)); + } + // tensorized region axis + for (size_t i = tensor_op->schedulable_ndim; i < tensor_op->axis.size(); ++i) { + IterVar iv = compute_axis[i]; + args.push_back(value_map.at(iv)); + } + } + + Array cache_expr_list; + for (size_t i = 0; i < tensor_size; i++) { + Tensor cache_tensor = cache_op.output(i); + cache_expr_list.push_back(cache_tensor(args)); + } + Operation orig_new_op = ComputeOpNode::make( + tensor_op->name, tensor_op->tag, {}, + compute_axis, cache_expr_list); + return ReplaceOriginalOp(sch, orig_stage, scope, + cache_op, orig_new_op, tensor_size); } + Array Schedule::cache_write(const Array& tensor_array, const std::string& scope) { (*this)->InvalidateCache(); @@ -291,23 +418,26 @@ Array Schedule::cache_write(const Array& tensor_array, CHECK(orig_stage.same_as(tmp_stage)) << "Input tensor list must be generated by ONE computeOp"; } - return CacheWriteWithReLayout(*this, tensor_array, scope); } + Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) { + // support original compute and tensor compute both (*this)->InvalidateCache(); - Stage orig_stage = operator[](tensor->op); - const ComputeOpNode* compute = tensor->op.as(); - CHECK(compute) - << "cache write only take ComputeOp as writers"; - CHECK_EQ(compute->num_outputs(), 1) - << "cache write only support single output ComputeOp"; - - return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + const char* type_key = tensor->op->type_key(); + if (!strcmp(type_key, "ComputeOp")) { + return (CacheWriteWithReLayout(*this, {tensor}, scope))[0]; + } else if (!strcmp(type_key, "TensorComputeOp")) { + return (CacheWriteWithReLayoutTensor(*this, {tensor}, scope))[0]; + } else { + LOG(FATAL) << "cache write only take ComputeOp or TensorComputeOp as writers"; + return Tensor(); + } } + void RebaseNonZeroMinLoop(const Schedule& sch) { std::unordered_map rebase_map; for (Stage s : sch->stages) { diff --git a/tests/python/relay/test_debug_printer.py b/tests/python/relay/test_ir_debug_printer.py similarity index 92% rename from tests/python/relay/test_debug_printer.py rename to tests/python/relay/test_ir_debug_printer.py index 2ea0b7575ff8..e5f9ad2e69cd 100644 --- a/tests/python/relay/test_debug_printer.py +++ b/tests/python/relay/test_ir_debug_printer.py @@ -77,7 +77,7 @@ def test_call(): def test_let(): lv = relay.Var('x') - ty = relay.ty.TensorType((10, 20), "float32") + ty = relay.ty.TensorType((10, 20), 'float32') arr = tvm.nd.array(10) value = relay.Constant(arr) let = relay.Let(lv, value, lv, ty) @@ -90,3 +90,8 @@ def test_if(): right = relay.Var('right') ife = relay.If(cond, left, right) show(ife) + +def test_tuple_get_item(): + t = relay.Var('t') + g = relay.TupleGetItem(t, 0) + show(g) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index 4505710c06cc..79883ed225e0 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -3,6 +3,13 @@ from tvm import relay from tvm.expr import * +def test_bad_constructor(): + try: + x = relay.ty.TensorType("xx", "xx") + except tvm.TVMError: + pass + + # Span def test_span(): span = relay.Span(None, 1, 1) @@ -65,8 +72,8 @@ def test_type_relation(): args = tvm.convert([tf, tt, tp]) num_inputs = 2 - func = None - attrs = None + func = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + attrs = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) tr = relay.TypeRelation(func, args, num_inputs, attrs) assert tr.args == args @@ -168,7 +175,15 @@ def test_if(): str(ife) +def test_tuple_get_item(): + tup = relay.Var("tuple") + get = relay.TupleGetItem(tup, 1) + assert get.tuple == tup + assert get.index == 1 + str(get) + if __name__ == "__main__": + test_bad_constructor() test_span() test_tensor_type() test_type_param() @@ -184,3 +199,4 @@ def test_if(): test_call() test_let() test_if() + test_tuple_get_item() diff --git a/tests/python/relay/test_relay_op.py b/tests/python/relay/test_ir_op.py similarity index 62% rename from tests/python/relay/test_relay_op.py rename to tests/python/relay/test_ir_op.py index 3b1d914fe02c..f1d835d2b43b 100644 --- a/tests/python/relay/test_relay_op.py +++ b/tests/python/relay/test_ir_op.py @@ -14,13 +14,22 @@ def test(x): def test_op_level1(): x = relay.Var("x") - for op_name in ["log", "exp", "sqrt"]: + for op_name in ["log", "exp", "sqrt", "tanh"]: y = getattr(relay, op_name)(x) assert y.op.name == op_name assert y.op.support_level == 1 assert y.args[0] == x +def test_op_level3(): + x = relay.Var("x") + + for op_name in ["ceil", "floor", "trunc", "round", "abs", "negative"]: + y = getattr(relay, op_name)(x) + assert y.op.name == op_name + assert y.op.support_level == 3 + assert y.args[0] == x if __name__ == "__main__": test_op_attr() test_op_level1() + test_op_level3() diff --git a/tests/python/relay/test_well_formed.py b/tests/python/relay/test_ir_well_formed.py similarity index 61% rename from tests/python/relay/test_well_formed.py rename to tests/python/relay/test_ir_well_formed.py index 8bdef4d0edb5..c6cb99662bb5 100644 --- a/tests/python/relay/test_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -3,7 +3,7 @@ from tvm.relay.ir_pass import well_formed def test_well_formed(): - x = relay.Var("x") + x = relay.Var('x') assert well_formed(x) v = relay.Constant(tvm.nd.array(10)) ty = None @@ -16,3 +16,19 @@ def test_well_formed(): # but we want all binder to be distinct from each other. assert not well_formed(relay.Let(relay.Var("y"), f, relay.Let(relay.Var("z"), f, v, ty), ty)) + + +def test_tuple(): + x = relay.Var('x') + assert well_formed(x) + v = relay.Constant(tvm.nd.array(10)) + ty = None + let = relay.Let(x, v, x, ty) + assert well_formed(let) + assert well_formed(relay.Tuple([v, v])) + assert not well_formed(relay.Tuple([let, let])) + + +def test_tuple_get_item(): + t = relay.Var('t') + assert well_formed(relay.TupleGetItem(t, 2)) diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index c1c8b03c1c23..a90f6eb55ae1 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -1,6 +1,31 @@ import tvm +import numpy as np from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.ir_builder import IRBuilder, func_type +from tvm.relay.ir_builder import scalar_type, convert, tensor_type +from tvm.relay.env import Environment +def assert_has_type(expr, typ, env=Environment({})): + checked_expr = infer_type(env, expr) + checked_type = checked_expr.checked_type + if checked_type != typ: + raise RuntimeError("Type mismatch %s vs %s" % ( + checked_type, typ)) + +def test_single_op(): + def check_single_op(opfunc): + "Program: fn (x : float32) { let t1 = f(x); t1 }" + b = IRBuilder() + with b.function(('x', 'float32')) as func: + x, = func.param_ids() + t1 = b.let('t1', opfunc(x)) + b.ret(t1) + assert_has_type(func.to_func(), func_type(['float32'], 'float32')) + + for opfunc in [tvm.relay.log, tvm.relay.exp, tvm.relay.sqrt, + tvm.relay.sigmoid, tvm.relay.tanh]: + check_single_op(opfunc) def test_expand_dims_infer_type(): ib = relay.ir_builder.IRBuilder() @@ -11,10 +36,171 @@ def test_expand_dims_infer_type(): ib.ret(relay.expand_dims(x, axis=2)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type() + ftype = func.checked_type assert ftype.ret_type == relay.ty.TensorType( (n, t, 1, 100), "float32") +def test_softmax(): + ib = relay.ir_builder.IRBuilder() + n, d = tvm.var("n"), tvm.var("d") + x = ib.param("x", relay.ty.TensorType((n, d), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.softmax(x, axis=1)) + ib.ret(func) + + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, d), "float32") + + +def test_log_softmax(): + ib = relay.ir_builder.IRBuilder() + n, d = tvm.var("n"), tvm.var("d") + x = ib.param("x", relay.ty.TensorType((n, d), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.log_softmax(x, axis=1)) + ib.ret(func) + + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, d), "float32") + +def test_unary_op(): + for op in [relay.exp, + relay.log, + relay.sqrt, + relay.sigmoid, + relay.nn.relu]: + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.TensorType((10, 4), "int32")) + with ib.function(x) as func: + ib.ret(op(x.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((10, 4), "int32") + +def test_binary_op(): + def check_binary_op(opfunc): + """ + Program: + fn (x, y) { + return x y; + } + """ + b = IRBuilder() + + x = b.param('x', tensor_type(5, 5, 5)) + y = b.param('y', tensor_type(5, 5, 5)) + with b.function(x, y) as func: + b.ret(opfunc(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert_has_type(func.to_func(), expected_ty) + + for opfunc in [relay.add, relay.subtract, relay.mod, + relay.multiply, relay.divide]: + check_binary_op(opfunc) + + +def test_binary_broadcast_op(): + def check_binary_broadcast_op(opfunc): + """ + Program: + fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { + return x y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(10, 4)) + y = b.param('y', tensor_type(5, 10, 1)) + with b.function(x, y) as func: + b.ret(opfunc(x.var, y.var)) + b.ret(func) + prog, env = b.get() + + expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)], + tensor_type(5, 10, 4)) + assert_has_type(func.to_func(), expected_ty) + + for opfunc in [relay.add, relay.subtract, relay.mod, + relay.multiply, relay.divide]: + check_binary_broadcast_op(opfunc) + + +def test_concatenate_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + y = ib.param("y", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x, y) as func: + ib.ret(relay.concatenate((x, y), axis=-1)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, t, 200), "float32") + + ib = relay.ir_builder.IRBuilder() + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + y = ib.param("y", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x, y) as func: + ib.ret(relay.concatenate((x, y), axis=2)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, t, 200), "float32") + + ib = relay.ir_builder.IRBuilder() + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + y = ib.param("y", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x, y) as func: + ib.ret(relay.concatenate((x, y), axis=1)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, t + t, 100), "float32") + +def test_lrn(): + ib = relay.ir_builder.IRBuilder() + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.lrn(x, size=10, axis=2, bias=0.5, alpha=.00001, beta=0.75)) + ib.ret(func) + + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, c , h, w), "float32") + + +def test_l2_normalize(): + ib = relay.ir_builder.IRBuilder() + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = ib.param("x", relay.ty.TensorType((n, c , h, w), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.l2_normalize(x, eps=0.001, axis=[1])) + ib.ret(func) + + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, c , h, w), "float32") + if __name__ == "__main__": + test_unary_op() + test_single_op() test_expand_dims_infer_type() + test_concatenate_infer_type() + test_softmax() + test_log_softmax() + test_binary_op() + test_binary_broadcast_op() + test_lrn() + test_l2_normalize() diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index d5dd64d76555..1d6d00277358 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1,7 +1,8 @@ +""" Support level2 operator test cases. +""" import tvm from tvm import relay - def test_conv2d_infer_type(): # symbolic in batch dimension ib = relay.ir_builder.IRBuilder() @@ -16,7 +17,7 @@ def test_conv2d_infer_type(): channels=2)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type() + ftype = func.checked_type assert ftype.ret_type == relay.ty.TensorType( (n, 2, 224, 224), "float32") assert ftype.arg_types[1] == relay.ty.TensorType( @@ -31,7 +32,7 @@ def test_conv2d_infer_type(): ib.ret(relay.nn.conv2d(x.var, w.var, out_dtype="int32")) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type() + ftype = func.checked_type assert ftype.ret_type == relay.ty.TensorType( (n, 2, 222, 222), "int32") @@ -50,13 +51,155 @@ def test_conv2d_infer_type(): out_dtype="int32")) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type() + ftype = func.checked_type assert ftype.ret_type == relay.ty.TensorType( (1, 4, 224, 224, 4, 4), "int32") assert ftype.arg_types[1] == relay.ty.TensorType( (4, 8, 3, 3, 4, 4), "int8") +def test_conv2d_transpose_infer_type(): + # symbolic in batch dimension + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), 10, 10, 12 + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + w = ib.param("w", relay.ty.IncompleteType()) + + with ib.function(x, w) as func: + ib.ret(relay.nn.conv2d_transpose(x.var, w.var, + kernel_size=(3, 3), + padding=(1, 1), + channels=15)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, 15, 10, 12), "float32") + assert ftype.arg_types[1] == relay.ty.TensorType( + (10, 15, 3, 3), "float32") + + # infer by shape of w, mixed precision + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), 10, 10, 12 + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + w = ib.param("w", relay.ty.TensorType((12, 11, 5, 5), "float32")) + with ib.function(x, w) as func: + ib.ret(relay.nn.conv2d_transpose(x.var, w.var, + output_padding=(1, 1), + channels=11, + data_layout="NHWC")) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, 15, 15, 11), "float32") + +def test_upsampling_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, c , h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR")) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, c, h*2, w*2), "float32") + + ib = relay.ir_builder.IRBuilder() + n, c = tvm.var("n"), tvm.var("c") + x = ib.param("x", relay.ty.TensorType((n, c, 100, 200), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.upsampling(x.var, scale=2, layout="NCHW", method="BILINEAR")) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, c, 200, 400), "float32") + +def _test_pool2d_infer_type(opfunc): + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), 10, 224, 224 + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + with ib.function(x) as func: + ib.ret(opfunc(x.var, pool_size=(1, 1))) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, 10, 224, 224), "float32") + + ph, pw = tvm.var("ph"), tvm.var("pw") + sh, sw = tvm.var("sh"), tvm.var("sw") + + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), 10, 224, 224 + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + with ib.function(x) as func: + ib.ret(opfunc(x.var, pool_size=(ph, pw), strides=(sh, sw))) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, 10, (((224 - ph)/sh) + 1), (((224 - pw)/sw) + 1)), "float32") + +def _test_global_pool2d_infer_type(opfunc): + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), tvm.var("c"), 224, 224 + x = ib.param("x", relay.ty.TensorType((n, h, w, c), "float32")) + with ib.function(x) as func: + ib.ret(opfunc(x.var, layout="NHWC")) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, 1, 1, c), "float32") + + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "float32")) + with ib.function(x) as func: + ib.ret(opfunc(x.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, c, 1, 1), "float32") + +def test_pool2d_infer_type(): + _test_pool2d_infer_type(relay.nn.max_pool2d) + _test_pool2d_infer_type(relay.nn.avg_pool2d) + _test_global_pool2d_infer_type(relay.nn.global_avg_pool2d) + _test_global_pool2d_infer_type(relay.nn.global_avg_pool2d) + +def test_flatten_infer_type(): + ib = relay.ir_builder.IRBuilder() + d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4") + x = ib.param("x", relay.ty.TensorType((d1, d2, d3, d4), "float32")) + + with ib.function(x) as func: + ib.ret(relay.nn.batch_flatten(x.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((d1, ((d2*d3)*d4)), "float32") + + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.ty.TensorType((3, 2, 4, 3), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.batch_flatten(x.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((3, 24), "float32") + + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.ty.TensorType((d1, 2, d3, 3), "float32")) + with ib.function(x) as func: + ib.ret(relay.nn.batch_flatten(x.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((d1, ((2*d3)*3)), "float32") if __name__ == "__main__": test_conv2d_infer_type() + test_pool2d_infer_type() + test_upsampling_infer_type() + test_flatten_infer_type() + test_conv2d_transpose_infer_type() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py new file mode 100644 index 000000000000..cc8973c38384 --- /dev/null +++ b/tests/python/relay/test_op_level3.py @@ -0,0 +1,172 @@ +""" Support level3 operator test cases. +""" +import tvm +import numpy as np +from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.ir_builder import IRBuilder, func_type +from tvm.relay.env import Environment + + +def test_unary_identity(): + for op in [relay.zeros_like, relay.ones_like]: + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.TensorType((8, 9, 4), "int32")) + with ib.function(x) as func: + ib.ret(op(x.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((8, 9, 4), "int32") + + +def test_clip_type(): + ib = relay.ir_builder.IRBuilder() + a = ib.param("a", relay.TensorType((10, 4), "float32")) + with ib.function(a) as func: + ib.ret(relay.clip(a.var, 1., 4.)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((10, 4), "float32") + + +def test_copy_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x) as func: + ib.ret(relay.copy(x)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, t, 100), "float32") + + +def test_transpose_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, t, d = tvm.var("n"), tvm.var("t"), 100 + x = ib.param("x", relay.ty.TensorType((n, t, d), "float32")) + with ib.function(x) as func: + ib.ret(relay.transpose(x, axes=(1, 0, 2))) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (t, n, 100), "float32") + + +def test_reshape_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, t, d1, d2 = tvm.var("n"), tvm.var("t"), 100, 20 + x = ib.param("x", relay.ty.TensorType((n, t, d1, d2), "float32")) + with ib.function(x) as func: + ib.ret(relay.reshape(x, newshape=(n, t, 2000))) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType( + (n, t, 2000), "float32") + + +def assert_has_type(expr, typ, env=Environment({})): + checked_expr = infer_type(env, expr) + checked_type = checked_expr.checked_type + if checked_type != typ: + raise RuntimeError("Type mismatch %s vs %s" % ( + checked_type, typ)) + +def test_single_op(): + def check_single_op(opfunc): + "Program: fn (x : float32) { let t1 = f(x); t1 }" + b = IRBuilder() + with b.function(('x', 'float32')) as func: + x, = func.param_ids() + t1 = b.let('t1', opfunc(x)) + b.ret(t1) + assert_has_type(func.to_func(), func_type(['float32'], 'float32')) + + for opfunc in [tvm.relay.ceil, tvm.relay.floor, tvm.relay.trunc, + tvm.relay.round, tvm.relay.abs, tvm.relay.negative]: + check_single_op(opfunc) + +def test_take_infer_type(): + def verify_take(dshape, indices_shape, oshape, axis=None): + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.ty.TensorType(dshape, "float32")) + indices = ib.param("indices", relay.ty.TensorType(indices_shape, "int32")) + with ib.function(x, indices) as func: + ib.ret(relay.take(x.var, indices.var, axis=axis)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType(oshape, "float32") + + d1, d2, d3 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3") + d4, d5, d6 = tvm.var("d4"), tvm.var("d5"), tvm.var("d6") + verify_take((d1,), (1,), (1,), 0) + verify_take((4,), (d1, d2), (d1, d2)) + verify_take((3, 3, 3), (1, d2), (1, d2)) + verify_take((d1, d2), (d3, d4, d5), (d3, d4, d5, d2), 0) + verify_take((d1, d2), (d3, d4, d5), (d1, d3, d4, d5), 1) + verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2) + + +def test_full(): + # default settings: match input dtype + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.TensorType((), "int8")) + with ib.function(x) as func: + ib.ret(relay.full(x.var, ())) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((), "int8") + + # change the shape and dtype + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.TensorType((), "float32")) + with ib.function(x) as func: + ib.ret(relay.full(x.var, (1, 2), "int8")) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((1, 2), "int8") + + +def test_full_like(): + # concrete shape + ib = relay.ir_builder.IRBuilder() + base = ib.param("base", relay.TensorType((1, 2, 3), "float32")) + fill = ib.param("fill", relay.TensorType((), "float32")) + with ib.function(base, fill) as func: + ib.ret(relay.full_like(base.var, fill.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((1, 2, 3), "float32") + + # symbolic shape + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") + base = ib.param("base", relay.TensorType((n, c, h, w), "float32")) + fill = ib.param("fill", relay.TensorType((), "float32")) + with ib.function(base, fill) as func: + ib.ret(relay.full_like(base.var, fill.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((n, c, h, w), "float32") + + +if __name__ == "__main__": + test_single_op() + test_unary_identity() + test_clip_type() + test_copy_infer_type() + test_transpose_infer_type() + test_reshape_infer_type() + test_take_infer_type() + test_full() + test_full_like() diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 5009994871f7..a855b0f2caaa 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -1,6 +1,17 @@ import tvm +import numpy as np from tvm import relay +from tvm.relay.ir_pass import infer_type +from tvm.relay.ir_builder import IRBuilder, func_type +from tvm.relay.ir_builder import scalar_type, convert, tensor_type +from tvm.relay.env import Environment +def assert_has_type(expr, typ, env=Environment({})): + checked_expr = infer_type(env, expr) + checked_type = checked_expr.checked_type + if checked_type != typ: + raise RuntimeError("Type mismatch %s vs %s" % ( + checked_type, typ)) def test_cmp_type(): for op in (relay.greater, @@ -16,12 +27,14 @@ def test_cmp_type(): ib.ret(op(x.var, y.var)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type() + ftype = func.checked_type assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1") def test_binary_broadcast(): - for op in [relay.right_shift]: + for op in [relay.right_shift, + relay.left_shift, + relay.maximum]: ib = relay.ir_builder.IRBuilder() x = ib.param("x", relay.TensorType((10, 4), "int32")) y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) @@ -29,10 +42,91 @@ def test_binary_broadcast(): ib.ret(op(x.var, y.var)) ib.ret(func) func = relay.ir_pass.infer_type(ib.env, func.to_func()) - ftype = func.checked_type() + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") + +def test_binary_op(): + def check_binary_op(opfunc): + """ + Program: + fn (x, y) { + return x y; + } + """ + b = IRBuilder() + + x = b.param('x', tensor_type(5, 5, 5)) + y = b.param('y', tensor_type(5, 5, 5)) + with b.function(x, y) as func: + b.ret(opfunc(x.var, y.var)) + b.ret(func) + prog, env = b.get() + ttype = tensor_type(5, 5, 5) + expected_ty = func_type([ttype, ttype], ttype) + assert_has_type(func.to_func(), expected_ty) + + for opfunc in [relay.pow]: + check_binary_op(opfunc) + + +def test_binary_broadcast_op(): + def check_binary_broadcast_op(opfunc): + """ + Program: + fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { + return x y; + } + """ + b = IRBuilder() + x = b.param('x', tensor_type(10, 4)) + y = b.param('y', tensor_type(5, 10, 1)) + with b.function(x, y) as func: + b.ret(opfunc(x.var, y.var)) + b.ret(func) + prog, env = b.get() + + expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)], + tensor_type(5, 10, 4)) + assert_has_type(func.to_func(), expected_ty) + + for opfunc in [relay.pow]: + check_binary_broadcast_op(opfunc) + +def test_cmp_type(): + for op in (relay.greater, + relay.greater_equal, + relay.less, + relay.less_equal, + relay.equal, + relay.not_equal): + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.TensorType((10, 4), "float32")) + y = ib.param("y", relay.TensorType((5, 10, 1), "float32")) + with ib.function(x, y) as func: + ib.ret(op(x.var, y.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.TensorType((5, 10, 4), "uint1") + +def test_binary_broadcast(): + for op in [relay.right_shift, + relay.left_shift, + relay.maximum, + relay.minimum]: + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.TensorType((10, 4), "int32")) + y = ib.param("y", relay.TensorType((5, 10, 1), "int32")) + with ib.function(x, y) as func: + ib.ret(op(x.var, y.var)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type assert ftype.ret_type == relay.TensorType((5, 10, 4), "int32") if __name__ == "__main__": test_cmp_type() test_binary_broadcast() + test_binary_op() + test_binary_broadcast_op() diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py new file mode 100644 index 000000000000..62da592e8249 --- /dev/null +++ b/tests/python/relay/test_op_level5.py @@ -0,0 +1,29 @@ +""" Support level5 operator test cases. +""" +import tvm +from tvm import relay + +def test_resize_infer_type(): + ib = relay.ir_builder.IRBuilder() + n, c, h, w = tvm.var("n"), tvm.var("c"), tvm.var("h"), tvm.var("w") + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) + th, tw = tvm.var("th"), tvm.var("tw") + + with ib.function(x) as func: + ib.ret(relay.image.resize(x.var, (th, tw))) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, c, th, tw), "int8") + + ib = relay.ir_builder.IRBuilder() + x = ib.param("x", relay.ty.TensorType((n, c, h, w), "int8")) + with ib.function(x) as func: + ib.ret(relay.image.resize(x.var, (100, 200), "NCHW", "BILINEAR", False)) + ib.ret(func) + func = relay.ir_pass.infer_type(ib.env, func.to_func()) + ftype = func.checked_type + assert ftype.ret_type == relay.ty.TensorType((n, c, 100, 200), "int8") + +if __name__ == "__main__": + test_resize_infer_type() diff --git a/tests/python/relay/test_pass_alpha_eq.py b/tests/python/relay/test_pass_alpha_eq.py deleted file mode 100644 index 40140ea486a1..000000000000 --- a/tests/python/relay/test_pass_alpha_eq.py +++ /dev/null @@ -1,17 +0,0 @@ -import tvm -from tvm import relay - -def test_type_alpha_eq(): - t1 = relay.ty.TensorType((3, 4), "float32") - t2 = relay.ty.TensorType((3, 4), "float32") - t3 = relay.ty.TensorType((3, 4, 5), "float32") - assert t1 == t2 - assert t1 != t3 - - t1 = relay.ty.TensorType((), "float32") - t2 = relay.ty.TensorType((), "float32") - assert t1 == t2 - - -if __name__ == "__main__": - test_type_alpha_eq() diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py new file mode 100644 index 000000000000..dd722399dac4 --- /dev/null +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -0,0 +1,454 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import alpha_equal +from tvm.relay.ir_builder import convert + +def test_tensor_type_alpha_equal(): + t1 = relay.TensorType((3, 4), "float32") + t2 = relay.TensorType((3, 4), "float32") + t3 = relay.TensorType((3, 4, 5), "float32") + assert t1 == t2 + assert t1 != t3 + + t1 = relay.TensorType((), "float32") + t2 = relay.TensorType((), "float32") + assert t1 == t2 + + +def test_incomplete_type_alpha_equal(): + t1 = relay.IncompleteType(relay.Kind.Shape) + t2 = relay.IncompleteType(relay.Kind.Type) + t3 = relay.IncompleteType(relay.Kind.Type) + + # only equal when there is pointer equality + assert t2 == t2 + assert t1 == t1 + assert t1 != t2 + assert t2 != t3 + + +def test_type_param_alpha_equal(): + t1 = relay.TypeParam("v1", relay.Kind.Type) + t2 = relay.TypeParam("v2", relay.Kind.Shape) + t3 = relay.TypeParam("v3", relay.Kind.Type) + + # only pointer equality and eq_map allow equal params + assert t1 == t1 + assert t2 == t2 + assert t1 != t2 # different kind + assert t1 != t3 # not in eq_map + + # function types are the only way to put type params + # in eq map + ft1 = relay.FuncType(tvm.convert([]), t1, tvm.convert([t1]), tvm.convert([])) + ft2 = relay.FuncType(tvm.convert([]), t3, tvm.convert([t3]), tvm.convert([])) + # actually an invalid type because t2 is wrong kind + ft3 = relay.FuncType(tvm.convert([]), t2, tvm.convert([t2]), tvm.convert([])) + + assert ft1 == ft2 + assert ft1 != ft3 # kinds still do not match + + +def test_func_type_alpha_equal(): + t1 = relay.TensorType((1, 2), "float32") + t2 = relay.TensorType((1, 2, 3), "float32") + + tp1 = relay.TypeParam("v1", relay.Kind.Type) + tp2 = relay.TypeParam("v2", relay.Kind.Type) + tp3 = relay.TypeParam("v3", relay.Kind.Shape) + tp4 = relay.TypeParam("v3", relay.Kind.Shape) + + broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + identity = tvm.get_env_func("tvm.relay.type_relation.Identity") + + tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None) + tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None) + tr3 = relay.TypeRelation(identity, tvm.convert([tp1, tp3]), 1, None) + + ft = relay.FuncType(tvm.convert([t1, t2]), tp1, + tvm.convert([tp1, tp3]), + tvm.convert([tr1])) + translate_vars = relay.FuncType(tvm.convert([t1, t2]), tp1, + tvm.convert([tp2, tp4]), + tvm.convert([tr2])) + assert ft == translate_vars + + different_args = relay.FuncType(tvm.convert([t1]), tp1, + tvm.convert([tp1, tp3]), + tvm.convert([tr1])) + assert ft != different_args + + different_order = relay.FuncType(tvm.convert([t2, t1]), tp1, + tvm.convert([tp1, tp3]), + tvm.convert([tr1])) + assert ft != different_order + + no_rel = relay.FuncType(tvm.convert([t1, t2]), tp1, + tvm.convert([tp1, tp3]), + tvm.convert([])) + assert ft != no_rel + + more_vars = relay.FuncType(tvm.convert([t1, t2]), tp2, + tvm.convert([tp1, tp2, tp3]), + tvm.convert([tr1])) + assert ft != more_vars + + all_the_vars = relay.FuncType(tvm.convert([t1, t2]), tp1, + tvm.convert([tp1, tp2, tp3, tp4]), + tvm.convert([tr1, tr2])) + assert ft != all_the_vars + + different_rel = relay.FuncType(tvm.convert([t1, t2]), tp1, + tvm.convert([tp1, tp3]), + tvm.convert([tr3])) + assert ft != different_rel + + more_rels = relay.FuncType(tvm.convert([t1, t2]), tp1, + tvm.convert([tp1, tp3]), + tvm.convert([tr1, tr3])) + assert ft != more_rels + + +def test_tuple_type_alpha_equal(): + t1 = relay.TensorType((1, 2, 3), "float32") + t2 = relay.TensorType((1, 2, 3, 4), "float32") + tp1 = relay.TypeParam("v1", relay.Kind.Type) + tp2 = relay.TypeParam("v2", relay.Kind.Type) + + tup1 = relay.TupleType(tvm.convert([t1, t2, tp1])) + tup2 = relay.TupleType(tvm.convert([t1, t2, tp1])) + tup3 = relay.TupleType(tvm.convert([t2, t1, tp1])) + tup4 = relay.TupleType(tvm.convert([t1, t2, tp2])) + + # as long as types are alpha-equal and in same order, + # tuples should be alpha-equal + assert tup1 == tup2 + assert tup1 != tup3 + assert tup1 != tup4 + + +def test_type_relation_alpha_equal(): + t1 = relay.TensorType((1, 2), "float32") + t2 = relay.TensorType((1, 2, 3), "float32") + t3 = relay.TensorType((1, 2, 3, 4), "float32") + + # functions are compared only by pointer equality so + # we need to be sure to use the same pointers + broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast") + identity = tvm.get_env_func("tvm.relay.type_relation.Identity") + + # attrs are also compared only by pointer equality + attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + + tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) + same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1) + diff_func = relay.TypeRelation(identity, tvm.convert([t1, t2]), 1, attr1) + diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1) + diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1) + diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2) + + bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1) + diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2) + + # func, number of args, input count, and order should be the same + assert tr == same + assert tr != diff_func + assert tr != diff_order + assert tr != diff_args + assert tr != diff_attr + assert tr != bigger + + assert bigger != diff_num_inputs + + +def test_constant_alpha_equal(): + x = convert(1) + y = convert(2) + assert alpha_equal(x, x) + assert not alpha_equal(x, y) + assert alpha_equal(x, convert(1)) + + +def test_var_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + # normally only pointer equality + assert alpha_equal(v1, v1) + assert not alpha_equal(v1, v2) + + # let node allows for setting the eq_map + l1 = relay.Let(v1, convert(1), v1, None) + l2 = relay.Let(v2, convert(1), v2, None) + l3 = relay.Let(v1, convert(1), v2, None) + + assert alpha_equal(l1, l2) + assert not alpha_equal(l1, l3) + + +def test_global_var_alpha_equal(): + v1 = relay.GlobalVar("v1") + v2 = relay.GlobalVar("v2") + + # only pointer equality suffices (smoke test) + assert alpha_equal(v1, v1) + assert not alpha_equal(v1, v2) + + +def test_tuple_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + # unit value is a valid tuple + assert alpha_equal(relay.Tuple([]), relay.Tuple([])) + + tup = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])]) + same = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)])]) + + assert alpha_equal(tup, same) + + # use the eq_map + let_tup = relay.Let(v1, tup, v1, None) + let_mapped = relay.Let(v2, relay.Tuple([v2, convert(2), convert(3), + relay.Tuple([convert(4)])]), + v2, None) + assert alpha_equal(let_tup, let_mapped) + + more_fields = relay.Tuple([v1, convert(2), convert(3), relay.Tuple([convert(4)]), v2]) + assert not alpha_equal(tup, more_fields) + + fewer_fields = relay.Tuple([v1, convert(2), convert(3)]) + assert not alpha_equal(tup, fewer_fields) + + different_end = relay.Tuple([v1, convert(2), convert(3), + relay.Tuple([convert(5)])]) + assert not alpha_equal(tup, different_end) + + different_start = relay.Tuple([v2, convert(2), convert(3), + relay.Tuple([convert(4)])]) + assert not alpha_equal(tup, different_start) + + longer_at_end = relay.Tuple([v1, convert(2), convert(3), + relay.Tuple([convert(4), convert(5)])]) + assert not alpha_equal(tup, longer_at_end) + + +def test_tuple_get_item_alpha_equal(): + x = relay.Var('x') + y = relay.Var('y') + assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1)) + assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2)) + assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) + + +def test_param_alpha_equal(): + # only checks equality of the types + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + p1 = relay.Param(v1, relay.TensorType((1, 2, 3), "float32")) + p2 = relay.Param(v2, relay.TensorType((1, 2, 3), "float32")) + assert alpha_equal(p1, p2) + + p3 = relay.Param(v1, relay.TensorType((4, 5, 6), "int8")) + assert not alpha_equal(p1, p3) + + p4 = relay.Param(v1, relay.TupleType([relay.TensorType((1, 2, 3), + "float32")])) + assert not alpha_equal(p1, p4) + + +def test_function_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + v3 = relay.Var("v3") + v4 = relay.Var("v4") + + tt1 = relay.TensorType((1, 2, 3), "float32") + tt2 = relay.TensorType((4, 5, 6), "int8") + tt3 = relay.TupleType([tt1, tt2]) + + tp1 = relay.TypeParam("tp1", relay.Kind.Type) + tp2 = relay.TypeParam("tp2", relay.Kind.Type) + tp3 = relay.TypeParam("tp3", relay.Kind.Shape) + tp4 = relay.TypeParam("tp4", relay.Kind.Shape) + + basic_args = [relay.Param(v3, tt1), relay.Param(v4, tt2)] + basic_tps = [tp1, tp2] + + func = relay.Function([relay.Param(v1, tt1), relay.Param(v2, tt2)], + tt2, v2, basic_tps) + mapped = relay.Function(basic_args, tt2, v4, basic_tps) + assert alpha_equal(func, mapped) + + fewer_params = relay.Function([relay.Param(v4, tt2)], tt2, v4, basic_tps) + assert not alpha_equal(func, fewer_params) + + more_params = relay.Function([relay.Param(v3, tt1), relay.Param(v4, tt2), + relay.Param(v2, tt2)], tt2, v4, basic_tps) + assert not alpha_equal(func, more_params) + + params_unordered = relay.Function([relay.Param(v3, tt2), + relay.Param(v4, tt1)], + tt1, v3, basic_tps) + assert not alpha_equal(func, params_unordered) + + params_mismatch = relay.Function([relay.Param(v3, tt3), + relay.Param(v4, tt2)], + tt2, v4, basic_tps) + assert not alpha_equal(func, params_mismatch) + + # also would not typecheck + ret_type_mismatch = relay.Function(basic_args, tt1, v4, basic_tps) + assert not alpha_equal(func, ret_type_mismatch) + + # also mis-typed + different_body = relay.Function(basic_args, tt2, v3, basic_tps) + assert not alpha_equal(func, different_body) + + fewer_type_params = relay.Function(basic_args, tt2, v4, [tp1]) + assert not alpha_equal(func, fewer_type_params) + + more_type_params = relay.Function(basic_args, tt2, v4, [tp1, tp2, tp3]) + assert not alpha_equal(func, more_type_params) + + type_params_unordered = relay.Function(basic_args, tt2, v4, [tp2, tp1]) + assert not alpha_equal(func, type_params_unordered) + + different_type_params = relay.Function(basic_args, tt2, v4, [tp3, tp4]) + assert not alpha_equal(func, different_type_params) + + # a well-typed example that also differs in body, ret type, and type params + tupled_example = relay.Function(basic_args, tt3, relay.Tuple([v3, v4])) + assert not alpha_equal(func, tupled_example) + + +def test_call_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + # attrs are compared only by pointer equality + attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4)) + + tt1 = relay.TensorType((1, 2, 3), "float32") + tt2 = relay.TensorType((), "int8") + + basic_args = [convert(1), convert(2), v2, relay.Tuple([])] + + # manually writing out args to ensure that args does not rely on + # pointer equality + call = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([])], + attr1, [tt1]) + same = relay.Call(v1, basic_args, attr1, [tt1]) + assert alpha_equal(call, same) + + different_fn = relay.Call(v2, basic_args, attr1, [tt1]) + assert not alpha_equal(call, different_fn) + + fewer_args = relay.Call(v1, [convert(1), convert(2), v2], attr1, [tt1]) + assert not alpha_equal(call, fewer_args) + + reordered_args = relay.Call(v1, [convert(2), convert(1), + relay.Tuple([]), v2], attr1, [tt1]) + assert not alpha_equal(call, reordered_args) + + different_args = relay.Call(v1, [convert(1), convert(2), convert(3)], + attr1, [tt1]) + assert not alpha_equal(call, different_args) + + more_args = relay.Call(v1, [convert(1), convert(2), v2, relay.Tuple([]), + convert(3), convert(4)], attr1, [tt1]) + assert not alpha_equal(call, more_args) + + different_attrs = relay.Call(v1, basic_args, attr2, [tt1]) + assert not alpha_equal(call, different_attrs) + + no_type_args = relay.Call(v1, basic_args, attr1) + assert not alpha_equal(call, no_type_args) + + more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2]) + assert not alpha_equal(call, more_type_args) + + different_type_arg = relay.Call(v1, basic_args, attr1, [tt2]) + assert not alpha_equal(call, different_type_arg) + + +def test_let_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + v3 = relay.Var("v3") + + let = relay.Let(v1, convert(2), v1) + mapped = relay.Let(v2, convert(2), v2) + assert alpha_equal(let, mapped) + + mismatched_var = relay.Let(v2, convert(2), v3) + assert not alpha_equal(let, mismatched_var) + + different_value = relay.Let(v2, convert(3), v2) + assert not alpha_equal(let, different_value) + + different_body = relay.Let(v2, convert(3), convert(12)) + assert not alpha_equal(let, different_body) + + # specified types must match + tt1 = relay.TensorType((), "float32") + tt2 = relay.TensorType((), "int8") + let_with_type = relay.Let(v1, convert(2), v1, tt1) + same_type = relay.Let(v1, convert(2), v1, tt1) + assert alpha_equal(let_with_type, same_type) + assert not alpha_equal(let, let_with_type) + + different_type = relay.Let(v1, convert(2), v1, tt2) + assert not alpha_equal(let_with_type, different_type) + + +def test_if_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.Var("v2") + + if_sample = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)])) + same = relay.If(v1, convert(1), relay.Tuple([convert(2), convert(3)])) + assert alpha_equal(if_sample, same) + + different_cond = relay.If(v2, convert(1), relay.Tuple([convert(2), convert(3)])) + assert not alpha_equal(if_sample, different_cond) + + different_true = relay.If(v1, convert(2), relay.Tuple([convert(2), convert(3)])) + assert not alpha_equal(if_sample, different_true) + + different_false = relay.If(v1, convert(1), relay.Tuple([])) + assert not alpha_equal(if_sample, different_false) + + +def test_op_alpha_equal(): + # only checks names + op1 = relay.op.get("add") + op2 = relay.op.get("add") + assert alpha_equal(op1, op2) + + op3 = relay.op.get("take") + assert not alpha_equal(op1, op3) + + +if __name__ == "__main__": + test_tensor_type_alpha_equal() + test_incomplete_type_alpha_equal() + test_constant_alpha_equal() + test_type_param_alpha_equal() + test_func_type_alpha_equal() + test_tuple_type_alpha_equal() + test_type_relation_alpha_equal() + test_constant_alpha_equal() + test_var_alpha_equal() + test_global_var_alpha_equal() + test_tuple_alpha_equal() + test_tuple_get_item_alpha_equal() + test_param_alpha_equal() + test_function_alpha_equal() + test_call_alpha_equal() + test_let_alpha_equal() + test_if_alpha_equal() + test_op_alpha_equal() diff --git a/tests/python/relay/test_check_kind.py b/tests/python/relay/test_pass_check_kind.py similarity index 100% rename from tests/python/relay/test_check_kind.py rename to tests/python/relay/test_pass_check_kind.py diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py new file mode 100644 index 000000000000..ce9bda3d254f --- /dev/null +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -0,0 +1,93 @@ +import tvm +from tvm import relay +from tvm.relay.ir_pass import dead_code_elimination, alpha_equal +from tvm.relay.ir_builder import convert, IRBuilder +from tvm.relay.op import log, add, equal, subtract + + +class env: + def __init__(self): + self.a = relay.Var("a") + self.b = relay.Var("b") + self.c = relay.Var("c") + self.d = relay.Var("d") + self.e = relay.Var("e") + self.x = relay.Var("x") + self.y = relay.Var("y") + self.z = relay.Var("z") + self.shape = tvm.convert([1, 2, 3]) + self.tt = relay.TensorType(self.shape, "float32") + self.int32 = relay.TensorType([], "int32") + self.float32 = relay.TensorType([], "float32") + self.one = convert(1.0) + self.two = convert(2.0) + self.three = convert(3.0) + + +e = env() + + +def test_let(): + orig = relay.Let(e.x, e.y, e.z, e.tt) + assert alpha_equal(dead_code_elimination(orig), e.z) + + +def test_used_let(): + orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt) + assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt)) + + +def test_chain_unused_let(): + orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt) + assert alpha_equal(dead_code_elimination(orig), e.e) + + +# make sure we dont infinite loop +def test_recursion(): + """ + Program: + let f(n: i32, data: f32) -> f32 = { + if (n == 0) { + return data; + } else { + return f(n - 1, log(data)); + } + } + f(2, 10000); + """ + f = relay.Var("f") + n = relay.Var("n") + np = relay.Param(n, e.int32) + data = relay.Var("data") + datap = relay.Param(data, e.float32) + funcbody = relay.If(equal(n, convert(0)), data, f(subtract(n, convert(1.0)), log(data))) + value = relay.Function([np, datap], e.float32, funcbody, []) + orig = relay.Let(f, funcbody, f(convert(2.0), convert(10000.0)), e.float32) + assert alpha_equal(dead_code_elimination(orig), orig) + assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three) + + +def test_op_let(): + assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two)) + + +def test_if(): + orig = relay.If(convert(True), e.a, e.b) + assert alpha_equal(dead_code_elimination(orig), e.a) + + +def test_tuple_get_item(): + t = relay.Var('t') + g = relay.TupleGetItem(t, 0) + assert alpha_equal(dead_code_elimination(g), g) + assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g) + + +if __name__ == "__main__": + test_let() + test_used_let() + test_chain_unused_let() + test_recursion() + test_op_let() + test_if() + test_tuple_get_item() diff --git a/tests/python/relay/test_free_vars.py b/tests/python/relay/test_pass_free_vars.py similarity index 78% rename from tests/python/relay/test_free_vars.py rename to tests/python/relay/test_pass_free_vars.py index 002646ada582..989c9f8d25db 100644 --- a/tests/python/relay/test_free_vars.py +++ b/tests/python/relay/test_pass_free_vars.py @@ -15,6 +15,17 @@ def test_free_vars(): f = relay.Function([relay.Param(x, ty)], ty, x) assert len(free_vars(f)) == 0 + +def test_tuple(): + t = relay.Var('t') + fv = free_vars(relay.Tuple([t, t])) + assert len(fv) == 1 + assert fv[0] == t + fv = free_vars(relay.TupleGetItem(t, 123)) + assert len(fv) == 1 + assert fv[0] == t + + def test_free_type_vars(): tp = relay.TypeParam("") ty = relay.TupleType([tp, relay.TensorType([], "int32")]) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 5b8375580424..77b04590df59 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -7,12 +7,13 @@ from tvm.relay.ir_builder import IRBuilder, func_type from tvm.relay.ir_builder import scalar_type, convert, tensor_type from tvm.relay.env import Environment -from tvm.relay.op import log, add, equal, subtract, concat +from tvm.relay.op import log, add, equal, subtract, concatenate from tvm.relay.expr import Function +from tvm import relay def assert_has_type(expr, typ, env=Environment({})): checked_expr = infer_type(env, expr) - checked_type = checked_expr.checked_type() + checked_type = checked_expr.checked_type if checked_type != typ: raise RuntimeError("Type mismatch %s vs %s" % ( checked_type, typ)) @@ -20,7 +21,7 @@ def assert_has_type(expr, typ, env=Environment({})): def assert_decl_has_type(env, name, typ): func = env[name] - assert func.checked_type() == typ + assert func.checked_type == typ def test_monomorphic_let(): @@ -32,54 +33,6 @@ def test_monomorphic_let(): prog, env = b.get() assert_has_type(prog, scalar_type('float64')) - -def test_single_op(): - "Program: fn (x : float32) { let t1 = f(x); t1 }" - b = IRBuilder() - with b.function(('x', 'float32')) as func: - x, = func.param_ids() - t1 = b.let('t1', log(x)) - b.ret(t1) - assert_has_type(func.to_func(), func_type(['float32'], 'float32')) - -def test_add_op(): - """ - Program: - fn (x, y) { - return x + y; - } - """ - b = IRBuilder() - - x = b.param('x', tensor_type(5, 5, 5)) - y = b.param('y', tensor_type(5, 5, 5)) - with b.function(x, y) as func: - b.ret(add(x.var, y.var)) - b.ret(func) - prog, env = b.get() - ttype = tensor_type(5, 5, 5) - expected_ty = func_type([ttype, ttype], ttype) - assert_has_type(func.to_func(), expected_ty) - -def test_add_broadcast_op(): - """ - Program: - fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { - return x + y; - } - """ - b = IRBuilder() - x = b.param('x', tensor_type(10, 4)) - y = b.param('y', tensor_type(5, 10, 1)) - with b.function(x, y) as func: - b.ret(add(x.var, y.var)) - b.ret(func) - prog, env = b.get() - - expected_ty = func_type([tensor_type(10, 4), tensor_type(5, 10, 1)], - tensor_type(5, 10, 4)) - assert_has_type(func.to_func(), expected_ty) - def test_dual_op(): """Program: fn (x : Tensor[f32, (10, 10)]) { @@ -120,9 +73,9 @@ def test_recursion(): Program: def f(n: i32, data: f32) -> f32 { if (n == 0) { - return f(n - 1, log(data)); - } else { return data; + } else { + return f(n - 1, log(data)); } } f(2, 10000); @@ -133,9 +86,9 @@ def f(n: i32, data: f32) -> f32 { data = b.param('data', ty='float32') with b.decl(f, n, data): with b.if_scope(equal(n, convert(0))): - b.ret(f(subtract(n, convert(1)), log(data))) - with b.else_scope(): b.ret(data) + with b.else_scope(): + b.ret(f(subtract(n, convert(1)), log(data))) b.ret(f(convert(2.0), convert(10000.0))) assert_decl_has_type(b.env, 'f', func_type( ['int32', 'float32'], 'float32')) @@ -146,7 +99,7 @@ def test_concat(): """ Program: def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { - return concat(x, y); + return concatenate((x, y), axis=0); } """ ib = IRBuilder() @@ -154,17 +107,25 @@ def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { x = ib.param('x', ty=tensor_type(3, 2)) y = ib.param('y', ty=tensor_type(2, 2)) with ib.decl(try_concat2, x, y): - ib.ret(concat(x, y)) + ib.ret(concatenate((x, y), axis=0)) fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2)) assert_decl_has_type(ib.env, try_concat2, fn_ty) +def test_tuple(): + ib = IRBuilder() + dup = ib.global_var('dup') + x = ib.param('x') + with ib.decl(dup, x): + ib.ret(relay.Tuple([x, x])) + # todo: why is this not generalized? + fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()])) + assert_decl_has_type(ib.env, dup, fn_ty) + if __name__ == "__main__": test_dual_op() - test_recursion() test_monomorphic_let() - test_single_op() - test_add_op() - test_add_broadcast_op() test_decl() + test_recursion() test_concat() + test_tuple() diff --git a/tests/python/unittest/test_codegen_bool.py b/tests/python/unittest/test_codegen_bool.py new file mode 100644 index 000000000000..e2592c416345 --- /dev/null +++ b/tests/python/unittest/test_codegen_bool.py @@ -0,0 +1,58 @@ +"""codegen related to bool types""" + +import tvm +import numpy as np + +def test_cmp_load_store(): + n = 32 + A = tvm.placeholder((n,), name='A') + B = tvm.placeholder((n,), name='B') + C = tvm.compute(A.shape, lambda *i: A(*i) > B(*i), name='C') + D = tvm.compute(C.shape, lambda *i: tvm.all(C(*i), A(*i) > 1), name="D") + + + def check_llvm(): + if not tvm.module.enabled("llvm"): + return + s = tvm.create_schedule(D.op) + xo, xi = s[C].split(C.op.axis[0], factor=4) + xo1, xo2 = s[C].split(xo, factor=13) + s[C].parallel(xo2) + # BUILD and invoke the kernel. + f = tvm.build(s, [A, B, D], "llvm") + ctx = tvm.cpu(0) + a_np = np.random.uniform(size=n).astype(A.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx) + f(a, b, d) + np.testing.assert_equal( + d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1)) + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + return + s = tvm.create_schedule(D.op) + for stage in [C, D]: + xo, xi = s[stage].split(stage.op.axis[0], factor=4) + s[stage].bind(xo, tvm.thread_axis("blockIdx.x")) + s[stage].bind(xi, tvm.thread_axis("threadIdx.x")) + f = tvm.build(s, [A, B, D], device) + a_np = np.random.uniform(size=n).astype(A.dtype) + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), ctx) + d = tvm.nd.array(np.zeros(n, dtype=D.dtype), ctx) + f(a, b, d) + np.testing.assert_equal( + d.asnumpy(), np.logical_and(a.asnumpy()> b.asnumpy(), a.asnumpy() > 1)) + + + check_llvm() + for device in ["vulkan", "opencl", "cuda", "rocm", "metal"]: + check_device(device) + + + +if __name__ == "__main__": + test_cmp_load_store() diff --git a/tests/python/unittest/test_lang_basic.py b/tests/python/unittest/test_lang_basic.py index bf25ca3dfc85..079123d96ca0 100644 --- a/tests/python/unittest/test_lang_basic.py +++ b/tests/python/unittest/test_lang_basic.py @@ -79,7 +79,7 @@ def test_dtype(): x = tvm.var('x') assert x.dtype == 'int32' y = tvm.var('y') - assert (x > y).dtype == 'uint1' + assert (x > y).dtype == 'bool' def test_any(): diff --git a/tests/python/unittest/test_lang_tensor.py b/tests/python/unittest/test_lang_tensor.py index f562a48e44ae..50492ca41fca 100644 --- a/tests/python/unittest/test_lang_tensor.py +++ b/tests/python/unittest/test_lang_tensor.py @@ -85,6 +85,78 @@ def test_tensor_reduce(): assert(isinstance(C_loaded, tvm.tensor.Tensor)) assert(str(C_loaded) == str(C)) +def test_tensor_compute1(): + m = 1024 + factor = 16 + dtype = 'float32' + + def intrin_vadd(n): + x = tvm.placeholder((n,)) + y = tvm.placeholder((n,)) + z = tvm.compute(x.shape, lambda i: x[i] + y[i]) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) + return ib.get() + + with tvm.build_config(offset_factor=n): + return tvm.decl_tensor_intrin(z.op, intrin_func) + + vadd = intrin_vadd(factor) + + A = tvm.placeholder((m//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((m//factor, factor), name="B", dtype=dtype) + C = tvm.compute((m//factor, factor), + lambda i: vadd(A[i, 0:factor], B[i, 0:factor])) + + s = tvm.create_schedule(C.op) + stmt = tvm.lower(s, [A, B, C], simple_mode=True) + assert isinstance(stmt.body.body, tvm.stmt.Evaluate) + +def test_tensor_compute2(): + M = 2048 + N = 1024 + L = 1024 + factor = 16 + factor1 = 32 + factor2 = 32 + dtype = 'float32' + + def intrin_gemm(m, n, l): + k = tvm.reduce_axis((0, l)) + x = tvm.placeholder((m, l)) + y = tvm.placeholder((n, l)) + # in theory, no relation + z = tvm.compute((m, n), lambda i, j: tvm.sum(x[i][k] * y[j][k], axis=k)) + + def intrin_func(ins, outs): + x_ptr = ins[0].access_ptr("r") + y_ptr = ins[1].access_ptr("r") + z_ptr = outs[0].access_ptr("w") + body = tvm.call_packed( + "gemv", x_ptr, y_ptr, z_ptr, m, n, l) + reset = tvm.call_packed( + "fill_zero", z_ptr, m, n) + update = tvm.call_packed( + "gemv_add", x_ptr, y_ptr, z_ptr, m, n, l) + return body, reset, update + + with tvm.build_config(offset_factor=n): + return tvm.decl_tensor_intrin(z.op, intrin_func) + + vgemm = intrin_gemm(factor1, factor2, factor) + + A = tvm.placeholder((M//factor1, L//factor, factor1, factor), name="A", dtype=dtype) + B = tvm.placeholder((N//factor2, L//factor, factor2, factor), name="B", dtype=dtype) + k = tvm.reduce_axis((0, L//factor), name='k') + C = tvm.compute((M//factor1, N//factor2, factor1, factor2), + lambda i, j: vgemm(A[i, k, 0:factor1, 0:factor], B[j, k, 0:factor2, 0:factor], reduce_axis=k)) + + s = tvm.create_schedule(C.op) + stmt = tvm.lower(s, [A, B, C], simple_mode=True) + assert isinstance(stmt.body.body.body.first, tvm.stmt.Evaluate) + assert isinstance(stmt.body.body.body.rest.body, tvm.stmt.Evaluate) def test_tensor_scan(): m = tvm.var("m") @@ -221,6 +293,8 @@ def intrin_func(ins, outs): test_conv1d() test_tensor_slice() test_tensor() + test_tensor_compute1() + test_tensor_compute2() test_tensor_reduce() test_tensor_scan() test_scan_multi_out() diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index 8e6f4090d403..8774514cfa17 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -276,6 +276,133 @@ def test_schedule_bound_condition(): stmt = tvm.ir_pass.Simplify(stmt) assert (isinstance(stmt.body.body.first.body.body.then_case, tvm.stmt.IfThenElse)) + +def intrin_gemv(m, n): + w = tvm.placeholder((m, n), name='w') + x = tvm.placeholder((n,), name='x') + k = tvm.reduce_axis((0, n), name='k') + z = tvm.compute((m,), lambda i: + tvm.sum(w[i, k] * x[k], axis=k), name='z') + Wb = tvm.decl_buffer(w.shape, w.dtype, + name="W", + offset_factor=16, + strides=[tvm.var('ldw'), 1]) + def intrin_func(ins, outs): + ww, xx = ins + zz = outs[0] + ww_ptr = ww.access_ptr("r") + xx_ptr = xx.access_ptr("r") + zz_ptr = zz.access_ptr("w") + body = tvm.call_packed( + "gemm", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + reset = tvm.call_packed( + "fill_zero", zz_ptr, n) + update = tvm.call_packed( + "gemv_add", ww_ptr, xx_ptr, zz_ptr, n, ww.strides[0]) + return body, reset, update + + with tvm.build_config(data_alignment=16, + offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, + binds={w: Wb}) + + +def test_schedule_tensor_compute1(): + # basic: split, reorder, tile + M, N, L = 2048, 1024, 512 + factor, rfactor = 16, 16 + A = tvm.placeholder((N//factor, L//rfactor, factor, rfactor), name='A') + B = tvm.placeholder((M, L//rfactor, rfactor), name='B') + k = tvm.reduce_axis((0, L//rfactor), name='k') + + gemv = intrin_gemv(factor, rfactor) + C = tvm.compute((N, M//factor, factor), + lambda i, j: gemv(A[i, k, 0:factor, 0:factor], B[j, k, 0:rfactor], reduce_axis=k), + name='C') + + s = tvm.create_schedule(C.op) + ai, aj, ax = s[C].op.axis + aio, aii = s[C].split(ai, 16) + s[C].reorder(aio, aj, aii) + aioo, ajo, aioi, aji = s[C].tile(aio, aj, 16, 4) + + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def intrin_vadd(n, cache_read=False, cache_write=False): + scope_ubuf = 'local' + dtype = 'float32' + x = tvm.placeholder((n,), dtype=dtype, name='vx') + y = tvm.placeholder((n,), dtype=dtype, name='vy') + z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z') + s = tvm.create_schedule(z.op) + + def create_buffer(t): + return tvm.decl_buffer(t.shape, t.dtype, + name='W'+t.name, + scope=scope_ubuf, + offset_factor=16) + + binds = {} + if cache_read: + binds[x] = create_buffer(x) + binds[y] = create_buffer(y) + if cache_write: + binds[z] = create_buffer(z) + + def intrin_func(ins, outs): + ib = tvm.ir_builder.create() + ib.emit(tvm.call_extern(outs[0].dtype, 'vadd', ins[0].access_ptr("r"), ins[1].access_ptr('r'), outs[0].access_ptr('wr'))) + return ib.get() + + with tvm.build_config(offset_factor=16): + return tvm.decl_tensor_intrin(z.op, intrin_func, binds=binds) + + +def test_schedule_tensor_compute2(): + # cache_read, cache_write + M = 1024 + factor = 16 + dtype = 'float32' + scope_ubuf = 'local' + + A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype) + + vadd = intrin_vadd(factor, True, True) + C = tvm.compute((M//factor, factor), + lambda i: vadd(A[i, 0:factor], B[i, 0:factor]), name='C') + + s = tvm.create_schedule(C.op) + AL = s.cache_read(A, scope_ubuf, C) + BL = s.cache_read(B, scope_ubuf, C) + CL = s.cache_write(C, scope_ubuf) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + +def test_schedule_tensor_compute3(): + # compute_at + M = 1024 + factor = 16 + dtype = 'float32' + A = tvm.placeholder((M//factor, factor), name="A", dtype=dtype) + B = tvm.placeholder((M//factor, factor), name="B", dtype=dtype) + Bi = tvm.compute((M//factor, factor), lambda i, j: B[i, j] + 5, name="Bi") + + vadd = intrin_vadd(factor) + C = tvm.compute((M//factor, factor), + lambda i: vadd(A[i, 0:factor], Bi[i, 0:factor]), name='C') + s = tvm.create_schedule(C.op) + s[Bi].compute_at(s[C], C.op.axis[0]) + s = s.normalize() + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + if __name__ == "__main__": test_schedule_middle_cache() test_inline_multi_reduce() @@ -294,3 +421,6 @@ def test_schedule_bound_condition(): test_schedule2() test_schedule_cache() test_schedule_bound_condition() + test_schedule_tensor_compute1() + test_schedule_tensor_compute2() + test_schedule_tensor_compute3() diff --git a/tests/scripts/task_rust.sh b/tests/scripts/task_rust.sh new file mode 100755 index 000000000000..8ef9a1a1556f --- /dev/null +++ b/tests/scripts/task_rust.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +set -e + +export LD_LIBRARY_PATH=lib:$LD_LIBRARY_PATH + +tvm_root="$(git rev-parse --show-toplevel)" +export PYTHONPATH="$tvm_root/python":"$tvm_root/nnvm/python":"$tvm_root/topi/python" + +cd rust +cargo fmt -- --check + +# run basic tests +python3 tests/build_model.py +cargo test --tests + +# run TVM module test +cd tests/test_tvm_basic +cargo run +cd - + +# run NNVM graph test +cd tests/test_nnvm +cargo run +cd - diff --git a/topi/include/topi/transform.h b/topi/include/topi/transform.h index e4e646453cca..756aa2ec3b49 100644 --- a/topi/include/topi/transform.h +++ b/topi/include/topi/transform.h @@ -38,10 +38,6 @@ inline Tensor expand_dims(const Tensor& x, std::string name = "tensor", std::string tag = kBroadcast) { int ndim = static_cast(x->shape.size()); - if (axis < 0) { - // Calculate offset from last dimension - axis = ndim + axis + 1; - } CHECK(-ndim - 1 <= axis && axis <= ndim) << "expand_dims only accepts `axis` in [-data.ndim - 1, data.ndim]" << ", but got axis = " << axis @@ -49,7 +45,10 @@ inline Tensor expand_dims(const Tensor& x, CHECK(num_newaxis >= 0) << "expand_dims only accepts `num_newaxis >= 0`" << ", but got num_newaxis = " << num_newaxis; - + if (axis < 0) { + // Calculate offset from last dimension + axis = ndim + axis + 1; + } Array new_shape; for (size_t i = 0; i < static_cast(axis); ++i) { new_shape.push_back(x->shape[i]); @@ -265,8 +264,13 @@ inline Tensor concatenate(const Array& inputs, int axis = 0, std::string name = "tensor", std::string tag = kInjective) { + int ndim = static_cast(inputs[0]->shape.size()); + CHECK(-ndim <= axis && axis < ndim) + << "concatenate only accepts `axis` in [-ndim, ndim)" + << ", but got axis = " << axis + << ", and ndim = " << ndim; if (axis < 0) { - axis += static_cast(inputs[0]->shape.size()); + axis += ndim; } CHECK_LT(axis, inputs[0]->shape.size()) << "axis out of bounds"; diff --git a/tutorials/nnvm/from_onnx.py b/tutorials/nnvm/from_onnx.py index df8dee8272ce..0fdef8afa98c 100644 --- a/tutorials/nnvm/from_onnx.py +++ b/tutorials/nnvm/from_onnx.py @@ -46,7 +46,7 @@ def download(url, path, overwrite=False): 'super_resolution_0.2.onnx']) download(model_url, 'super_resolution.onnx', True) # now you have super_resolution.onnx on disk -onnx_model = onnx.load('super_resolution.onnx') +onnx_model = onnx.load_model('super_resolution.onnx') # we can load the graph as NNVM compatible model sym, params = nnvm.frontend.from_onnx(onnx_model)