Skip to content

Commit

Permalink
refactor build module to take IRModule
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 5, 2020
1 parent 5a0f39b commit 8376093
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 49 deletions.
58 changes: 28 additions & 30 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _convert_param_map(params):


class BuildModule(object):
"""Build a Relay function to run on TVM graph runtime. This class is used
"""Build an IR module to run on TVM graph runtime. This class is used
to expose the `RelayBuildModule` APIs implemented in C++.
"""
def __init__(self):
Expand All @@ -74,12 +74,12 @@ def __init__(self):
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]

def build(self, func, target=None, target_host=None, params=None):
def build(self, mod, target=None, target_host=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.
mod : :py:class:`~tvm.IRModule`
The IRModule to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
Expand Down Expand Up @@ -115,21 +115,21 @@ def build(self, func, target=None, target_host=None, params=None):
# Setup the params.
if params:
self._set_params(params)
# Build the function
self._build(func, target, target_host)
# Build the IR module
self._build(mod, target, target_host)
# Get artifacts
graph_json = self.get_json()
mod = self.get_module()
params = self.get_params()

return graph_json, mod, params

def optimize(self, func, target=None, params=None):
def optimize(self, mod, target=None, params=None):
"""
Parameters
----------
func: relay.Function
The function to build.
mod : :py:class:`~tvm.IRModule`
The IR module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
Expand All @@ -142,7 +142,7 @@ def optimize(self, func, target=None, params=None):
Returns
-------
mod : tvm.IRModule
mod : :py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
Expand All @@ -153,7 +153,7 @@ def optimize(self, func, target=None, params=None):
# Setup the params.
if params:
self._set_params(params)
mod = self._optimize(func, target)
mod = self._optimize(mod, target)
# Get artifacts
params = self.get_params()

Expand Down Expand Up @@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None):
Parameters
----------
mod : tvm.IRModule
The module to build. Using relay.Function is deprecated.
mod : :py:class:`~tvm.IRModule`
The IR module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
name) to str/tvm.target.Target, optional
Expand Down Expand Up @@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None):
params : dict
The parameters of the final graph.
"""
if isinstance(mod, IRModule):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
if not isinstance(mod, (IRModule, _expr.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

if isinstance(mod, _expr.Function):
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
"instead of deprecated parameter mod (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.IRModule")

target = _update_target(target)

Expand All @@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None):

with tophub_context:
bld_mod = BuildModule()
graph_json, mod, params = bld_mod.build(func, target, target_host, params)
graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
return graph_json, mod, params


Expand All @@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None):
Parameters
----------
mod : tvm.IRModule
mod : :py:class:`~tvm.IRModule`
The module to build. Using relay.Function is deprecated.
target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context
Expand All @@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None):
Returns
-------
mod : tvm.IRModule
mod : :py:class:`~tvm.IRModule`
The optimized relay module.
params : dict
The parameters of the final graph.
"""
if isinstance(mod, IRModule):
func = mod["main"]
elif isinstance(mod, _expr.Function):
func = mod
if not isinstance(mod, (IRModule, _expr.Function)):
raise ValueError("Type of input parameter mod must be tvm.IRModule")

if isinstance(mod, _expr.Function):
mod = IRModule.from_expr(mod)
warnings.warn(
"Please use input parameter mod (tvm.IRModule) "
"instead of deprecated parameter func (tvm.relay.expr.Function)",
DeprecationWarning)
else:
raise ValueError("Type of input parameter mod must be tvm.IRModule")

target = _update_target(target)

Expand All @@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None):

with tophub_context:
bld_mod = BuildModule()
mod, params = bld_mod.optimize(func, target, params)
mod, params = bld_mod.optimize(mod, target, params)
return mod, params


Expand Down
40 changes: 22 additions & 18 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,41 +233,43 @@ class RelayBuildModule : public runtime::ModuleNode {
}

/*!
* \brief Build relay function for graph runtime
* \brief Build relay IRModule for graph runtime
*
* \param func Relay Function
* \param mod Relay IRModule
* \param target Target device
* \param target_host Host target device
*/
void Build(Function func,
void Build(IRModule mod,
const TargetsMap& targets,
const tvm::Target& target_host) {
targets_ = targets;
target_host_ = target_host;
BuildRelay(func, params_);
BuildRelay(mod, params_);
}

protected:
/*!
* \brief Optimize a Relay Function.
* \brief Optimize a Relay IRModule.
*
* \param func The input Function where optmization will be applied on.
* \param relay_module The input IRModule where optmization will be applied on.
* \param targets The device type to `Target` mapping.
* \param params The param name to value mapping.
*
* \return relay::Module The updated Relay module after optimization.
* \return relay::IRModule The updated Relay IR module after optimization.
*/
IRModule Optimize(
Function func,
IRModule relay_module,
const TargetsMap& targets,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (params.size()) {
func = BindParamsByName(func, params);
CHECK(relay_module->ContainGlobalVar("main"))
<< "Missing the main entry function";
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params);
relay_module->Update(main_glb_var, new_main);
}

// Perform Module->Module optimizations.
IRModule relay_module = IRModule::FromExpr(func);

Array<Pass> pass_seqs;

// Run all dialect legalization passes.
Expand Down Expand Up @@ -418,18 +420,20 @@ class RelayBuildModule : public runtime::ModuleNode {
}

/*!
* \brief Compile a Relay function to runtime module.
* \brief Compile a Relay IR module to runtime module.
*
* \param func The Relay function.
* \param relay_module The Relay IR module.
* \param params The parameters.
*/
void BuildRelay(
Function func,
IRModule relay_module,
const std::unordered_map<std::string, tvm::runtime::NDArray>& params) {
// Optimize input Relay Function and returns Relay Module
IRModule relay_module = Optimize(func, targets_, params);
// Relay IRModule -> IRModule optimizations.
relay_module = Optimize(relay_module, targets_, params);
CHECK_EQ(relay_module->functions.size(), 1U)
<< "Expect one and only one function in the IR module";
// Get the updated function.
func = Downcast<Function>(relay_module->Lookup("main"));
auto func = Downcast<Function>(relay_module->Lookup("main"));

// Generate code for the updated function.
graph_codegen_ = std::unique_ptr<GraphCodegen>(new GraphCodegen());
Expand Down
4 changes: 3 additions & 1 deletion tests/cpp/relay_build_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <topi/broadcast.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/ir/module.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>

Expand Down Expand Up @@ -115,7 +116,8 @@ TEST(Relay, BuildModule) {
Map<tvm::Integer, tvm::Target> targets;
Target llvm_tgt = Target::Create("llvm");
targets.Set(0, llvm_tgt);
build_f(func, targets, llvm_tgt);
auto relay_mod = tvm::IRModule::FromExpr(func);
build_f(relay_mod, targets, llvm_tgt);
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
// run
Expand Down

0 comments on commit 8376093

Please sign in to comment.