diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e615fe54a3c..c6937a5546aa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -23,7 +23,8 @@ endif() tvm_option(USE_CUDA "Build with CUDA" OFF) tvm_option(USE_OPENCL "Build with OpenCL" OFF) tvm_option(USE_METAL "Build with Metal" OFF) -tvm_option(USE_RPC "Build with RPC" OFF) +tvm_option(USE_RPC "Build with RPC" ON) +tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON) tvm_option(USE_LLVM "Build with LLVM" OFF) tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) diff --git a/Jenkinsfile b/Jenkinsfile index d7fa946e49a9..ef9666351ba5 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -90,6 +90,7 @@ stage('Build') { echo USE_OPENCL=1 >> config.mk echo LLVM_CONFIG=llvm-config-4.0 >> config.mk echo USE_RPC=1 >> config.mk + echo USE_GRAPH_RUNTIME=1 >> config.mk echo USE_BLAS=openblas >> config.mk rm -f lib/libtvm_runtime.so lib/libtvm.so """ diff --git a/Makefile b/Makefile index 1073e2680dfa..c67d5e845334 100644 --- a/Makefile +++ b/Makefile @@ -53,6 +53,7 @@ CUDA_SRC = $(wildcard src/runtime/cuda/*.cc) ROCM_SRC = $(wildcard src/runtime/rocm/*.cc) OPENCL_SRC = $(wildcard src/runtime/opencl/*.cc) RPC_SRC = $(wildcard src/runtime/rpc/*.cc) +GRAPH_SRC = $(wildcard src/runtime/graph/*.cc) RUNTIME_SRC = $(wildcard src/runtime/*.cc) # Objectives @@ -63,6 +64,7 @@ CUDA_OBJ = $(patsubst src/%.cc, build/%.o, $(CUDA_SRC)) ROCM_OBJ = $(patsubst src/%.cc, build/%.o, $(ROCM_SRC)) OPENCL_OBJ = $(patsubst src/%.cc, build/%.o, $(OPENCL_SRC)) RPC_OBJ = $(patsubst src/%.cc, build/%.o, $(RPC_SRC)) +GRAPH_OBJ = $(patsubst src/%.cc, build/%.o, $(GRAPH_SRC)) CC_OBJ = $(patsubst src/%.cc, build/%.o, $(CC_SRC)) $(LLVM_OBJ) RUNTIME_OBJ = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC)) CONTRIB_OBJ = @@ -124,6 +126,10 @@ ifeq ($(USE_RPC), 1) RUNTIME_DEP += $(RPC_OBJ) endif +ifeq ($(USE_GRAPH_RUNTIME), 1) + RUNTIME_DEP += $(GRAPH_OBJ) +endif + include make/contrib/cblas.mk include make/contrib/nnpack.mk include make/contrib/cudnn.mk diff --git a/docs/api/python/contrib.rst b/docs/api/python/contrib.rst index 7c50fafb7166..def68e390dd9 100644 --- a/docs/api/python/contrib.rst +++ b/docs/api/python/contrib.rst @@ -22,6 +22,11 @@ tvm.contrib.rpc .. automodule:: tvm.contrib.rpc :members: +tvm.contrib.graph +~~~~~~~~~~~~~~~~~ +.. automodule:: tvm.contrib.graph + :members: + tvm.contrib.util ~~~~~~~~~~~~~~~~ .. automodule:: tvm.contrib.util diff --git a/make/config.mk b/make/config.mk index a4e6fc7b20df..53775df1ab36 100644 --- a/make/config.mk +++ b/make/config.mk @@ -45,7 +45,10 @@ USE_OPENCL = 0 USE_METAL = 0 # Whether enable RPC during compile -USE_RPC = 0 +USE_RPC = 1 + +# Whether enable tiny embedded graph runtime. +USE_GRAPH_RUNTIME = 1 # whether build with LLVM support # Requires LLVM version >= 4.0 diff --git a/python/tvm/contrib/graph_runtime.py b/python/tvm/contrib/graph_runtime.py new file mode 100644 index 000000000000..8b7954f0ed48 --- /dev/null +++ b/python/tvm/contrib/graph_runtime.py @@ -0,0 +1,131 @@ +"""Minimum graph runtime that executes graph containing TVM PackedFunc.""" +from . import rpc +from .._ffi.base import string_types +from .._ffi.function import get_global_func +from .. import ndarray as nd + + +def create(graph_json_str, libmod, ctx): + """Create a runtime executor module given a graph and module. + + Parameters + ---------- + graph_json_str : str or graph class + The graph to be deployed in json format output by nnvm graph. + The graph can only contain one operator(tvm_op) that + points to the name of PackedFunc in the libmod. + + libmod : tvm.Module + The module of the corresponding function + + ctx : TVMContext + The context to deploy the module, can be local or remote. + + Returns + ------- + graph_module : GraphModule + Runtime graph module that can be used to execute the graph. + """ + if not isinstance(graph_json_str, string_types): + try: + graph_json_str = graph_json_str._tvm_graph_json() + except AttributeError: + raise ValueError("Type %s is not supported" % type(graph_json_str)) + device_type = ctx.device_type + device_id = ctx.device_id + if device_type >= rpc.RPC_SESS_MASK: + assert libmod.type_key == "rpc" + assert rpc._SessTableIndex(libmod) == ctx._rpc_sess._tbl_index + hmod = rpc._ModuleHandle(libmod) + fcreate = ctx._rpc_sess.get_function("tvm.graph_runtime.remote_create") + device_type = device_type % rpc.RPC_SESS_MASK + return GraphModule(fcreate(graph_json_str, hmod, device_type, device_id), ctx) + fcreate = get_global_func("tvm.graph_runtime.create") + return GraphModule(fcreate(graph_json_str, libmod, device_type, device_id), ctx) + + +class GraphModule(object): + """Wrapper runtime module. + + This is a thin wrapper of the underlying TVM module. + you can also directly call set_input, run, and get_output + of underlying module functions + + Parameters + ---------- + module : Module + The interal tvm module that holds the actual graph functions. + + ctx : TVMContext + The context this module is under + + Attributes + ---------- + module : Module + The interal tvm module that holds the actual graph functions. + + ctx : TVMContext + The context this module is under + """ + def __init__(self, module, ctx): + self.module = module + self._set_input = module["set_input"] + self._run = module["run"] + self._get_output = module["get_output"] + self.ctx = ctx + + def set_input(self, key=None, value=None, **params): + """Set inputs to the module via kwargs + + Parameters + ---------- + key : int or str + The input key + + value : the input value. + The input key + + params : dict of str to NDArray + Additonal arguments + """ + if key: + self._set_input(key, nd.array(value, ctx=self.ctx)) + for k, v in params.items(): + self._set_input(k, nd.array(v, ctx=self.ctx)) + return self + + def run(self, **input_dict): + """Run forward execution of the graph + + Parameters + ---------- + input_dict: dict of str to NDArray + List of input values to be feed to + """ + if input_dict: + self.set_input(**input_dict) + self._run() + + def get_output(self, index, out): + """Get index-th output to out + + Parameters + ---------- + index : int + The input index + + out : NDArray + The output array container + """ + self._get_output(index, out) + return out + + def __getitem__(self, key): + """Get internal module function + + Parameters + ---------- + key : str + The key to the module. + """ + return self.module[key] diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc new file mode 100644 index 000000000000..d1160f523773 --- /dev/null +++ b/src/runtime/graph/graph_runtime.cc @@ -0,0 +1,556 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file graph_runtime.cc + */ +#include +#include +#include +#include +#include +#include "./graph_runtime.h" + +namespace tvm { +namespace runtime { + +/*! \brief macro to do C API call */ +#define TVM_CCALL(func) \ + { \ + int ret = (func); \ + CHECK_EQ(ret, 0) \ + << TVMGetLastError(); \ + } + +/*! + * \brief Tiny graph runtime. + * + * This runtime can be acccesibly in various language via + * TVM runtime PackedFunc API. + */ +class GraphRuntime : public ModuleNode { + public: + ~GraphRuntime() { + for (DLTensor* t : storage_pool_) { + TVM_CCALL(TVMArrayFree(t)); + } + } + /*! + * \brief Get member function to front-end + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * \return The corresponding member function. + */ + PackedFunc GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + /*! + * \return The type key of the executor. + */ + const char* type_key() const final { + return "GraphRuntime"; + } + void Run() { + // setup the array and requirements. + for (size_t i = 0; i < op_execs_.size(); ++i) { + if (op_execs_[i]) op_execs_[i](); + } + } + /*! + * \brief Initialize the graph executor with graph and context. + * \param graph The execution graph. + * \param module The module containing the compiled functions. + * \param ctx The context where the graph should sit on + */ + void Init(const std::string& graph_json, + tvm::runtime::Module module, + TVMContext ctx) { + std::istringstream is(graph_json); + dmlc::JSONReader reader(&is); + this->Load(&reader); + module_ = module; + ctx_ = ctx; + this->SetupStorage(); + this->SetupOpExecs(); + } + /*! + * \brief Get the input index given the name of input. + * \param name The name of the input. + * \return The index of input. + */ + int GetInputIndex(const std::string& name) { + for (size_t i = 0; i< input_nodes_.size(); ++i) { + uint32_t nid = input_nodes_[i]; + if (nodes_[nid].name == name) { + return static_cast(i); + } + } + LOG(FATAL) << "cannot find " << name << " among input"; + return -1; + } + /*! + * \brief set index-th input to the graph. + * \param index The input index. + * \param data The input data. + */ + void SetInput(int index, DLTensor* data_in) { + CHECK_LT(static_cast(index), input_nodes_.size()); + uint32_t eid = this->entry_id(input_nodes_[index], 0); + TVM_CCALL(TVMArrayCopyFromTo(data_in, &data_entry_[eid], nullptr)); + } + /*! + * \brief Copy index-th output to data_out. + * \param index The output index. + * \param data_out the output data. + */ + void GetOutput(int index, DLTensor* data_out) { + CHECK_LT(static_cast(index), outputs_.size()); + uint32_t eid = this->entry_id(outputs_[index]); + TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr)); + } + + /*! + * \brief Load parameters from binary stream + * \param strm The input stream. + */ + void LoadParams(dmlc::Stream* strm); + /*! + * \brief Load parameters from parameter blob. + * \param param_blob A binary blob of parameter. + */ + void LoadParams(const std::string& param_blob) { + dmlc::MemoryStringStream strm(const_cast(¶m_blob)); + this->LoadParams(&strm); + } + + private: + // Node entry + struct NodeEntry { + uint32_t node_id; + uint32_t index; + uint32_t version; + // JSON Loader + void Load(dmlc::JSONReader *reader) { + reader->BeginArray(); + CHECK(reader->NextArrayItem()) << "invalid json format"; + reader->Read(&node_id); + CHECK(reader->NextArrayItem()) << "invalid json format"; + reader->Read(&index); + if (reader->NextArrayItem()) { + reader->Read(&version); + CHECK(!reader->NextArrayItem()) << "invalid json format"; + } else { + version = 0; + } + } + }; + // Node + struct Node { + // operator type in string + std::string op_type; + // name of the op + std::string name; + // parameters + TVMOpParam param; + // inputs + std::vector inputs; + // control deps + std::vector control_deps; + // JSON Loader + void Load(dmlc::JSONReader *reader) { + reader->BeginObject(); + std::unordered_map dict; + int bitmask = 0; + std::string key; + while (reader->NextObjectItem(&key)) { + if (key == "op") { + reader->Read(&op_type); + bitmask |= 1; + } else if (key == "name") { + reader->Read(&name); + bitmask |= 2; + } else if (key == "inputs") { + reader->Read(&inputs); + bitmask |= 4; + } else if (key == "attr" || key == "attrs") { + reader->Read(&dict); + param.Init(dict); + } else if (key == "control_deps") { + reader->Read(&control_deps); + } else { + LOG(FATAL) << "do not support key " << key; + } + } + CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + } + }; + struct GraphAttr { + size_t storage_num_not_alloctaed{0}; + std::vector storage_id; + std::vector dltype; + std::vector > shape; + // The graph attribute fields. + void Load(dmlc::JSONReader *reader) { + reader->BeginObject(); + int bitmask = 0; + std::string key, type; + while (reader->NextObjectItem(&key)) { + if (key == "dltype") { + reader->BeginArray(); + CHECK(reader->NextArrayItem()); + reader->Read(&type); + CHECK_EQ(type, "list_str"); + CHECK(reader->NextArrayItem()); + reader->Read(&dltype); + CHECK(!reader->NextArrayItem()); + bitmask |= 1; + } else if (key == "storage_id") { + reader->BeginArray(); + CHECK(reader->NextArrayItem()); + reader->Read(&type); + CHECK_EQ(type, "list_int"); + CHECK(reader->NextArrayItem()); + reader->Read(&storage_id); + CHECK(!reader->NextArrayItem()); + bitmask |= 2; + } else if (key == "shape") { + reader->BeginArray(); + CHECK(reader->NextArrayItem()); + reader->Read(&type); + CHECK_EQ(type, "list_shape"); + CHECK(reader->NextArrayItem()); + reader->Read(&shape); + CHECK(!reader->NextArrayItem()); + bitmask |= 4; + } else { + reader->BeginArray(); + CHECK(reader->NextArrayItem()); + reader->Read(&type); + if (type == "list_int") { + CHECK(reader->NextArrayItem()); + std::vector temp; + reader->Read(&temp); + } else if (type == "size_t") { + CHECK(reader->NextArrayItem()); + size_t temp; + reader->Read(&temp); + } else { + LOG(FATAL) << "cannot skip graph attr " << key; + } + CHECK(!reader->NextArrayItem()); + } + } + CHECK_EQ(bitmask, 1|2|4) << "invalid format"; + } + }; + // The graph attribute fields. + void Load(dmlc::JSONReader *reader) { + reader->BeginObject(); + int bitmask = 0; + std::string key; + while (reader->NextObjectItem(&key)) { + if (key == "nodes") { + reader->Read(&nodes_); + bitmask |= 1; + } else if (key == "arg_nodes") { + reader->Read(&input_nodes_); + bitmask |= 2; + } else if (key == "node_row_ptr") { + reader->Read(&node_row_ptr_); + bitmask |= 4; + } else if (key == "heads") { + reader->Read(&outputs_); + bitmask |= 8; + } else if (key == "attrs") { + reader->Read(&attrs_); + bitmask |= 16; + } + } + CHECK_EQ(bitmask, 1|2|4|8|16) << "invalid format"; + } + bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor); + /*! \brief Setup the temporal storage */ + void SetupStorage(); + /*! \brief Setup the executors */ + void SetupOpExecs(); + /*! + * \brief Create a executtion function given input. + * \param attrs The node attributes + * \param args The arguments to the functor, including inputs and outputs. + * \param num_inputs Number of inputs + * \return The created executor. + */ + std::function CreateTVMOp(const TVMOpParam& attrs, + const std::vector& args, + size_t num_inputs); + // Get node entry index. + uint32_t entry_id(uint32_t nid, uint32_t index) const { + return node_row_ptr_[nid] + index; + } + // Get node entry index. + uint32_t entry_id(const NodeEntry& e) const { + return entry_id(e.node_id, e.index); + } + // Number of node entries + uint32_t num_node_entries() const { + return node_row_ptr_.back(); + } + // Number of nodes. + uint32_t num_nodes() const { + return static_cast(nodes_.size()); + } + // The graph nodes. + std::vector nodes_; + // The argument nodes. + std::vector input_nodes_; + // used or quick entry indexing + std::vector node_row_ptr_; + // output entries + std::vector outputs_; + // Additional graph attributes + GraphAttr attrs_; + /*! \brief The code module */ + tvm::runtime::Module module_; + /*! \brief execution context */ + TVMContext ctx_; + /*! \brief common storage pool */ + std::vector storage_pool_; + /*! \brief data entry of each node */ + std::vector data_entry_; + /*! \brief operator on each node */ + std::vector > op_execs_; +}; + +DMLC_REGISTER_PARAMETER(TVMOpParam); + +bool GraphRuntime::LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) { + uint64_t header, reserved; + CHECK(strm->Read(&header, sizeof(header))) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&reserved, sizeof(reserved))) + << "Invalid DLTensor file format"; + CHECK(header == kTVMNDArrayMagic) + << "Invalid DLTensor file format"; + + CHECK(strm->Read(&tensor->ctx, sizeof(tensor->ctx))) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&tensor->ndim, sizeof(tensor->ndim))) + << "Invalid DLTensor file format"; + CHECK(strm->Read(&tensor->dtype, sizeof(tensor->dtype))) + << "Invalid DLTensor file format"; + + int ndim = tensor->ndim; + CHECK(strm->Read(tensor->shape, sizeof(int64_t) * ndim)) + << "Invalid DLTensor file format"; + + int64_t size = 1; + int type_size = tensor->dtype.bits / 8; + for (int i = 0; i < ndim; ++i) { + size *= tensor->shape[i]; + } + int64_t data_byte_size; + CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size))) + << "Invalid DLTensor file format"; + CHECK(data_byte_size == type_size * size) + << "Invalid DLTensor file format"; + CHECK(strm->Read(tensor->data, type_size * size)) + << "Invalid DLTensor file format"; + return true; +} + +void GraphRuntime::LoadParams(dmlc::Stream* strm) { + uint64_t header, reserved; + CHECK(strm->Read(&header)) + << "Invalid parameters file format"; + CHECK(header == kTVMNDArrayListMagic) + << "Invalid parameters file format"; + CHECK(strm->Read(&reserved)) + << "Invalid parameters file format"; + + std::vector names; + CHECK(strm->Read(&names)) + << "Invalid parameters file format"; + uint64_t sz; + strm->Read(&sz, sizeof(sz)); + size_t size = static_cast(sz); + + CHECK(size == names.size()) + << "Invalid parameters file format"; + + for (size_t i = 0; i < size; ++i) { + uint32_t in_idx = GetInputIndex(names[i]); + CHECK(LoadDLTensor(strm, &data_entry_[this->entry_id(input_nodes_[in_idx], 0)])) + << "Invalid parameters file format"; + } +} + +void GraphRuntime::SetupStorage() { + // Grab saved optimization plan from graph. + std::vector vtype; + for (const std::string& s_type : attrs_.dltype) { + vtype.push_back(tvm::runtime::String2TVMType(s_type)); + } + data_entry_.resize(num_node_entries()); + // Find the maximum space size. + int max_id = 0; + for (size_t i = 0; i < attrs_.shape.size(); ++i) { + max_id = std::max(attrs_.storage_id[i] + 1, max_id); + } + for (uint32_t nid : input_nodes_) { + attrs_.storage_id[this->entry_id(nid, 0)] = max_id++; + } + // size of each storage pool entry + std::vector pool_entry_bytes; + // Find the maximum space size. + for (size_t i = 0; i < attrs_.shape.size(); ++i) { + int storage_id = attrs_.storage_id[i]; + size_t size = 1; + for (int64_t sz : attrs_.shape[i]) { + size *= static_cast(sz); + } + CHECK_GE(storage_id, 0) << "Do not support runtime shape op"; + DLDataType t = vtype[i]; + size_t bits = t.bits * t.lanes; + CHECK_EQ(bits % 8U, 0U); + size_t bytes = (bits / 8U) * size; + + size_t sid = static_cast(storage_id); + if (sid >= pool_entry_bytes.size()) { + pool_entry_bytes.resize(sid + 1, 0); + } + pool_entry_bytes[sid] = std::max(pool_entry_bytes[sid], bytes); + } + // Allocate the space. + for (size_t i = 0; i < pool_entry_bytes.size(); ++i) { + int64_t shape[] = {static_cast(pool_entry_bytes[i] + 3) / 4}; + DLTensor* tensor; + TVM_CCALL(TVMArrayAlloc( + shape, 1, kFloat, 32, 1, ctx_.device_type, ctx_.device_id, &tensor)); + storage_pool_.push_back(tensor); + } + // Assign the pooled entries. + for (size_t i = 0; i < data_entry_.size(); ++i) { + int storage_id = attrs_.storage_id[i]; + data_entry_[i] = *storage_pool_[storage_id]; + data_entry_[i].shape = const_cast(attrs_.shape[i].data()); + data_entry_[i].ndim = static_cast(attrs_.shape[i].size()); + data_entry_[i].dtype = vtype[i]; + } +} + +void GraphRuntime::SetupOpExecs() { + op_execs_.resize(this->num_nodes()); + // setup the array and requirements. + for (uint32_t nid = 0; nid < this->num_nodes(); ++nid) { + const auto& inode = nodes_[nid]; + if (inode.op_type == "null") continue; + std::vector args; + for (const auto& e : inode.inputs) { + args.push_back(data_entry_[this->entry_id(e)]); + } + for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { + uint32_t eid = this->entry_id(nid, index); + args.push_back(data_entry_[eid]); + } + CHECK_EQ(inode.op_type, "tvm_op") + << "Can only take tvm_op as op"; + op_execs_[nid] = CreateTVMOp(inode.param, args, inode.inputs.size()); + } +} + +std::function GraphRuntime::CreateTVMOp( + const TVMOpParam& param, + const std::vector& args, + size_t num_inputs) { + struct OpArgs { + std::vector args; + std::vector arg_values; + std::vector arg_tcodes; + std::vector shape_data; + }; + std::shared_ptr arg_ptr = std::make_shared(); + // setup address. + arg_ptr->args = std::move(args); + if (param.flatten_data) { + arg_ptr->shape_data.resize(arg_ptr->args.size()); + } + for (size_t i = 0; i < arg_ptr->args.size(); ++i) { + TVMValue v; + DLTensor* t = &(arg_ptr->args[i]); + v.v_handle = t; + arg_ptr->arg_values.push_back(v); + arg_ptr->arg_tcodes.push_back(kArrayHandle); + if (param.flatten_data) { + arg_ptr->shape_data[i] = std::accumulate( + t->shape, t->shape + t->ndim, 1, std::multiplies()); + t->ndim = 1; + t->shape = &(arg_ptr->shape_data[i]); + } + } + // get compiled function from module. + tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, false); + CHECK(pf != nullptr) << "no such function in module: " << param.func_name; + auto fexec = [arg_ptr, pf] () { + TVMRetValue rv; + TVMArgs targs(arg_ptr->arg_values.data(), + arg_ptr->arg_tcodes.data(), + static_cast(arg_ptr->arg_values.size())); + pf.CallPacked(targs, &rv); + }; + return fexec; +} + +PackedFunc GraphRuntime::GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) { + // return member functions during query. + if (name == "set_input") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + if (args[0].type_code() == kStr) { + this->SetInput(this->GetInputIndex(args[0]), args[1]); + } else { + this->SetInput(args[0], args[1]); + } + }); + } else if (name == "get_output") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->GetOutput(args[0], args[1]); + }); + } else if (name == "run") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->Run(); + }); + } else if (name == "load_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->LoadParams(args[0].operator std::string()); + }); + } else { + return PackedFunc(); + } +} + +Module GraphRuntimeCreate(std::string sym_json, + tvm::runtime::Module m, + int device_type, + int device_id) { + TVMContext ctx; + ctx.device_type = static_cast(device_type); + ctx.device_id = device_id; + std::shared_ptr exec = std::make_shared(); + exec->Init(sym_json, m, ctx); + return Module(exec); +} + +TVM_REGISTER_GLOBAL("tvm.graph_runtime.create") +.set_body([](TVMArgs args, TVMRetValue *rv) { + *rv = GraphRuntimeCreate(args[0], args[1], args[2], args[3]); + }); + +TVM_REGISTER_GLOBAL("tvm.graph_runtime.remote_create") +.set_body([](TVMArgs args, TVMRetValue *rv) { + void* mhandle = args[1]; + *rv = GraphRuntimeCreate(args[0], + *static_cast(mhandle), + args[2], args[3]); + }); +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h new file mode 100644 index 000000000000..2df909f9be48 --- /dev/null +++ b/src/runtime/graph/graph_runtime.h @@ -0,0 +1,40 @@ +/*! + * Copyright (c) 2017 by Contributors + * + * \brief Tiny graph runtime that can run graph + * containing only tvm PackedFunc. + * \file graph_runtime.h + */ +#ifndef TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ +#define TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +/*! \brief Magic number for NDArray file */ +constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F; +/*! \brief Magic number for NDArray list file */ +constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; + +/*! \brief operator attributes about tvm op */ +struct TVMOpParam : public dmlc::Parameter { + std::string func_name; + uint32_t num_inputs; + uint32_t num_outputs; + uint32_t flatten_data; + + DMLC_DECLARE_PARAMETER(TVMOpParam) { + DMLC_DECLARE_FIELD(func_name); + DMLC_DECLARE_FIELD(num_inputs).set_default(1); + DMLC_DECLARE_FIELD(num_outputs).set_default(1); + DMLC_DECLARE_FIELD(flatten_data).set_default(0); + } +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_GRAPH_GRAPH_RUNTIME_H_ diff --git a/tests/python/unittest/test_runtime_graph.py b/tests/python/unittest/test_runtime_graph.py new file mode 100644 index 000000000000..7177810cfa0c --- /dev/null +++ b/tests/python/unittest/test_runtime_graph.py @@ -0,0 +1,71 @@ +import tvm +import numpy as np +import json +from tvm.contrib import rpc, util, graph_runtime + +def test_graph_simple(): + n = 4 + A = tvm.placeholder((n,), name='A') + B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B') + s = tvm.create_schedule(B.op) + + node0 = {"op": "null", "name": "x", "inputs": []} + node1 = {"op": "tvm_op", "name": "add", + "inputs": [[0, 0, 0]], + "attrs": {"func_name": "myadd", + "flatten_data": "1", + "num_inputs" : "1", + "num_outputs" : "1"}} + nodes = [node0, node1] + arg_nodes = [0] + node_row_ptr = [0, 1, 2] + outputs = [[1, 0, 0]] + shape = (4,) + attrs = { + "shape" : ["list_shape", [shape, shape]], + "dltype" : ["list_str", ["float32", "float32"]], + "storage_id" : ["list_int", [0, 1]], + } + graph = {"nodes": nodes, + "arg_nodes": arg_nodes, + "node_row_ptr": node_row_ptr, + "heads": outputs, + "attrs": attrs} + graph = json.dumps(graph) + + def check_verify(): + if not tvm.module.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mlib = tvm.build(s, [A, B], "llvm", name="myadd") + mod = graph_runtime.create(graph, mlib, tvm.cpu(0)) + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.run(x=a) + out = mod.get_output(0, tvm.nd.empty((n,))) + np.testing.assert_equal(out.asnumpy(), a + 1) + + def check_remote(): + if not tvm.module.enabled("llvm"): + print("Skip because llvm is not enabled") + return + mlib = tvm.build(s, [A, B], "llvm", name="myadd") + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + temp = util.tempdir() + ctx = remote.cpu(0) + path_dso = temp.relpath("dev_lib.so") + mlib.export_library(path_dso) + remote.upload(path_dso) + mlib = remote.load_module("dev_lib.so") + mod = graph_runtime.create(graph, mlib, remote.cpu(0)) + a = np.random.uniform(size=(n,)).astype(A.dtype) + mod.run(x=tvm.nd.array(a, ctx)) + out = tvm.nd.empty((n,), ctx=ctx) + out = mod.get_output(0, out) + np.testing.assert_equal(out.asnumpy(), a + 1) + + check_verify() + check_remote() + +if __name__ == "__main__": + test_graph_simple()