From d9450f8c9d8338aadf2632bad3aedb0b9416a8ae Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 28 Aug 2020 22:59:35 -0700 Subject: [PATCH] [Target][Codegen] Use target class in all codegens (#6347) * [Target][Codegen] Make all code generator use Target class instead of target string * Remove dep to TargetNode::str() in LLVM module * Allow for llvm nvptx codegen * ... * Address comments from Cody * Rename UpdateTargetConfig => UpdateTargetConfigKeyValueEntry --- include/tvm/target/codegen.h | 2 +- include/tvm/target/target.h | 3 +- python/tvm/target/target.py | 3 + src/target/build_common.h | 17 ++++ src/target/codegen.cc | 6 +- src/target/llvm/codegen_amdgpu.cc | 50 +++++---- src/target/llvm/codegen_blob.cc | 7 +- src/target/llvm/codegen_hexagon.cc | 31 ++---- src/target/llvm/codegen_nvptx.cc | 17 ++-- src/target/llvm/llvm_common.cc | 101 ++++++++++-------- src/target/llvm/llvm_common.h | 21 +++- src/target/llvm/llvm_module.cc | 141 +++++++++++++------------- src/target/opt/build_cuda_on.cc | 2 +- src/target/source/codegen_aocl.cc | 17 ++-- src/target/source/codegen_c_host.cc | 7 +- src/target/source/codegen_metal.cc | 6 +- src/target/source/codegen_opencl.cc | 2 +- src/target/source/codegen_vhls.cc | 3 +- src/target/spirv/build_vulkan.cc | 6 +- src/target/stackvm/codegen_stackvm.cc | 2 +- src/target/target.cc | 20 +++- src/target/target_kind.cc | 1 + tests/cpp/build_module_test.cc | 11 +- 23 files changed, 276 insertions(+), 200 deletions(-) diff --git a/include/tvm/target/codegen.h b/include/tvm/target/codegen.h index e89d44dd4eb1..b2cab0e4bc45 100644 --- a/include/tvm/target/codegen.h +++ b/include/tvm/target/codegen.h @@ -45,7 +45,7 @@ using runtime::TVMRetValue; * \param target The target to be built. * \return The result runtime::Module. */ -runtime::Module Build(IRModule mod, const Target& target); +runtime::Module Build(IRModule mod, Target target); /*! * \brief Pack imported device library to a C file. diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h index 258b2d83ee72..4d5fba39f506 100644 --- a/include/tvm/target/target.h +++ b/include/tvm/target/target.h @@ -52,9 +52,10 @@ class TargetNode : public Object { Array keys; /*! \brief Collection of attributes */ Map attrs; - /*! \return the full device string to pass to codegen::Build */ TVM_DLL const std::string& str() const; + /*! \return Export target to JSON-like configuration */ + TVM_DLL Map Export() const; void VisitAttrs(AttrVisitor* v) { v->Visit("kind", &kind); diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 9dcc164be78a..986caa165791 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -54,6 +54,9 @@ def __enter__(self): def __exit__(self, ptype, value, trace): _ffi_api.ExitTargetScope(self) + def export(self): + return _ffi_api.TargetExport(self) + @staticmethod def current(allow_none=True): """Returns the current target. diff --git a/src/target/build_common.h b/src/target/build_common.h index ec5b522397ed..531bd629bbdb 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -62,6 +62,23 @@ inline std::unordered_map ExtractFuncInfo(co } return fmap; } + +inline void UpdateTargetConfigKeyValueEntry(const String& key, const String& value, + Map* target_config, + bool error_if_inconsistent) { + if (target_config->count(key)) { + const ObjectRef& obj = (*target_config)[key]; + CHECK(obj->IsInstance()) << "TypeError: Expect target key \"" << key + << "\" to be String, but gets type: " << obj->GetTypeKey(); + if (error_if_inconsistent) { + String old_value = Downcast(obj); + CHECK_EQ(old_value, value) << "ValueError: Target key \"" << key << "\" has been set to \"" + << old_value << "\", and cannot be reset to \"" << value << "\""; + } + } + target_config->Set(key, value); +} + } // namespace codegen } // namespace tvm #endif // TVM_TARGET_BUILD_COMMON_H_ diff --git a/src/target/codegen.cc b/src/target/codegen.cc index 0ac4993efe16..47603e404635 100644 --- a/src/target/codegen.cc +++ b/src/target/codegen.cc @@ -41,7 +41,7 @@ namespace tvm { namespace codegen { -runtime::Module Build(IRModule mod, const Target& target) { +runtime::Module Build(IRModule mod, Target target) { if (transform::PassContext::Current() ->GetConfig("tir.disable_assert", Bool(false)) .value()) { @@ -55,8 +55,8 @@ runtime::Module Build(IRModule mod, const Target& target) { } // the build function. const PackedFunc* bf = runtime::Registry::Get(build_f_name); - CHECK(bf != nullptr) << "target.build." << target << " is not enabled"; - return (*bf)(mod, target->str()); + CHECK(bf != nullptr) << build_f_name << " is not enabled"; + return (*bf)(mod, target); } /*! \brief Helper class to serialize module */ diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 758a4f6be7e6..c19c01b3acdf 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -191,12 +191,17 @@ class CodeGenAMDGPU : public CodeGenLLVM { } }; -inline int DetectROCMComputeVersion(const std::string& target) { - size_t pos = target.find("=gfx"); - if (pos != std::string::npos) { - int value; - std::stringstream is(target.substr(pos + 4)); - if (is >> value) return value; +inline int DetectROCMComputeVersion(const Target& target) { + if (const Optional mcpu = target->GetAttr("mcpu")) { + std::string gfx = mcpu.value(); + if (gfx.length() >= 3 && gfx.substr(0, 3) == "gfx") { + int version; + std::stringstream is(gfx.substr(3)); + if (is >> version) { + return version; + } + } + LOG(FATAL) << "ValueError: Unrecognized -mcpu value: " << mcpu; } TVMContext tvm_ctx; tvm_ctx.device_type = kDLROCM; @@ -228,23 +233,34 @@ inline int DetectROCMApiVersion() { return 305; } -runtime::Module BuildAMDGPU(IRModule mod, std::string target) { +Target UpdateTarget(const Target& original_target) { + Map target_config = original_target->Export(); + UpdateTargetConfigKeyValueEntry("mtriple", "amdgcn-amd-amdhsa-hcc", &target_config, true); + UpdateTargetConfigKeyValueEntry("mcpu", + "gfx" + std::to_string(DetectROCMComputeVersion(original_target)), + &target_config, false); + if (DetectROCMApiVersion() < 305) { + // before ROCm 3.5 we needed code object v2, starting + // with 3.5 we need v3 (this argument disables v3) + Array mattr; + if (target_config.count("mattr")) { + mattr = Downcast>(target_config["mattr"]); + } + mattr.push_back("-code-object-v3"); + target_config.Set("mattr", mattr); + } + return Target::FromConfig(target_config); +} + +runtime::Module BuildAMDGPU(IRModule mod, Target original_target) { #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; // Lower versions will crash when loading the bitcode, see // issue #4087 for a discussion #endif InitializeLLVM(); - CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm"); - std::ostringstream config; - config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target); - if (DetectROCMApiVersion() < 305) { - // before ROCm 3.5 we needed code object v2, starting - // with 3.5 we need v3 (this argument disables v3) - config << " -mattr=-code-object-v3 "; - } - config << target.substr(4, target.length() - 4); - std::unique_ptr tm = GetLLVMTargetMachine(config.str()); + Target target = UpdateTarget(original_target); + std::unique_ptr tm = GetLLVMTargetMachine(target); std::unique_ptr ctx(new llvm::LLVMContext()); // careful: cg will hold a naked pointer reference to ctx, so it should // have a shorter lifetime than the ctx. diff --git a/src/target/llvm/codegen_blob.cc b/src/target/llvm/codegen_blob.cc index 6df481730548..5d8a7697e0e7 100644 --- a/src/target/llvm/codegen_blob.cc +++ b/src/target/llvm/codegen_blob.cc @@ -24,6 +24,7 @@ #include "codegen_blob.h" #include +#include #include @@ -33,8 +34,8 @@ namespace codegen { std::pair, std::shared_ptr> CodeGenBlob( const std::string& data, bool system_lib, const std::string& target_triple) { InitializeLLVM(); - std::string full_target_triple = std::string("-mtriple ") + target_triple; - auto tm = GetLLVMTargetMachine(full_target_triple); + Target target = Target::Create("llvm -mtriple " + target_triple); + auto tm = GetLLVMTargetMachine(target); auto triple = tm->getTargetTriple(); auto ctx = std::make_shared(); std::string module_name = "devc"; @@ -43,7 +44,7 @@ std::pair, std::shared_ptr> Cod // Store full target string in metadata, because flags such as -mfloat-abi must be preserved for // ModulePackImportsToLLVM. module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target", - llvm::MDString::get(*ctx, full_target_triple)); + llvm::MDString::get(*ctx, LLVMTargetToString(target))); module->setDataLayout(tm->createDataLayout()); auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false); auto* tvm_dev_mblob = new llvm::GlobalVariable( diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index c77215dec74b..c52f9b06929e 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -658,11 +658,7 @@ void ProcessLLVMOptions(const std::vector& llvm_vec) { } // namespace -runtime::Module BuildHexagon(IRModule mod, std::string target_str) { - if (target_str.empty()) { - LOG(FATAL) << "Unknown or invalid target."; - } - +runtime::Module BuildHexagon(IRModule mod, Target target) { // Make sure all targets are registered. InitializeLLVM can be called // multiple times, after the first call all subsequent calls are no-ops. InitializeLLVM(); @@ -675,21 +671,12 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) { } return vec; }; - auto starts_with = [](const std::string& s, const std::string& p) { - return !s.compare(0, p.size(), p); - }; - - std::vector flags = split(target_str); - std::string llvm_target_str, llvm_options_str = "llvm"; - - for (const auto& s : flags) { - if (starts_with(s, "-mattr=") || starts_with(s, "-mtriple=") || starts_with(s, "-mcpu=")) { - llvm_target_str += " " + s; - } else if (starts_with(s, "-llvm-options=")) { - llvm_options_str += "," + s.substr(14 /*length of -llvm-options=*/); - } + std::string llvm_options_str; + if (const Optional llvm_options = target->GetAttr("llvm-options")) { + llvm_options_str = "llvm," + llvm_options.value(); + } else { + llvm_options_str = "llvm"; } - // Postprocess the LLVM options string: replace '@' with '=', and ',' with ' '. for (int i = 0, e = llvm_options_str.size(); i != e; ++i) { switch (llvm_options_str[i]) { @@ -716,7 +703,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) { static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true); (void)CallOnce; - std::unique_ptr tm = GetLLVMTargetMachine(target_str); + std::unique_ptr tm = GetLLVMTargetMachine(target); std::unique_ptr cg(new CodeGenHexagon()); std::unique_ptr ctx(new llvm::LLVMContext()); cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false); @@ -802,9 +789,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) { export_abi); } -TVM_REGISTER_GLOBAL("target.build.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildHexagon(args[0], args[1]); -}); +TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon); } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index e2690b96a106..fe409ba0a0cd 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -254,14 +254,19 @@ inline int DetectCUDAComputeVersion() { } } -runtime::Module BuildNVPTX(IRModule mod, std::string target) { +Target UpdateTarget(const Target& original_target, int compute_ver) { + Map target_config = original_target->Export(); + UpdateTargetConfigKeyValueEntry("mtriple", "nvptx64-nvidia-cuda", &target_config, true); + UpdateTargetConfigKeyValueEntry("mcpu", "sm_" + std::to_string(compute_ver), &target_config, + false); + return Target::FromConfig(target_config); +} + +runtime::Module BuildNVPTX(IRModule mod, Target original_target) { InitializeLLVM(); - CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx"); int compute_ver = DetectCUDAComputeVersion(); - std::ostringstream config; - config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver - << target.substr(5, target.length() - 5); - std::unique_ptr tm = GetLLVMTargetMachine(config.str()); + Target target = UpdateTarget(original_target, compute_ver); + std::unique_ptr tm = GetLLVMTargetMachine(target); std::unique_ptr ctx(new llvm::LLVMContext()); // careful: cg will hold a naked pointer reference to ctx, so it should // have a shorter lifetime than the ctx. diff --git a/src/target/llvm/llvm_common.cc b/src/target/llvm/llvm_common.cc index 3a1036b3b0b5..e8225ab5b6e4 100644 --- a/src/target/llvm/llvm_common.cc +++ b/src/target/llvm/llvm_common.cc @@ -25,6 +25,7 @@ #include "llvm_common.h" #include +#include #include #include @@ -58,53 +59,44 @@ void InitializeLLVM() { } } -void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, +void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, std::string* mattr, llvm::TargetOptions* options) { - // setup target triple - size_t start = 0; - if (target_str.length() >= 4 && target_str.substr(0, 4) == "llvm") { - start = 4; - } // simple parser triple->resize(0); mcpu->resize(0); mattr->resize(0); - bool soft_float_abi = false; - std::string key, value; - std::istringstream is(target_str.substr(start, target_str.length() - start)); - while (is >> key) { - if (key == "-system-lib" || key == "-system-lib=0" || key == "-system-lib=1") { - continue; - } - size_t pos = key.find('='); - if (pos != std::string::npos) { - CHECK_GE(key.length(), pos + 1) << "invalid argument " << key; - value = key.substr(pos + 1, key.length() - 1); - key = key.substr(0, pos); - } else { - CHECK(is >> value) << "Unspecified value for option " << key; + if (const Optional& v = target->GetAttr("mtriple")) { + *triple = v.value(); + } + if (const Optional& v = target->GetAttr("mcpu")) { + *mcpu = v.value(); + } + if (const Optional>& v = target->GetAttr>("mattr")) { + std::ostringstream os; + bool is_first = true; + for (const String& s : v.value()) { + if (!is_first) { + os << ','; + } + is_first = false; + os << s; } - if (key == "-mtriple") { - *triple = value; - } else if (key == "-mcpu") { - *mcpu = value; - } else if (key == "-mattr") { - *mattr = value; - } else if (key == "-mfloat-abi") { - if (value == "hard") { + *mattr = os.str(); + } + if (const Optional& v = target->GetAttr("mfloat-abi")) { + String value = v.value(); + if (value == "hard") { #if TVM_LLVM_VERSION < 60 - LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0"; + LOG(FATAL) << "-mfloat-abi hard is only supported for LLVM > 6.0"; #endif - soft_float_abi = false; - } else if (value == "soft") { - soft_float_abi = true; - } else { - LOG(FATAL) << "invalid -mfloat-abi option " << value; - } + soft_float_abi = false; + } else if (value == "soft") { + soft_float_abi = true; + } else { + LOG(FATAL) << "invalid -mfloat-abi option " << value; } } - if (triple->length() == 0 || *triple == "default") { *triple = llvm::sys::getDefaultTargetTriple(); } @@ -125,12 +117,11 @@ void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, } } -std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, - bool allow_null) { +std::unique_ptr GetLLVMTargetMachine(const Target& target, bool allow_null) { std::string target_triple, mcpu, mattr; llvm::TargetOptions opt; - ParseLLVMTargetOptions(target_str, &target_triple, &mcpu, &mattr, &opt); + ParseLLVMTargetOptions(target, &target_triple, &mcpu, &mattr, &opt); if (target_triple.length() == 0 || target_triple == "default") { target_triple = llvm::sys::getDefaultTargetTriple(); @@ -140,16 +131,42 @@ std::unique_ptr GetLLVMTargetMachine(const std::string& tar } std::string err; - const llvm::Target* target = llvm::TargetRegistry::lookupTarget(target_triple, err); - if (target == nullptr) { + const llvm::Target* llvm_target = llvm::TargetRegistry::lookupTarget(target_triple, err); + if (llvm_target == nullptr) { CHECK(allow_null) << err << " target_triple=" << target_triple; return nullptr; } llvm::TargetMachine* tm = - target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); + llvm_target->createTargetMachine(target_triple, mcpu, mattr, opt, llvm::Reloc::PIC_); return std::unique_ptr(tm); } +std::string LLVMTargetToString(const Target& target) { + std::ostringstream os; + os << "llvm"; + if (Optional mtriple = target->GetAttr("mtriple")) { + os << " -mtriple=" << mtriple.value(); + } + if (Optional mcpu = target->GetAttr("mcpu")) { + os << " -mcpu=" << mcpu.value(); + } + if (Optional> mattr = target->GetAttr>("mattr")) { + bool is_first = true; + os << " -mattr="; + for (const String& attr : mattr.value()) { + if (!is_first) { + os << ","; + } + is_first = false; + os << attr; + } + } + if (Optional mfloat_abo = target->GetAttr("mfloat-abi")) { + os << " -mfloat-abi=" << mfloat_abo.value(); + } + return os.str(); +} + } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 738e0558da85..42cb9db44a2d 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -79,6 +79,10 @@ #include namespace tvm { + +// The TVM target +class Target; + namespace codegen { /*! @@ -89,24 +93,31 @@ void InitializeLLVM(); /*! * \brief Parse target options - * \param target_str Target string, in format "llvm -mtriple=xxx -mcpu=xxx" + * \param target The TVM target * \param triple Target triple * \param mcpu cpu info * \param options the options * \param mattr The attributes */ -void ParseLLVMTargetOptions(const std::string& target_str, std::string* triple, std::string* mcpu, +void ParseLLVMTargetOptions(const Target& target, std::string* triple, std::string* mcpu, std::string* mattr, llvm::TargetOptions* options); /*! - * \brief Get target machine from target_str string. - * \param target_str Target string, in format "llvm -mtriple=xxx -mcpu=xxx" + * \brief Get target machine from TVM target. + * \param target The TVM target * \param allow_null Whether allow null to be returned. * \return target machine */ -std::unique_ptr GetLLVMTargetMachine(const std::string& target_str, +std::unique_ptr GetLLVMTargetMachine(const Target& target, bool allow_null = false); +/*! + * \brief Convert the TVM's LLVM target to string by extracting only relevant fields + * \param target The TVM target to be extracted + * \return The raw string format for the TVM LLVM target + */ +std::string LLVMTargetToString(const Target& target); + } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index de2dadf9bb16..b3d448aee77f 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -189,10 +189,9 @@ class LLVMModuleNode final : public runtime::ModuleNode { return ""; } - void Init(const IRModule& mod, std::string target_str) { + void Init(const IRModule& mod, const Target& target) { InitializeLLVM(); - tm_ = GetLLVMTargetMachine(target_str); - auto target = Target::Create(target_str); + tm_ = GetLLVMTargetMachine(target); bool system_lib = target->GetAttr("system-lib").value_or(Bool(false)); bool target_c_runtime = (target->GetAttr("runtime").value_or("") == kTvmRuntimeCrt); ctx_ = std::make_shared(); @@ -225,7 +224,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { module_ = cg->Finish(); module_->addModuleFlag(llvm::Module::Warning, "tvm_target", - llvm::MDString::get(*ctx_, target_str)); + llvm::MDString::get(*ctx_, LLVMTargetToString(target))); module_->addModuleFlag(llvm::Module::Override, "Debug Info Version", llvm::DEBUG_METADATA_VERSION); @@ -238,7 +237,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { LOG_IF(FATAL, llvm::verifyModule(*module_, &verify_errors)) << "LLVM module verification failed with the following errors: \n" << verify_errors.str(); - target_ = target_str; + target_ = target; mptr_ = module_.get(); } @@ -251,19 +250,22 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::string msg = std::string(err.getMessage()); LOG(FATAL) << "Fail to load module: " << msg; } - std::string target_; - llvm::Metadata* mtarget = module_->getModuleFlag("tvm_target"); - if (mtarget != nullptr) { - llvm::MDString* pstr = llvm::dyn_cast(mtarget); + std::string target_metadata; + llvm::Metadata* tvm_target = module_->getModuleFlag("tvm_target"); + if (tvm_target != nullptr) { + llvm::MDString* pstr = llvm::dyn_cast(tvm_target); CHECK(pstr != nullptr); - target_ = pstr->getString().str(); + target_metadata = pstr->getString().str(); + if (!(target_metadata.length() >= 4 && target_metadata.substr(0, 4) == "llvm")) { + target_metadata = "llvm " + target_metadata; + } } else { std::ostringstream os; os << "llvm -mtriple " << module_->getTargetTriple(); - target_ = os.str(); + target_metadata = os.str(); } mptr_ = module_.get(); - tm_ = GetLLVMTargetMachine(target_); + tm_ = GetLLVMTargetMachine(Target::Create(target_metadata)); } void LoadIR(const std::string& file_name) { @@ -284,6 +286,9 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (ee_) { return; } + if (!target_.defined()) { + target_ = Target::Create("llvm"); + } llvm::EngineBuilder builder(std::move(module_)); std::string triple, mcpu, mattr; llvm::TargetOptions opt; @@ -299,7 +304,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { } builder.setTargetOptions(opt); auto tm = std::unique_ptr(builder.selectTarget()); - std::unique_ptr tm_sys = GetLLVMTargetMachine("llvm"); + std::unique_ptr tm_sys = GetLLVMTargetMachine(Target::Create("llvm")); if (tm_sys->getTargetTriple().getArch() != tm->getTargetTriple().getArch()) { LOG(FATAL) << "Cannot run module, architecture mismatch " << " module=" << tm->getTargetTriple().str() @@ -340,7 +345,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { } // The target configuration string - std::string target_; + Target target_; // JIT lock std::mutex mutex_; // execution engine @@ -355,64 +360,62 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::shared_ptr ctx_; }; -unsigned LookupLLVMIntrinsic(const std::string& name) { - return llvm::Function::lookupIntrinsicID(name); -} - -TVM_REGISTER_GLOBAL("target.build.llvm").set_body_typed([](IRModule mod, std::string target) { - auto n = make_object(); - n->Init(mod, target); - return runtime::Module(n); +TVM_REGISTER_GLOBAL("target.build.llvm") + .set_body_typed([](IRModule mod, Target target) -> runtime::Module { + auto n = make_object(); + n->Init(mod, target); + return runtime::Module(n); + }); + +TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate") + .set_body_typed([](std::string target_str, std::string module_name) -> runtime::Module { + Target target = Target::Create(target_str); + auto n = make_object(); + // Generate a LLVM module from an input target string + InitializeLLVM(); + auto tm = GetLLVMTargetMachine(target); + auto ctx = std::make_shared(); + std::unique_ptr module(new llvm::Module(module_name, *ctx)); + // Use a default data layout and target triple + auto triple = tm->getTargetTriple(); + module->setTargetTriple(triple.str()); + module->setDataLayout(tm->createDataLayout()); + n->Init(std::move(module), ctx); + return runtime::Module(n); + }); + +TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id") + .set_body_typed([](std::string name) -> int64_t { + return static_cast(llvm::Function::lookupIntrinsicID(name)); + }); + +TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body_typed([]() -> int { + return TVM_LLVM_VERSION / 10; }); -TVM_REGISTER_GLOBAL("codegen.LLVMModuleCreate").set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - auto target = args[0].operator std::string(); - auto module_name = args[1].operator std::string(); - - // Generate a LLVM module from an input target string - InitializeLLVM(); - auto tm = GetLLVMTargetMachine(target); - auto ctx = std::make_shared(); - std::unique_ptr module(new llvm::Module(module_name, *ctx)); - - // Use a default data layout and target triple - auto triple = tm->getTargetTriple(); - module->setTargetTriple(triple.str()); - module->setDataLayout(tm->createDataLayout()); +TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll") + .set_body_typed([](std::string filename, std::string fmt) -> runtime::Module { + auto n = make_object(); + n->LoadIR(filename); + return runtime::Module(n); + }); + +TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled") + .set_body_typed([](std::string target_str) -> bool { + InitializeLLVM(); + Target target = Target::Create(target_str); + return (GetLLVMTargetMachine(target, true) != nullptr); + }); + +TVM_REGISTER_GLOBAL("codegen.codegen_blob") + .set_body_typed([](std::string data, bool system_lib, + std::string target_triple) -> runtime::Module { + auto n = make_object(); + auto p = CodeGenBlob(data, system_lib, target_triple); + n->Init(std::move(p.first), p.second); + return runtime::Module(n); + }); - n->Init(std::move(module), ctx); - - *rv = runtime::Module(n); -}); - -TVM_REGISTER_GLOBAL("target.llvm_lookup_intrinsic_id").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = static_cast(LookupLLVMIntrinsic(args[0])); -}); - -TVM_REGISTER_GLOBAL("target.llvm_version_major").set_body([](TVMArgs args, TVMRetValue* rv) { - int major = TVM_LLVM_VERSION / 10; - *rv = major; -}); - -TVM_REGISTER_GLOBAL("runtime.module.loadfile_ll").set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->LoadIR(args[0]); - *rv = runtime::Module(n); -}); - -TVM_REGISTER_GLOBAL("codegen.llvm_target_enabled").set_body([](TVMArgs args, TVMRetValue* rv) { - InitializeLLVM(); - *rv = (GetLLVMTargetMachine(args[0], true) != nullptr); -}); - -TVM_REGISTER_GLOBAL("codegen.codegen_blob").set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - auto p = CodeGenBlob(args[0].operator std::string(), args[1].operator bool(), - args[2].operator std::string()); - n->Init(std::move(p.first), p.second); - *rv = runtime::Module(n); -}); } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index c9471d1bfa8d..780829c256ce 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -121,7 +121,7 @@ std::string NVRTCCompile(const std::string& code, bool include_path = false) { return ptx; } -runtime::Module BuildCUDA(IRModule mod, std::string target) { +runtime::Module BuildCUDA(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenCUDA cg; diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 597fd37f6774..e90b7d4f8b2c 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -33,7 +33,7 @@ namespace tvm { namespace codegen { -runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) { +runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) { // Get code. using tvm::runtime::Registry; bool output_ssa = false; @@ -61,7 +61,6 @@ runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) std::string cmd = "aoc aocl.cl"; // AOCL supports fp64. cmd += " -Dcl_khr_fp64"; - Target target = Target::Create(target_str); Optional device = target->GetAttr("device"); if (device.defined()) { cmd += " -board=" + device.value(); @@ -80,13 +79,15 @@ runtime::Module BuildAOCL(IRModule mod, std::string target_str, bool emulation) return AOCLModuleCreate(aocxbin, "aocx", ExtractFuncInfo(mod), code); } -TVM_REGISTER_GLOBAL("target.build.aocl").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], false); -}); +TVM_REGISTER_GLOBAL("target.build.aocl") + .set_body_typed([](IRModule mod, Target target) -> runtime::Module { + return BuildAOCL(mod, target, false); + }); -TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildAOCL(args[0], args[1], true); -}); +TVM_REGISTER_GLOBAL("target.build.build.aocl_sw_emu") + .set_body_typed([](IRModule mod, Target target) -> runtime::Module { + return BuildAOCL(mod, target, true); + }); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index f4aa7281f1f9..5bd7b2e91c31 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -298,12 +298,11 @@ void CodeGenCHost::GenerateCrtSystemLib() { << "}\n"; } -runtime::Module BuildCHost(IRModule mod, const std::string& target_str) { +runtime::Module BuildCHost(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; bool emit_asserts = false; CodeGenCHost cg; - auto target = Target::Create(target_str); cg.Init(output_ssa, emit_asserts); for (auto kv : mod->functions) { @@ -323,8 +322,6 @@ runtime::Module BuildCHost(IRModule mod, const std::string& target_str) { return CSourceModuleCreate(code, "c"); } -TVM_REGISTER_GLOBAL("target.build.c").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildCHost(args[0], args[1]); -}); +TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 1c4256c5a166..fb235d2d785d 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -282,7 +282,7 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT } } -runtime::Module BuildMetal(IRModule mod) { +runtime::Module BuildMetal(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenMetal cg; @@ -308,8 +308,6 @@ runtime::Module BuildMetal(IRModule mod) { return MetalModuleCreate(code, fmt, ExtractFuncInfo(mod), source); } -TVM_REGISTER_GLOBAL("target.build.metal").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = BuildMetal(args[0]); -}); +TVM_REGISTER_GLOBAL("target.build.metal").set_body_typed(BuildMetal); } // namespace codegen } // namespace tvm diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 21e5ed66403f..10cc007c4572 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -280,7 +280,7 @@ void CodeGenOpenCL::VisitExpr_(const FloatImmNode* op, std::ostream& os) { // N } } -runtime::Module BuildOpenCL(IRModule mod, std::string target) { +runtime::Module BuildOpenCL(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenOpenCL cg; diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 3d77ddadd29c..9401f0682db8 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -137,7 +137,7 @@ void CodeGenVivadoHLS::VisitExpr_(const MaxNode* op, std::ostream& os) { // NOL PrintBinaryExpr(op, opstr, os, this); } -runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { +runtime::Module BuildSDAccel(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; CodeGenVivadoHLS cg; @@ -178,7 +178,6 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { std::string xclbin; if (const auto* f = Registry::Get("tvm_callback_sdaccel_compile")) { - Target target = Target::Create(target_str); String device = target->GetAttr("device", "").value(); xclbin = (*f)(kernel_info, device).operator std::string(); } else { diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 86d1614dc863..1eef2f8f88e5 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -63,7 +63,7 @@ class SPIRVTools { spv_context ctx_; }; -runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restriction) { +runtime::Module BuildSPIRV(IRModule mod, Target target, bool webgpu_restriction) { using tvm::runtime::Registry; using tvm::runtime::VulkanShader; @@ -116,11 +116,11 @@ runtime::Module BuildSPIRV(IRModule mod, std::string target, bool webgpu_restric return runtime::VulkanModuleCreate(smap, ExtractFuncInfo(mod), code_data.str()); } -TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, std::string target) { +TVM_REGISTER_GLOBAL("target.build.vulkan").set_body_typed([](IRModule mod, Target target) { return BuildSPIRV(mod, target, false); }); -TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, std::string target) { +TVM_REGISTER_GLOBAL("target.build.webgpu").set_body_typed([](IRModule mod, Target target) { return BuildSPIRV(mod, target, true); }); diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 9cad92dfdacc..ac3ba78fa4d5 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -510,7 +510,7 @@ void CodeGenStackVM::VisitExpr_(const LetNode* op) { this->Push(op->body); } -runtime::Module BuildStackVM(const IRModule& mod, const std::string& target) { +runtime::Module BuildStackVM(IRModule mod, Target target) { std::unordered_map fmap; std::string entry_func; diff --git a/src/target/target.cc b/src/target/target.cc index 47b405430ead..ccc0023378df 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -276,6 +276,18 @@ std::unordered_set TargetNode::GetLibs() const { return result; } +Map TargetNode::Export() const { + Map result = { + {"kind", this->kind->name}, + {"tag", this->tag}, + {"keys", this->keys}, + }; + for (const auto& kv : attrs) { + result.Set(kv.first, kv.second); + } + return result; +} + const std::string& TargetNode::str() const { if (str_repr_.empty()) { std::ostringstream os; @@ -527,10 +539,14 @@ TVM_REGISTER_GLOBAL("target.TargetFromString").set_body_typed(Target::Create); TVM_REGISTER_GLOBAL("target.TargetFromConfig").set_body_typed(Target::FromConfig); +TVM_REGISTER_GLOBAL("target.TargetExport") + .set_body_typed([](Target target) -> Map { return target->Export(); }); + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - p->stream << op->str(); + const auto* target = node.as(); + CHECK(target); + p->stream << target->str(); }); namespace target { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 40ade4de96a0..29f16925968d 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -106,6 +106,7 @@ TVM_REGISTER_TARGET_KIND("nvptx") .add_attr_option("max_num_threads", Integer(1024)) .add_attr_option("thread_warp_size", Integer(32)) .add_attr_option("mcpu") + .add_attr_option("mtriple") .set_default_keys({"cuda", "gpu"}) .set_device_type(kDLGPU); diff --git a/tests/cpp/build_module_test.cc b/tests/cpp/build_module_test.cc index 2462fd1e733f..48edfcd024f5 100644 --- a/tests/cpp/build_module_test.cc +++ b/tests/cpp/build_module_test.cc @@ -56,9 +56,14 @@ TEST(BuildModule, Basic) { auto module = build(lowered, target, Target()); auto mali_target = Target::Create("opencl -model=Mali-T860MP4@800Mhz -device=mali"); - CHECK_EQ( - mali_target->str(), - "opencl -keys=mali,opencl,gpu -device=mali -max_num_threads=256 -model=Mali-T860MP4@800Mhz"); + CHECK_EQ(mali_target->kind->name, "opencl"); + CHECK_EQ(mali_target->keys.size(), 3); + CHECK_EQ(mali_target->keys[0], "mali"); + CHECK_EQ(mali_target->keys[1], "opencl"); + CHECK_EQ(mali_target->keys[2], "gpu"); + CHECK_EQ(mali_target->GetAttr("device").value(), "mali"); + CHECK_EQ(mali_target->GetAttr("model").value(), "Mali-T860MP4@800Mhz"); + CHECK_EQ(mali_target->GetAttr("max_num_threads").value(), 256); } TEST(BuildModule, Heterogeneous) {