Skip to content

Commit

Permalink
refactor param binding
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent 109187f commit efeccea
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
9 changes: 1 addition & 8 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,7 @@ class RelayBuildModule : public runtime::ModuleNode {
IRModule OptimizeImpl(IRModule relay_module) {
ICHECK(relay_module.defined()) << "The IRModule must be defined for the Relay compiler.";

if (!params_.empty()) {
ICHECK(relay_module->ContainGlobalVar("main")) << "Missing the main entry function";
GlobalVar main_glb_var = relay_module->GetGlobalVar("main");
Function main_func = Downcast<Function>(relay_module->Lookup(main_glb_var));
auto new_main = BindParamsByName(main_func, params_);
IRModuleNode* relay_module_ptr = relay_module.CopyOnWrite();
relay_module_ptr->Update(main_glb_var, new_main);
}
backend::BindParamsInModule(relay_module, params_);

Array<Pass> pass_seqs = GetPassPrefix(
/*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/false);
Expand Down
6 changes: 1 addition & 5 deletions src/relay/backend/metaschedule_task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace metaschedule {
using meta_schedule::ExtractedTask;

Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Constant> params) {
// backend::BindParamsInModule(mod, params);
if (params.size()) {
std::unordered_map<std::string, runtime::NDArray> params_;
BaseFunc base_func = mod->Lookup("main");
Expand All @@ -51,18 +52,13 @@ Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Consta
auto opt_mod = seq(std::move(mod));

Array<ExtractedTask> tasks;
LOG(INFO) << opt_mod;
LOG(INFO) << opt_mod->Lookup("main");
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (relay_func->HasNonzeroAttr(attr::kPrimitive)) {
LOG(INFO) << relay_func;
Array<te::Tensor> outputs;
std::string fused_name;
std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func);
LOG(INFO) << fused_name;
LOG(INFO) << outputs;
auto prim_func = tir::CreatePrimFunc(outputs);
auto prim_fn_var = GlobalVar(fused_name);
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
Expand Down
19 changes: 19 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,25 @@ inline relay::Function BindParamsByName(
return ret;
}

inline void BindParamsInModule(IRModule mod,
const std::unordered_map<std::string, runtime::NDArray>& params) {
if (!params.empty()) {
BaseFunc base_func = mod->Lookup("main");
ICHECK(base_func->IsInstance<FunctionNode>());
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
}

inline void BindParamsInModule(IRModule mod, Map<String, Constant> params) {
std::unordered_map<std::string, runtime::NDArray> params_tmp;
for (const auto& kv : params) {
params_tmp[kv.first] = kv.second->data;
}
BindParamsInModule(mod, params_tmp);
}

/*!
* \brief Extract the shape from a Relay tensor type.
* \param type The provided type.
Expand Down
9 changes: 1 addition & 8 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1034,14 +1034,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetMap& targets,

IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
VLOG_CONTEXT << "VM Optimize";
if (params_.size()) {
BaseFunc base_func = mod->Lookup("main");
ICHECK(base_func->IsInstance<FunctionNode>())
<< "VM compiler expects to compile relay::Function";
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}
backend::BindParamsInModule(mod, params_);

Array<Pass> pass_seqs = relay::backend::GetPassPrefix(
/*is_homogenous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true);
Expand Down

0 comments on commit efeccea

Please sign in to comment.