diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 63ee2d59d854..8a0c32fc6684 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -18,12 +18,11 @@ */ /*! - * Copyright (c) 2019 by Contributors * \file relay/backend/build_module.cc * \brief Code generation for TVM's graph runtime. */ - #include <tvm/build_module.h> +#include <tvm/runtime/device_api.h> #include <tvm/relay/op.h> #include <tvm/relay/expr.h> #include <tvm/relay/attrs/nn.h> @@ -40,31 +39,6 @@ namespace backend { using TargetsMap = Map<tvm::Integer, tvm::Target>; -/*! - * \brief Context index to Target - */ -struct ContextTargetMap { - static const std::unordered_map<int, tvm::Target> mask2str; - static tvm::Target Mask2Str(int mask) { - CHECK_GT(mask2str.count(mask), 0) << "Unknown mask."; - return mask2str.at(mask); - } -}; - -const std::unordered_map<int, tvm::Target> ContextTargetMap::mask2str = { - {1, tvm::Target::create("llvm")}, - {2, tvm::Target::create("cuda")}, - {4, tvm::Target::create("opencl")}, - {5, tvm::Target::create("aocl")}, - {6, tvm::Target::create("sdaccel")}, - {7, tvm::Target::create("vulkan")}, - {8, tvm::Target::create("metal")}, - {9, tvm::Target::create("vpi")}, - {10, tvm::Target::create("rocm")}, - {11, tvm::Target::create("opengl")}, - {12, tvm::Target::create("ext_dev")} -}; - /*! * \brief A data structure to map the names of specific optimizations to * numeric optimization levels @@ -310,8 +284,8 @@ class RelayBuildModule : public runtime::ModuleNode { * * \return Array<StringImm> names of params */ - Array<HalideIR::Expr> ListParamNames() { - Array<HalideIR::Expr> ret; + Array<tvm::Expr> ListParamNames() { + Array<tvm::Expr> ret; for (const auto& kv : params_) { ret.push_back(ir::StringImm::make(kv.first)); } @@ -470,12 +444,9 @@ class RelayBuildModule : public runtime::ModuleNode { if (cfg.pass_enabled("AlterOpLayout")) { if (targets.size() == 1) { func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); - auto enter_pf = GetPackedFunc("_EnterTargetScope"); - auto exit_pf = GetPackedFunc("_ExitTargetScope"); for (const auto& kv : targets) { - (*enter_pf)(kv.second); + TargetContext tctx(kv.second); func = CallPackedFunc("relay._ir_pass.AlterOpLayout", func); - (*exit_pf)(); } } else { LOG(WARNING) << "AlterOpLayout pass is not enabled for heterogeneous" @@ -487,6 +458,18 @@ class RelayBuildModule : public runtime::ModuleNode { } return func; } + + /*! + * \brief Create a default type. + * \param device_type The device type index. + * \return the default target for the device. + */ + Target CreateDefaultTarget(int device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") return Target::create("llvm"); + if (name == "gpu") return Target::create("cuda"); + return Target::create(name); + } /*! * \brief Update the target and fallback device required for heterogeneous * compilation. CPU is used as the fallback device if it wasn't provided. @@ -507,7 +490,7 @@ class RelayBuildModule : public runtime::ModuleNode { if (tmp_map.count(cfg.fallback_device) == 0) { device_target.Set( cfg.fallback_device, - ContextTargetMap::Mask2Str(cfg.fallback_device)); + CreateDefaultTarget(cfg.fallback_device)); } return device_target; } @@ -520,7 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode { * \param targets_map_ptr * \return Function */ - Function RunDeviceAnnotationPass(Function func, const RelayBuildConfig& cfg, + Function RunDeviceAnnotationPass(Function func, + const RelayBuildConfig& cfg, TargetsMap* targets_map_ptr) { func = CallPackedFunc("relay._ir_pass.infer_type", func, nullptr); func = CallPackedFunc("relay._ir_pass.RewriteDeviceAnnotation", func, @@ -532,7 +516,7 @@ class RelayBuildModule : public runtime::ModuleNode { "relay._ir_pass.CollectDeviceAnnotationOps", func, nullptr); if (annotation_map.size() == 0) { targets_map_ptr->Set( - 0, ContextTargetMap::Mask2Str(cfg.fallback_device)); + 0, CreateDefaultTarget(cfg.fallback_device)); } else { int64_t dev_type = -1; for (auto kv : annotation_map) { @@ -547,7 +531,7 @@ class RelayBuildModule : public runtime::ModuleNode { << "found. Please check the " << "RewriteAnnotation pass."; } - targets_map_ptr->Set(0, ContextTargetMap::Mask2Str(dev_type)); + targets_map_ptr->Set(0, CreateDefaultTarget(dev_type)); } } return func; @@ -611,7 +595,8 @@ runtime::Module RelayBuildCreate() { return runtime::Module(exec); } -TVM_REGISTER_GLOBAL("relay.build_module._BuildModule").set_body([](TVMArgs args, TVMRetValue* rv) { +TVM_REGISTER_GLOBAL("relay.build_module._BuildModule") +.set_body([](TVMArgs args, TVMRetValue* rv) { *rv = RelayBuildCreate(); });