diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 22e0b916e69ac..e894933b24160 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -62,7 +62,7 @@ def _convert_param_map(params): class BuildModule(object): - """Build a Relay function to run on TVM graph runtime. This class is used + """Build an IR module to run on TVM graph runtime. This class is used to expose the `RelayBuildModule` APIs implemented in C++. """ def __init__(self): @@ -74,12 +74,12 @@ def __init__(self): self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] - def build(self, func, target=None, target_host=None, params=None): + def build(self, mod, target=None, target_host=None, params=None): """ Parameters ---------- - func: relay.Function - The function to build. + mod : :py:class:`~tvm.IRModule` + The IRModule to build. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional @@ -115,8 +115,8 @@ def build(self, func, target=None, target_host=None, params=None): # Setup the params. if params: self._set_params(params) - # Build the function - self._build(func, target, target_host) + # Build the IR module + self._build(mod, target, target_host) # Get artifacts graph_json = self.get_json() mod = self.get_module() @@ -124,12 +124,12 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params - def optimize(self, func, target=None, params=None): + def optimize(self, mod, target=None, params=None): """ Parameters ---------- - func: relay.Function - The function to build. + mod : :py:class:`~tvm.IRModule` + The IR module to build. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional @@ -142,7 +142,7 @@ def optimize(self, func, target=None, params=None): Returns ------- - mod : tvm.IRModule + mod : :py:class:`~tvm.IRModule` The optimized relay module. params : dict @@ -153,7 +153,7 @@ def optimize(self, func, target=None, params=None): # Setup the params. if params: self._set_params(params) - mod = self._optimize(func, target) + mod = self._optimize(mod, target) # Get artifacts params = self.get_params() @@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None): Parameters ---------- - mod : tvm.IRModule - The module to build. Using relay.Function is deprecated. + mod : :py:class:`~tvm.IRModule` + The IR module to build. Using relay.Function is deprecated. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional @@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - if isinstance(mod, IRModule): - func = mod["main"] - elif isinstance(mod, _expr.Function): - func = mod + if not isinstance(mod, (IRModule, _expr.Function)): + raise ValueError("Type of input parameter mod must be tvm.IRModule") + + if isinstance(mod, _expr.Function): + mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " - "instead of deprecated parameter func (tvm.relay.expr.Function)", + "instead of deprecated parameter mod (tvm.relay.expr.Function)", DeprecationWarning) - else: - raise ValueError("Type of input parameter mod must be tvm.IRModule") target = _update_target(target) @@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None): with tophub_context: bld_mod = BuildModule() - graph_json, mod, params = bld_mod.build(func, target, target_host, params) + graph_json, mod, params = bld_mod.build(mod, target, target_host, params) return graph_json, mod, params @@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None): Parameters ---------- - mod : tvm.IRModule + mod : :py:class:`~tvm.IRModule` The module to build. Using relay.Function is deprecated. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context @@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None): Returns ------- - mod : tvm.IRModule + mod : :py:class:`~tvm.IRModule` The optimized relay module. params : dict The parameters of the final graph. """ - if isinstance(mod, IRModule): - func = mod["main"] - elif isinstance(mod, _expr.Function): - func = mod + if not isinstance(mod, (IRModule, _expr.Function)): + raise ValueError("Type of input parameter mod must be tvm.IRModule") + + if isinstance(mod, _expr.Function): + mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter func (tvm.relay.expr.Function)", DeprecationWarning) - else: - raise ValueError("Type of input parameter mod must be tvm.IRModule") target = _update_target(target) @@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None): with tophub_context: bld_mod = BuildModule() - mod, params = bld_mod.optimize(func, target, params) + mod, params = bld_mod.optimize(mod, target, params) return mod, params diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 0c0a8b8cbfa82..5cc4eaf584f0d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -233,41 +233,43 @@ class RelayBuildModule : public runtime::ModuleNode { } /*! - * \brief Build relay function for graph runtime + * \brief Build relay IRModule for graph runtime * - * \param func Relay Function + * \param mod Relay IRModule * \param target Target device * \param target_host Host target device */ - void Build(Function func, + void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; - BuildRelay(func, params_); + BuildRelay(mod, params_); } protected: /*! - * \brief Optimize a Relay Function. + * \brief Optimize a Relay IRModule. * - * \param func The input Function where optmization will be applied on. + * \param relay_module The input IRModule where optmization will be applied on. * \param targets The device type to `Target` mapping. * \param params The param name to value mapping. * - * \return relay::Module The updated Relay module after optimization. + * \return relay::IRModule The updated Relay IR module after optimization. */ IRModule Optimize( - Function func, + IRModule relay_module, const TargetsMap& targets, const std::unordered_map& params) { if (params.size()) { - func = BindParamsByName(func, params); + CHECK(relay_module->ContainGlobalVar("main")) + << "Missing the main entry function"; + GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); + Function main_func = Downcast(relay_module->Lookup(main_glb_var)); + auto new_main = BindParamsByName(main_func, params); + relay_module->Update(main_glb_var, new_main); } - // Perform Module->Module optimizations. - IRModule relay_module = IRModule::FromExpr(func); - Array pass_seqs; // Run all dialect legalization passes. @@ -418,18 +420,20 @@ class RelayBuildModule : public runtime::ModuleNode { } /*! - * \brief Compile a Relay function to runtime module. + * \brief Compile a Relay IR module to runtime module. * - * \param func The Relay function. + * \param relay_module The Relay IR module. * \param params The parameters. */ void BuildRelay( - Function func, + IRModule relay_module, const std::unordered_map& params) { - // Optimize input Relay Function and returns Relay Module - IRModule relay_module = Optimize(func, targets_, params); + // Relay IRModule -> IRModule optimizations. + relay_module = Optimize(relay_module, targets_, params); + CHECK_EQ(relay_module->functions.size(), 1U) + << "Expect one and only one function in the IR module"; // Get the updated function. - func = Downcast(relay_module->Lookup("main")); + auto func = Downcast(relay_module->Lookup("main")); // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index b9a8f8f96f8b3..a94dce6fe4967 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -115,7 +116,8 @@ TEST(Relay, BuildModule) { Map targets; Target llvm_tgt = Target::Create("llvm"); targets.Set(0, llvm_tgt); - build_f(func, targets, llvm_tgt); + auto relay_mod = tvm::IRModule::FromExpr(func); + build_f(relay_mod, targets, llvm_tgt); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run