From efecceaea3958e184de7ef0ff6cb5f3988640afa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 11 Mar 2022 03:56:05 +0900 Subject: [PATCH] refactor param binding --- src/relay/backend/build_module.cc | 9 +-------- .../backend/metaschedule_task_extraction.cc | 6 +----- src/relay/backend/utils.h | 19 +++++++++++++++++++ src/relay/backend/vm/compiler.cc | 9 +-------- 4 files changed, 22 insertions(+), 21 deletions(-) diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 89ee61c83f7c..87fe39c389f0 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -333,14 +333,7 @@ class RelayBuildModule : public runtime::ModuleNode { IRModule OptimizeImpl(IRModule relay_module) { ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler."; - if (!params_.empty()) { - ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function"; - GlobalVar main_glb_var = relay_module->GetGlobalVar("main"); - Function main_func = Downcast(relay_module->Lookup(main_glb_var)); - auto new_main = BindParamsByName(main_func, params_); - IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite(); - relay_module_ptr->Update(main_glb_var, new_main); - } + backend::BindParamsInModule(relay_module, params_); Array pass_seqs = GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false); diff --git a/src/relay/backend/metaschedule_task_extraction.cc b/src/relay/backend/metaschedule_task_extraction.cc index 509ef6259e86..c0c7a525f3d3 100644 --- a/src/relay/backend/metaschedule_task_extraction.cc +++ b/src/relay/backend/metaschedule_task_extraction.cc @@ -35,6 +35,7 @@ namespace metaschedule { using meta_schedule::ExtractedTask; Array ExtractTask(IRModule mod, Target target, Map params) { + // backend::BindParamsInModule(mod, params); if (params.size()) { std::unordered_map params_; BaseFunc base_func = mod->Lookup("main"); @@ -51,18 +52,13 @@ Array ExtractTask(IRModule mod, Target target, Map tasks; - LOG(INFO) << opt_mod; - LOG(INFO) << opt_mod->Lookup("main"); PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) { if (exp->IsInstance()) { Function relay_func = Downcast(exp); if (relay_func->HasNonzeroAttr(attr::kPrimitive)) { - LOG(INFO) << relay_func; Array outputs; std::string fused_name; std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func); - LOG(INFO) << fused_name; - LOG(INFO) << outputs; auto prim_func = tir::CreatePrimFunc(outputs); auto prim_fn_var = GlobalVar(fused_name); auto relay_mod = IRModule({{prim_fn_var, relay_func}}); diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 3b4d4c18de89..f15ae4765add 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -417,6 +417,25 @@ inline relay::Function BindParamsByName( return ret; } +inline void BindParamsInModule(IRModule mod, + const std::unordered_map& params) { + if (!params.empty()) { + BaseFunc base_func = mod->Lookup("main"); + ICHECK(base_func->IsInstance()); + auto f = relay::backend::BindParamsByName(Downcast(base_func), params); + auto gvar = mod->GetGlobalVar("main"); + mod->Add(gvar, f); + } +} + +inline void BindParamsInModule(IRModule mod, Map params) { + std::unordered_map params_tmp; + for (const auto& kv : params) { + params_tmp[kv.first] = kv.second->data; + } + BindParamsInModule(mod, params_tmp); +} + /*! * \brief Extract the shape from a Relay tensor type. * \param type The provided type. diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e94919de7f20..130fb09e7af1 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1034,14 +1034,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets, IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { VLOG_CONTEXT << "VM Optimize"; - if (params_.size()) { - BaseFunc base_func = mod->Lookup("main"); - ICHECK(base_func->IsInstance()) - << "VM compiler expects to compile relay::Function"; - auto f = relay::backend::BindParamsByName(Downcast(base_func), params_); - auto gvar = mod->GetGlobalVar("main"); - mod->Add(gvar, f); - } + backend::BindParamsInModule(mod, params_); Array pass_seqs = relay::backend::GetPassPrefix( /*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true);