diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index eb323a94d38a..2b8e168310d6 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -265,30 +265,6 @@ class RelayBuildModule : public runtime::ModuleNode { return ret; } - /*! - * \brief Optimize a Relay function. - * - * \param func The input Relay function for optimization. - * \param targets The device type to `Target` mapping. - * \param params The param name to value mapping. - * - * \return func The updated Relay function after optimization. - */ - relay::Function Optimize( - relay::Function func, - const TargetsMap& targets, - const std::unordered_map& params) { - if (params.size()) { - func = BindParamsByName(func, params); - } - - relay::Module relay_module = relay::ModuleNode::FromExpr(func); - relay_module = Optimize(relay_module, targets_, params); - CHECK(relay_module.defined()); - GlobalVar var = relay_module->entry_func; - return relay_module->Lookup(var->name_hint); - } - /*! * \brief Optimize a Relay module. * @@ -340,6 +316,18 @@ class RelayBuildModule : public runtime::ModuleNode { } else { relay_module = seq(relay_module); } + + // Handle heterogeneous compilation. + transform::PassContext pass_ctx = PassContext::Current(); + if (targets_.size() > 1) { + relay_module = + RunDeviceAnnotationPass(relay_module, pass_ctx->fallback_device); + } + + // Fuse the operations if it is needed. + relay_module = transform::FuseOps()(relay_module); + relay_module = transform::InferType()(relay_module); + return relay_module; } @@ -354,55 +342,58 @@ class RelayBuildModule : public runtime::ModuleNode { if (name == "gpu") return Target::Create("cuda"); return Target::Create(name); } + /*! * \brief Update the target and fallback device required for heterogeneous * compilation. CPU is used as the fallback device if it wasn't provided. * Meanwhile, a CPU device type and "llvm" pair will be added to the target * dictionary in this case. * - * \param targets dictionary * \param fallback_device The fallback device for heterogeneous execution. - * \return Map */ - TargetsMap UpdateHeterogeneousInputs(const TargetsMap& targets, - int fallback_device) { - TargetsMap device_target = targets; + void UpdateHeterogeneousInputs(int fallback_device) { std::unordered_map tmp_map; - for (const auto& kv : targets) { + for (const auto& kv : targets_) { tmp_map[kv.first->value] = kv.second; } if (tmp_map.count(fallback_device) == 0) { - device_target.Set( - fallback_device, - CreateDefaultTarget(fallback_device)); + targets_.Set(fallback_device, CreateDefaultTarget(fallback_device)); } - return device_target; } + /*! * \brief Execute the device annotation passes to update the input program and * target information. * - * \param func The input Relay function. + * \param relay_module The input Relay module. * \param fallback_device The fallback device for heterogeneous execution. - * \param targets_map_ptr The device type to `Target` map pointer. * - * \return func The updated function after device annotation. + * \return updated_module The updated module after device annotation. */ - Function RunDeviceAnnotationPass(Function func, - int fallback_device, - TargetsMap* targets_map_ptr) { - relay::Module relay_module = relay::ModuleNode::FromExpr(func); + relay::Module RunDeviceAnnotationPass(const relay::Module& relay_module, + int fallback_device) { + UpdateHeterogeneousInputs(fallback_device); auto rewrite = transform::RewriteAnnotatedOps(fallback_device); - relay_module = rewrite(relay_module); - CHECK(relay_module.defined()); - func = relay_module->Lookup(relay_module->entry_func->name_hint); - CHECK(func.defined()); - auto device_map = relay::CollectDeviceInfo(func); - if (device_map.size() == 0) { - auto annotation_map = relay::CollectDeviceAnnotationOps(func); - if (annotation_map.size() == 0) { - targets_map_ptr->Set(0, CreateDefaultTarget(fallback_device)); + auto updated_module = rewrite(relay_module); + CHECK(updated_module.defined()); + + tvm::Map device_map; + for (const auto& it : updated_module->functions) { + device_map = relay::CollectDeviceInfo(it.second); + if (!device_map.empty()) break; + } + + if (device_map.empty()) { + tvm::Map annotation_map; + for (const auto& it : relay_module->functions) { + annotation_map = relay::CollectDeviceAnnotationOps(it.second); + if (!annotation_map.empty()) break; + } + // None op is annotated but they are fallen back to the default device. + if (annotation_map.empty()) { + targets_.Set(0, CreateDefaultTarget(fallback_device)); } else { + // All ops are annotated to the same device type. int64_t dev_type = -1; for (auto kv : annotation_map) { dev_type = kv.second->value; @@ -416,10 +407,10 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set(0, CreateDefaultTarget(dev_type)); + targets_.Set(0, CreateDefaultTarget(dev_type)); } } - return func; + return updated_module; } /*! @@ -431,38 +422,27 @@ class RelayBuildModule : public runtime::ModuleNode { void BuildRelay( Function func, const std::unordered_map& params) { - transform::PassContext pass_ctx = PassContext::Current(); - - // convert - tvm_cfg_ = BuildConfig::Create(); - TargetsMap device_target; - if (targets_.size() > 1) { - device_target = - UpdateHeterogeneousInputs(targets_, pass_ctx->fallback_device); - } else { - device_target = targets_; - } - func = Optimize(func, targets_, params); - if (device_target.size() > 1) { - func = RunDeviceAnnotationPass(func, pass_ctx->fallback_device, - &device_target); + if (params.size()) { + func = BindParamsByName(func, params); } + // Perform Module->Module optimizations. relay::Module relay_module = relay::ModuleNode::FromExpr(func); - relay_module = transform::InferType()(relay_module); - relay_module = transform::FuseOps()(relay_module); - relay_module = transform::InferType()(relay_module); + relay_module = Optimize(relay_module, targets_, params); CHECK(relay_module.defined()); + // Get the updated function. func = relay_module->Lookup(relay_module->entry_func->name_hint); + // Generate code for the updated function. graph_codegen_ = std::unique_ptr(new GraphCodegen()); - graph_codegen_->Init(nullptr, device_target); + graph_codegen_->Init(nullptr, targets_); graph_codegen_->Codegen(func); ret_.graph_json = graph_codegen_->GetJSON(); ret_.params = graph_codegen_->GetParams(); - ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, tvm_cfg_); + ret_.mod = tvm::build(graph_codegen_->GetLoweredFunc(), target_host_, + BuildConfig::Current()); } protected: @@ -475,8 +455,6 @@ class RelayBuildModule : public runtime::ModuleNode { std::unordered_map params_; /*! \brief building output */ BuildOutput ret_; - /*! \brief tvm building cfg */ - BuildConfig tvm_cfg_; }; runtime::Module RelayBuildCreate() {