From 69de85845f4c7a4b6d0dcd4e787d605222a4775b Mon Sep 17 00:00:00 2001 From: Zhi Chen Date: Wed, 4 Mar 2020 23:52:39 +0000 Subject: [PATCH] refactor build module to take IRModule --- include/tvm/relay/transform.h | 10 +++++ python/tvm/relay/build_module.py | 58 ++++++++++++++-------------- src/relay/backend/build_module.cc | 40 ++++++++++--------- src/relay/backend/vm/compiler.cc | 1 - tests/cpp/relay_build_module_test.cc | 4 +- 5 files changed, 63 insertions(+), 50 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 2837c1ff7f25..0a2c77a3af45 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -332,6 +332,16 @@ TVM_DLL Pass PartitionGraph(); */ TVM_DLL Pass Inline(); +/*! + * \brief Remove the unused functions in the Relay IRModule. + * + * \param entry_functions The entry functions used to search the functions that + * are being used. + * + * \return The pass. + */ +TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); + } // namespace transform /*! diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 22e0b916e69a..e894933b2416 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -62,7 +62,7 @@ def _convert_param_map(params): class BuildModule(object): - """Build a Relay function to run on TVM graph runtime. This class is used + """Build an IR module to run on TVM graph runtime. This class is used to expose the `RelayBuildModule` APIs implemented in C++. """ def __init__(self): @@ -74,12 +74,12 @@ def __init__(self): self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] - def build(self, func, target=None, target_host=None, params=None): + def build(self, mod, target=None, target_host=None, params=None): """ Parameters ---------- - func: relay.Function - The function to build. + mod : :py:class:`~tvm.IRModule` + The IRModule to build. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional @@ -115,8 +115,8 @@ def build(self, func, target=None, target_host=None, params=None): # Setup the params. if params: self._set_params(params) - # Build the function - self._build(func, target, target_host) + # Build the IR module + self._build(mod, target, target_host) # Get artifacts graph_json = self.get_json() mod = self.get_module() @@ -124,12 +124,12 @@ def build(self, func, target=None, target_host=None, params=None): return graph_json, mod, params - def optimize(self, func, target=None, params=None): + def optimize(self, mod, target=None, params=None): """ Parameters ---------- - func: relay.Function - The function to build. + mod : :py:class:`~tvm.IRModule` + The IR module to build. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional @@ -142,7 +142,7 @@ def optimize(self, func, target=None, params=None): Returns ------- - mod : tvm.IRModule + mod : :py:class:`~tvm.IRModule` The optimized relay module. params : dict @@ -153,7 +153,7 @@ def optimize(self, func, target=None, params=None): # Setup the params. if params: self._set_params(params) - mod = self._optimize(func, target) + mod = self._optimize(mod, target) # Get artifacts params = self.get_params() @@ -186,8 +186,8 @@ def build(mod, target=None, target_host=None, params=None): Parameters ---------- - mod : tvm.IRModule - The module to build. Using relay.Function is deprecated. + mod : :py:class:`~tvm.IRModule` + The IR module to build. Using relay.Function is deprecated. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context name) to str/tvm.target.Target, optional @@ -218,16 +218,15 @@ def build(mod, target=None, target_host=None, params=None): params : dict The parameters of the final graph. """ - if isinstance(mod, IRModule): - func = mod["main"] - elif isinstance(mod, _expr.Function): - func = mod + if not isinstance(mod, (IRModule, _expr.Function)): + raise ValueError("Type of input parameter mod must be tvm.IRModule") + + if isinstance(mod, _expr.Function): + mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " - "instead of deprecated parameter func (tvm.relay.expr.Function)", + "instead of deprecated parameter mod (tvm.relay.expr.Function)", DeprecationWarning) - else: - raise ValueError("Type of input parameter mod must be tvm.IRModule") target = _update_target(target) @@ -246,7 +245,7 @@ def build(mod, target=None, target_host=None, params=None): with tophub_context: bld_mod = BuildModule() - graph_json, mod, params = bld_mod.build(func, target, target_host, params) + graph_json, mod, params = bld_mod.build(mod, target, target_host, params) return graph_json, mod, params @@ -255,7 +254,7 @@ def optimize(mod, target=None, params=None): Parameters ---------- - mod : tvm.IRModule + mod : :py:class:`~tvm.IRModule` The module to build. Using relay.Function is deprecated. target : str, :any:`tvm.target.Target`, or dict of str(i.e. device/context @@ -269,22 +268,21 @@ def optimize(mod, target=None, params=None): Returns ------- - mod : tvm.IRModule + mod : :py:class:`~tvm.IRModule` The optimized relay module. params : dict The parameters of the final graph. """ - if isinstance(mod, IRModule): - func = mod["main"] - elif isinstance(mod, _expr.Function): - func = mod + if not isinstance(mod, (IRModule, _expr.Function)): + raise ValueError("Type of input parameter mod must be tvm.IRModule") + + if isinstance(mod, _expr.Function): + mod = IRModule.from_expr(mod) warnings.warn( "Please use input parameter mod (tvm.IRModule) " "instead of deprecated parameter func (tvm.relay.expr.Function)", DeprecationWarning) - else: - raise ValueError("Type of input parameter mod must be tvm.IRModule") target = _update_target(target) @@ -297,7 +295,7 @@ def optimize(mod, target=None, params=None): with tophub_context: bld_mod = BuildModule() - mod, params = bld_mod.optimize(func, target, params) + mod, params = bld_mod.optimize(mod, target, params) return mod, params diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 0c0a8b8cbfa8..61ec28179eee 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -233,42 +233,46 @@ class RelayBuildModule : public runtime::ModuleNode { } /*! - * \brief Build relay function for graph runtime + * \brief Build relay IRModule for graph runtime * - * \param func Relay Function + * \param mod Relay IRModule * \param target Target device * \param target_host Host target device */ - void Build(Function func, + void Build(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { targets_ = targets; target_host_ = target_host; - BuildRelay(func, params_); + BuildRelay(mod, params_); } protected: /*! - * \brief Optimize a Relay Function. + * \brief Optimize a Relay IRModule. * - * \param func The input Function where optmization will be applied on. + * \param relay_module The input IRModule where optmization will be applied on. * \param targets The device type to `Target` mapping. * \param params The param name to value mapping. * - * \return relay::Module The updated Relay module after optimization. + * \return relay::IRModule The updated Relay IR module after optimization. */ IRModule Optimize( - Function func, + IRModule relay_module, const TargetsMap& targets, const std::unordered_map& params) { if (params.size()) { - func = BindParamsByName(func, params); + CHECK(relay_module->ContainGlobalVar("main")) + << "Missing the main entry function"; + GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); + Function main_func = Downcast(relay_module->Lookup(main_glb_var)); + auto new_main = BindParamsByName(main_func, params); + relay_module->Update(main_glb_var, new_main); } - // Perform Module->Module optimizations. - IRModule relay_module = IRModule::FromExpr(func); - Array pass_seqs; + Array entry_functions{tvm::PrimExpr{"main"}}; + pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); // Run all dialect legalization passes. pass_seqs.push_back(relay::qnn::transform::Legalize()); @@ -418,18 +422,18 @@ class RelayBuildModule : public runtime::ModuleNode { } /*! - * \brief Compile a Relay function to runtime module. + * \brief Compile a Relay IR module to runtime module. * - * \param func The Relay function. + * \param relay_module The Relay IR module. * \param params The parameters. */ void BuildRelay( - Function func, + IRModule relay_module, const std::unordered_map& params) { - // Optimize input Relay Function and returns Relay Module - IRModule relay_module = Optimize(func, targets_, params); + // Relay IRModule -> IRModule optimizations. + relay_module = Optimize(relay_module, targets_, params); // Get the updated function. - func = Downcast(relay_module->Lookup("main")); + auto func = Downcast(relay_module->Lookup("main")); // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 73a6450c16ec..2129b64a8b44 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -51,7 +51,6 @@ namespace transform { Pass LambdaLift(); Pass InlinePrimitives(); -Pass RemoveUnusedFunctions(Array entry_functions); Pass ManifestAlloc(Target target_host) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index b9a8f8f96f8b..a94dce6fe496 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -115,7 +116,8 @@ TEST(Relay, BuildModule) { Map targets; Target llvm_tgt = Target::Create("llvm"); targets.Set(0, llvm_tgt); - build_f(func, targets, llvm_tgt); + auto relay_mod = tvm::IRModule::FromExpr(func); + build_f(relay_mod, targets, llvm_tgt); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run