Skip to content

Commit

Permalink
Two small fixes to AMDCPU codegen for LLVM 10+ and ROCm 3.5+
Browse files Browse the repository at this point in the history
- For LLVM 10+ we need to avoid calling Align with 0, or else
  we get a crash.
- For ROCm 3.5+ we need to use code object 3 (the default in LLVM 9+)
  but for ROCm < 3.5 we want the code object 2.
- As we want to separate codegen from the API, we need to add
  a device api query for the version.
  But every one else wants now one, too. (But I only filled it
  in for CUDA for now.)
  • Loading branch information
t-vi committed Jun 25, 2020
1 parent 074a07e commit 3206394
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 5 deletions.
3 changes: 2 additions & 1 deletion include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ enum DeviceAttrKind : int {
kMultiProcessorCount = 7,
kMaxThreadDimensions = 8,
kMaxRegistersPerBlock = 9,
kGcnArch = 10
kGcnArch = 10,
kApiVersion = 11
};

/*! \brief Number of bytes each allocation must align to */
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ class CUDADeviceAPI final : public DeviceAPI {
}
case kGcnArch:
return;
case kApiVersion: {
*rv = CUDA_VERSION;
return;
}
}
*rv = value;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/metal/metal_device_api.mm
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
return;
case kGcnArch:
return;
case kApiVersion:
return;
}
}

Expand Down
2 changes: 2 additions & 0 deletions src/runtime/opencl/opencl_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
return;
case kGcnArch:
return;
case kApiVersion:
return;
}
}

Expand Down
1 change: 1 addition & 0 deletions src/runtime/rocm/rocm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#define TVM_RUNTIME_ROCM_ROCM_COMMON_H_

#include <hip/hip_runtime_api.h>
#include <hip/hip_version.h>
#include <tvm/runtime/packed_func.h>

#include <string>
Expand Down
4 changes: 4 additions & 0 deletions src/runtime/rocm/rocm_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
*rv = prop.gcnArch;
return;
}
case kApiVersion: {
*rv = HIP_VERSION;
return;
}
}
*rv = value;
}
Expand Down
2 changes: 2 additions & 0 deletions src/runtime/vulkan/vulkan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue*
return;
case kGcnArch:
return;
case kApiVersion:
return;
}
}

Expand Down
29 changes: 25 additions & 4 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ class CodeGenAMDGPU : public CodeGenLLVM {
llvm::GlobalVariable* global = new llvm::GlobalVariable(
*module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr,
llvm::GlobalValue::NotThreadLocal, shared_address_space);
if (global->getAlignment() < static_cast<uint32_t>(info.alignment)) {
#if TVM_LLVM_VERSION >= 100
global->setAlignment(llvm::Align(info.alignment));
global->setAlignment(llvm::Align(info.alignment));
#else
global->setAlignment(info.alignment);
global->setAlignment(info.alignment);
#endif
}
buf = global;
}

Expand Down Expand Up @@ -212,6 +214,20 @@ inline int DetectROCMComputeVersion(const std::string& target) {
return 900;
}

inline int DetectROCMApiVersion() {
TVMContext tvm_ctx;
tvm_ctx.device_type = kDLROCM;
tvm_ctx.device_id = 0;
tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true);
if (api != nullptr) {
TVMRetValue val;
api->GetAttr(tvm_ctx, tvm::runtime::kApiVersion, &val);
return val.operator int();
}
LOG(WARNING) << "Cannot detect ROCm version, assume >= 3.5";
return 305;
}

runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
#if TVM_LLVM_VERSION < 90
LOG(FATAL) << "AMDGPU backend requires at least LLVM 9";
Expand All @@ -221,8 +237,13 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) {
InitializeLLVM();
CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm");
std::ostringstream config;
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target)
<< " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4);
config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target);
if (DetectROCMApiVersion() < 305) {
// before ROCm 3.5 we needed code object v2, starting
// with 3.5 we need v3 (this argument disables v3)
config << " -mattr=-code-object-v3 ";
}
config << target.substr(4, target.length() - 4);
std::unique_ptr<llvm::TargetMachine> tm = GetLLVMTargetMachine(config.str());
std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext());
// careful: cg will hold a naked pointer reference to ctx, so it should
Expand Down

0 comments on commit 3206394

Please sign in to comment.