Skip to content

Commit

Permalink
[RUNTIME] Support module based interface runtime (#5753)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene authored Jul 15, 2020
1 parent 4ae8fd7 commit 9fcde21
Show file tree
Hide file tree
Showing 8 changed files with 927 additions and 20 deletions.
13 changes: 5 additions & 8 deletions python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
Parameters
----------
graph_json_str : str or graph class
graph_json_str : str
The graph to be deployed in json format output by graph compiler.
The graph can only contain one operator(tvm_op) that
points to the name of PackedFunc in the libmod.
The graph can contain operator(tvm_op) that points to the name
of PackedFunc in the libmod.
libmod : tvm.Module
The module of the corresponding function.
Expand All @@ -54,11 +54,8 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
graph_module : GraphModuleDebug
Debug Runtime graph module that can be used to execute the graph.
"""
if not isinstance(graph_json_str, string_types):
try:
graph_json_str = graph_json_str._tvm_graph_json()
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
assert isinstance(graph_json_str, string_types)

try:
ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
if num_rpc_ctx == len(ctx):
Expand Down
12 changes: 4 additions & 8 deletions python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def create(graph_json_str, libmod, ctx):
Parameters
----------
graph_json_str : str or graph class
graph_json_str : str
The graph to be deployed in json format output by json graph.
The graph can only contain one operator(tvm_op) that
points to the name of PackedFunc in the libmod.
The graph can contain operator(tvm_op) that points to the name
of PackedFunc in the libmod.
libmod : tvm.runtime.Module
The module of the corresponding function
Expand All @@ -48,11 +48,7 @@ def create(graph_json_str, libmod, ctx):
graph_module : GraphModule
Runtime graph module that can be used to execute the graph.
"""
if not isinstance(graph_json_str, string_types):
try:
graph_json_str = graph_json_str._tvm_graph_json()
except AttributeError:
raise ValueError("Type %s is not supported" % type(graph_json_str))
assert isinstance(graph_json_str, string_types)

ctx, num_rpc_ctx, device_type_id = get_device_ctx(libmod, ctx)

Expand Down
84 changes: 84 additions & 0 deletions python/tvm/relay/backend/graph_runtime_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Graph runtime factory."""
import warnings
from tvm._ffi.base import string_types
from tvm._ffi.registry import get_global_func
from tvm.runtime import ndarray

class GraphRuntimeFactoryModule(object):
"""Graph runtime factory module.
This is a module of graph runtime factory
Parameters
----------
graph_json_str : str
The graph to be deployed in json format output by graph compiler.
The graph can contain operator(tvm_op) that points to the name of
PackedFunc in the libmod.
libmod : tvm.Module
The module of the corresponding function
libmod_name: str
The name of module
params : dict of str to NDArray
The parameters of module
"""

def __init__(self, graph_json_str, libmod, libmod_name, params):
assert isinstance(graph_json_str, string_types)
fcreate = get_global_func("tvm.graph_runtime_factory.create")
args = []
for k, v in params.items():
args.append(k)
args.append(ndarray.array(v))
self.module = fcreate(graph_json_str, libmod, libmod_name, *args)
self.graph_json = graph_json_str
self.lib = libmod
self.libmod_name = libmod_name
self.params = params
self.iter_cnt = 0

def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
return self.module.export_library(file_name, fcompile, addons, **kwargs)

# Sometimes we want to get params explicitly.
# For example, we want to save its params value to
# an independent file.
def get_params(self):
return self.params

def get_json(self):
return self.graph_json

def __getitem__(self, item):
return self.module.__getitem__(item)

def __iter__(self):
warnings.warn(
"legacy graph runtime behaviour of producing json / lib / params will be "
"removed in the next release ",
DeprecationWarning, 2)
return self

def __next__(self):
if self.iter_cnt > 2:
raise StopIteration

objs = [self.graph_json, self.lib, self.params]
obj = objs[self.iter_cnt]
self.iter_cnt += 1
return obj
9 changes: 7 additions & 2 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from . import ty as _ty
from . import expr as _expr
from . import function as _function
from .backend import graph_runtime_factory as _graph_runtime_factory
from .backend import interpreter as _interpreter
from .backend.vm import VMExecutor

Expand Down Expand Up @@ -181,7 +182,7 @@ def get_params(self):
return ret


def build(mod, target=None, target_host=None, params=None):
def build(mod, target=None, target_host=None, params=None, mod_name='default'):
"""Helper function that builds a Relay function to run on TVM graph
runtime.
Expand All @@ -208,6 +209,9 @@ def build(mod, target=None, target_host=None, params=None):
Input parameters to the graph that do not change
during inference time. Used for constant folding.
mod_name: Optional[str]
The module name we will build
Returns
-------
graph_json : str
Expand Down Expand Up @@ -249,7 +253,8 @@ def build(mod, target=None, target_host=None, params=None):
with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
return graph_json, mod, params
mod = _graph_runtime_factory.GraphRuntimeFactoryModule(graph_json, mod, mod_name, params)
return mod


def optimize(mod, target=None, params=None):
Expand Down
175 changes: 175 additions & 0 deletions src/runtime/graph/graph_runtime_factory.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file graph_runtime_factory.cc
* \brief Graph runtime factory implementations
*/

#include "./graph_runtime_factory.h"

#include <tvm/node/container.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>

#include <iterator>
#include <vector>

namespace tvm {
namespace runtime {

GraphRuntimeFactory::GraphRuntimeFactory(
const std::string& graph_json,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params,
const std::string& module_name) {
graph_json_ = graph_json;
params_ = params;
module_name_ = module_name;
}

PackedFunc GraphRuntimeFactory::GetFunction(
const std::string& name, const tvm::runtime::ObjectPtr<tvm::runtime::Object>& sptr_to_self) {
if (name == module_name_) {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::vector<TVMContext> contexts;
for (int i = 0; i < args.num_args; ++i) {
contexts.emplace_back(args[i].operator TVMContext());
}
*rv = this->RuntimeCreate(contexts);
});
} else if (name == "debug_create") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.size(), 2);
std::string module_name = args[0].operator String();
CHECK(module_name == module_name_) << "Currently we only support single model for now.";
std::vector<TVMContext> contexts;
for (int i = 1; i < args.num_args; ++i) {
contexts.emplace_back(args[i].operator TVMContext());
}
*rv = this->DebugRuntimeCreate(contexts);
});
} else if (name == "remove_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
std::unordered_map<std::string, tvm::runtime::NDArray> empty_params{};
auto exec =
make_object<GraphRuntimeFactory>(this->graph_json_, empty_params, this->module_name_);
exec->Import(this->imports_[0]);
*rv = Module(exec);
});
} else {
return PackedFunc();
}
}

