Skip to content

Commit

Permalink
[Target][Codegen] Use target class in all codegens (#6347)
Browse files Browse the repository at this point in the history
* [Target][Codegen] Make all code generator use Target class instead of target string

* Remove dep to TargetNode::str() in LLVM module

* Allow  for llvm nvptx codegen

* ...

* Address comments from Cody

* Rename UpdateTargetConfig => UpdateTargetConfigKeyValueEntry
  • Loading branch information
junrushao authored Aug 29, 2020
1 parent b368f9d commit d9450f8
Show file tree
Hide file tree
Showing 23 changed files with 276 additions and 200 deletions.
2 changes: 1 addition & 1 deletion include/tvm/target/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ using runtime::TVMRetValue;
* \param target The target to be built.
* \return The result runtime::Module.
*/
runtime::Module Build(IRModule mod, const Target& target);
runtime::Module Build(IRModule mod, Target target);

/*!
* \brief Pack imported device library to a C file.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ class TargetNode : public Object {
Array<String> keys;
/*! \brief Collection of attributes */
Map<String, ObjectRef> attrs;

/*! \return the full device string to pass to codegen::Build */
TVM_DLL const std::string& str() const;
/*! \return Export target to JSON-like configuration */
TVM_DLL Map<String, ObjectRef> Export() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __enter__(self):
def __exit__(self, ptype, value, trace):
_ffi_api.ExitTargetScope(self)

def export(self):
return _ffi_api.TargetExport(self)

@staticmethod
def current(allow_none=True):
"""Returns the current target.
Expand Down
17 changes: 17 additions & 0 deletions src/target/build_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,23 @@ inline std::unordered_map<std::string, runtime::FunctionInfo> ExtractFuncInfo(co
}
return fmap;
}

inline void UpdateTargetConfigKeyValueEntry(const String& key, const String& value,
Map<String, ObjectRef>* target_config,
bool error_if_inconsistent) {
if (target_config->count(key)) {
const ObjectRef& obj = (*target_config)[key];
CHECK(obj->IsInstance<StringObj>()) << "TypeError: Expect target key \"" << key
<< "\" to be String, but gets type: " << obj->GetTypeKey();
if (error_if_inconsistent) {
String old_value = Downcast<String>(obj);
CHECK_EQ(old_value, value) << "ValueError: Target key \"" << key << "\" has been set to \""
<< old_value << "\", and cannot be reset to \"" << value << "\"";
}
}
target_config->Set(key, value);
}

} // namespace codegen
} // namespace tvm
#endif // TVM_TARGET_BUILD_COMMON_H_
6 changes: 3 additions & 3 deletions src/target/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
namespace tvm {
namespace codegen {

runtime::Module Build(IRModule mod, const Target& target) {
runtime::Module Build(IRModule mod, Target target) {
if (transform::PassContext::Current()
->GetConfig<Bool>("tir.disable_assert", Bool(false))
.value()) {
Expand All @@ -55,8 +55,8 @@ runtime::Module Build(IRModule mod, const Target& target) {
}
// the build function.
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr) << "target.build." << target << " is not enabled";
return (*bf)(mod, target->str());
CHECK(bf != nullptr) << build_f_name << " is not enabled";
return (*bf)(mod, target);
}

/*! \brief Helper class to serialize module */
Expand Down
50 changes: 33 additions & 17 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,17 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
};

inline int DetectROCMComputeVersion(const std::string& target) {
size_t pos = target.find("=gfx");
if (pos != std::string::npos) {
int value;
std::stringstream is(target.substr(pos + 4));
if (is >> value) return value;
inline int DetectROCMComputeVersion(const Target& target) {
if (const Optional<String> mcpu = target->GetAttr<String>("mcpu")) {
std::string gfx = mcpu.value();
if (gfx.length() >= 3 && gfx.substr(0, 3) == "gfx") {
int version;
std::stringstream is(gfx.substr(3));
if (is >> version) {
return version;
}
}
LOG(FATAL) << "ValueError: Unrecognized -mcpu value: " << mcpu;
}
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLROCM;
Expand Down Expand Up @@ -228,23 +233,34 @@ inline int DetectROCMApiVersion() {
return 305;
}

runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
Target UpdateTarget(const Target& original_target) {
Map<String, ObjectRef> target_config = original_target->Export();
UpdateTargetConfigKeyValueEntry("mtriple", "amdgcn-amd-amdhsa-hcc", &target_config, true);
UpdateTargetConfigKeyValueEntry("mcpu",
"gfx" + std::to_string(DetectROCMComputeVersion(original_target)),
&target_config, false);
if (DetectROCMApiVersion() < 305) {
// before ROCm 3.5 we needed code object v2, starting
// with 3.5 we need v3 (this argument disables v3)
Array<String> mattr;
if (target_config.count("mattr")) {
mattr = Downcast<Array<String>>(target_config["mattr"]);
}
mattr.push_back("-code-object-v3");
target_config.Set("mattr", mattr);
}
return Target::FromConfig(target_config);
}

runtime::Module BuildAMDGPU(IRModule mod, Target original_target) {
#if TVM_LLVM_VERSION < 90
LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
// Lower versions will crash when loading the bitcode, see
// issue #4087 for a discussion
#endif
InitializeLLVM();
CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm");
std::ostringstream config;
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target);
if (DetectROCMApiVersion() < 305) {
// before ROCm 3.5 we needed code object v2, starting
// with 3.5 we need v3 (this argument disables v3)
config << " -mattr=-code-object-v3 ";
}
config << target.substr(4, target.length() - 4);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
Target target = UpdateTarget(original_target);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
Expand Down
7 changes: 4 additions & 3 deletions src/target/llvm/codegen_blob.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "codegen_blob.h"

#include <tvm/runtime/module.h>
#include <tvm/target/target.h>

#include <cstring>

Expand All @@ -33,8 +34,8 @@ namespace codegen {
std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> CodeGenBlob(
const std::string& data, bool system_lib, const std::string& target_triple) {
InitializeLLVM();
std::string full_target_triple = std::string("-mtriple ") + target_triple;
auto tm = GetLLVMTargetMachine(full_target_triple);
Target target = Target::Create("llvm -mtriple " + target_triple);
auto tm = GetLLVMTargetMachine(target);
auto triple = tm->getTargetTriple();
auto ctx = std::make_shared<llvm::LLVMContext>();
std::string module_name = "devc";
Expand All @@ -43,7 +44,7 @@ std::pair<std::unique_ptr<llvm::Module>, std::shared_ptr<llvm::LLVMContext>> Cod
// Store full target string in metadata, because flags such as -mfloat-abi must be preserved for
// ModulePackImportsToLLVM.
module->addModuleFlag(llvm::Module::ModFlagBehavior::Override, "tvm_target",
llvm::MDString::get(*ctx, full_target_triple));
llvm::MDString::get(*ctx, LLVMTargetToString(target)));
module->setDataLayout(tm->createDataLayout());
auto* blob_value = llvm::ConstantDataArray::getString(*ctx, data, false);
auto* tvm_dev_mblob = new llvm::GlobalVariable(
Expand Down
31 changes: 8 additions & 23 deletions src/target/llvm/codegen_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,11 +658,7 @@ void ProcessLLVMOptions(const std::vector<std::string>& llvm_vec) {

} // namespace

runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
if (target_str.empty()) {
LOG(FATAL) << "Unknown or invalid target.";
}

runtime::Module BuildHexagon(IRModule mod, Target target) {
// Make sure all targets are registered. InitializeLLVM can be called
// multiple times, after the first call all subsequent calls are no-ops.
InitializeLLVM();
Expand All @@ -675,21 +671,12 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
}
return vec;
};
auto starts_with = [](const std::string& s, const std::string& p) {
return !s.compare(0, p.size(), p);
};

std::vector<std::string> flags = split(target_str);
std::string llvm_target_str, llvm_options_str = "llvm";

for (const auto& s : flags) {
if (starts_with(s, "-mattr=") || starts_with(s, "-mtriple=") || starts_with(s, "-mcpu=")) {
llvm_target_str += " " + s;
} else if (starts_with(s, "-llvm-options=")) {
llvm_options_str += "," + s.substr(14 /*length of -llvm-options=*/);
}
std::string llvm_options_str;
if (const Optional<String> llvm_options = target->GetAttr<String>("llvm-options")) {
llvm_options_str = "llvm," + llvm_options.value();
} else {
llvm_options_str = "llvm";
}

// Postprocess the LLVM options string: replace '@' with '=', and ',' with ' '.
for (int i = 0, e = llvm_options_str.size(); i != e; ++i) {
switch (llvm_options_str[i]) {
Expand All @@ -716,7 +703,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
static bool CallOnce = (ProcessLLVMOptions(llvm_options_vec), true);
(void)CallOnce;

std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target_str);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<CodeGenHexagon> cg(new CodeGenHexagon());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
cg->Init("TVMHexagonModule", tm.get(), ctx.get(), false, false, false);
Expand Down Expand Up @@ -802,9 +789,7 @@ runtime::Module BuildHexagon(IRModule mod, std::string target_str) {
export_abi);
}

TVM_REGISTER_GLOBAL("target.build.hexagon").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildHexagon(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("target.build.hexagon").set_body_typed(BuildHexagon);

} // namespace codegen
} // namespace tvm
Expand Down
17 changes: 11 additions & 6 deletions src/target/llvm/codegen_nvptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,14 +254,19 @@ inline int DetectCUDAComputeVersion() {
}
}

runtime::Module BuildNVPTX(IRModule mod, std::string target) {
Target UpdateTarget(const Target& original_target, int compute_ver) {
Map<String, ObjectRef> target_config = original_target->Export();
UpdateTargetConfigKeyValueEntry("mtriple", "nvptx64-nvidia-cuda", &target_config, true);
UpdateTargetConfigKeyValueEntry("mcpu", "sm_" + std::to_string(compute_ver), &target_config,
false);
return Target::FromConfig(target_config);
}

runtime::Module BuildNVPTX(IRModule mod, Target original_target) {
InitializeLLVM();
CHECK(target.length() >= 5 && target.substr(0, 5) == "nvptx");
int compute_ver = DetectCUDAComputeVersion();
std::ostringstream config;
config << "-mtriple=nvptx64-nvidia-cuda -mcpu=sm_" << compute_ver
<< target.substr(5, target.length() - 5);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
Target target = UpdateTarget(original_target, compute_ver);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(target);
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
// have a shorter lifetime than the ctx.
Expand Down
Loading

0 comments on commit d9450f8

Please sign in to comment.