void GraphRuntimeFactory::SaveToBinary(dmlc::Stream* stream) {
stream->Write(graph_json_);
std::vector<std::string> names;
std::vector<DLTensor*> arrays;
for (const auto& v : params_) {
names.emplace_back(v.first);
arrays.emplace_back(const_cast<DLTensor*>(v.second.operator->()));
}
uint64_t sz = arrays.size();
CHECK(sz == names.size());
stream->Write(sz);
stream->Write(names);
for (size_t i = 0; i < sz; ++i) {
tvm::runtime::SaveDLTensor(stream, arrays[i]);
}
stream->Write(module_name_);
}

Module GraphRuntimeFactory::RuntimeCreate(const std::vector<TVMContext>& ctxs) {
auto exec = make_object<GraphRuntime>();
exec->Init(this->graph_json_, this->imports_[0], ctxs);
// set params
SetParams(exec.get(), this->params_);
return Module(exec);
}

Module GraphRuntimeFactory::DebugRuntimeCreate(const std::vector<TVMContext>& ctxs) {
const PackedFunc* pf = tvm::runtime::Registry::Get("tvm.graph_runtime_debug.create");
CHECK(pf != nullptr) << "Cannot find function tvm.graph_runtime_debug.create in registry. "
"Do you enable debug graph runtime build?";
// Debug runtime create packed function will call GetAllContexs, so we unpack the ctxs.
std::vector<int> unpacked_ctxs;
for (const auto& ctx : ctxs) {
unpacked_ctxs.emplace_back(ctx.device_type);
unpacked_ctxs.emplace_back(ctx.device_id);
}
size_t args_size = unpacked_ctxs.size() + 2;
std::vector<TVMValue> values(args_size);
std::vector<int> codes(args_size);
runtime::TVMArgsSetter setter(values.data(), codes.data());
setter(0, this->graph_json_);
setter(1, this->imports_[0]);
for (size_t i = 0; i < unpacked_ctxs.size(); ++i) {
setter(i + 2, unpacked_ctxs[i]);
}
TVMRetValue rv;
pf->CallPacked(TVMArgs(values.data(), codes.data(), args_size), &rv);
Module mod = rv.operator Module();
// debug graph runtime is one child class of graph runtime.
SetParams(const_cast<GraphRuntime*>(mod.as<GraphRuntime>()), this->params_);
return mod;
}

Module GraphRuntimeFactoryModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string graph_json;
std::unordered_map<std::string, tvm::runtime::NDArray> params;
std::string module_name;
CHECK(stream->Read(&graph_json));
uint64_t sz;
CHECK(stream->Read(&sz));
std::vector<std::string> names;
CHECK(stream->Read(&names));
CHECK(sz == names.size());
for (size_t i = 0; i < sz; ++i) {
tvm::runtime::NDArray temp;
temp.Load(stream);
params[names[i]] = temp;
}
CHECK(stream->Read(&module_name));
auto exec = make_object<GraphRuntimeFactory>(graph_json, params, module_name);
return Module(exec);
}

TVM_REGISTER_GLOBAL("tvm.graph_runtime_factory.create").set_body([](TVMArgs args, TVMRetValue* rv) {
CHECK_GE(args.num_args, 3) << "The expected number of arguments for "
"graph_runtime_factory.create needs at least 3, "
"but it has "
<< args.num_args;
// The argument order is graph_json, module, module_name, params.
CHECK_EQ((args.size() - 3) % 2, 0);
std::unordered_map<std::string, tvm::runtime::NDArray> params;
for (size_t i = 3; i < static_cast<size_t>(args.size()); i += 2) {
std::string name = args[i].operator String();
params[name] = args[i + 1].operator tvm::runtime::NDArray();
}
auto exec = make_object<GraphRuntimeFactory>(args[0], params, args[2]);
exec->Import(args[1]);
*rv = Module(exec);
});

TVM_REGISTER_GLOBAL("runtime.module.loadbinary_GraphRuntimeFactory")
.set_body_typed(GraphRuntimeFactoryModuleLoadBinary);

} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 9fcde21

Please sign in to comment.