diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 3cf5566f3231..c6a2ce3d28d0 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -45,7 +45,8 @@ enum DeviceAttrKind : int { kMultiProcessorCount = 7, kMaxThreadDimensions = 8, kMaxRegistersPerBlock = 9, - kGcnArch = 10 + kGcnArch = 10, + kApiVersion = 11 }; /*! \brief Number of bytes each allocation must align to */ diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/runtime/cuda/cuda_device_api.cc index ccd8e91e0c5d..14444c92f620 100644 --- a/src/runtime/cuda/cuda_device_api.cc +++ b/src/runtime/cuda/cuda_device_api.cc @@ -98,6 +98,10 @@ class CUDADeviceAPI final : public DeviceAPI { } case kGcnArch: return; + case kApiVersion: { + *rv = CUDA_VERSION; + return; + } } *rv = value; } diff --git a/src/runtime/metal/metal_device_api.mm b/src/runtime/metal/metal_device_api.mm index a64f35ced2c2..f2a2930810e5 100644 --- a/src/runtime/metal/metal_device_api.mm +++ b/src/runtime/metal/metal_device_api.mm @@ -69,6 +69,8 @@ return; case kGcnArch: return; + case kApiVersion: + return; } } diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc index 72d03fb6a4fc..5753c1d0f76b 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/runtime/opencl/opencl_device_api.cc @@ -111,6 +111,8 @@ void OpenCLWorkspace::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* return; case kGcnArch: return; + case kApiVersion: + return; } } diff --git a/src/runtime/rocm/rocm_common.h b/src/runtime/rocm/rocm_common.h index 2e637f5496bb..6ed9bccb1ab7 100644 --- a/src/runtime/rocm/rocm_common.h +++ b/src/runtime/rocm/rocm_common.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_ROCM_ROCM_COMMON_H_ #include +#include #include #include diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index e3dbef5ff42a..e1a14c7dcf1c 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -103,13 +103,19 @@ class ROCMDeviceAPI final : public DeviceAPI { return; } case kMaxRegistersPerBlock: - return; + ROCM_CALL( + hipDeviceGetAttribute(&value, hipDeviceAttributeMaxRegistersPerBlock, ctx.device_id)); + break; case kGcnArch: { hipDeviceProp_t prop; ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); *rv = prop.gcnArch; return; } + case kApiVersion: { + *rv = HIP_VERSION; + return; + } } *rv = value; } diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index ade4ddca9376..9e730b7fd8b1 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -417,6 +417,8 @@ void VulkanDeviceAPI::GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* return; case kGcnArch: return; + case kApiVersion: + return; } } diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 8e6b3a2ff22c..93c94cfa4389 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -108,11 +108,13 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::GlobalVariable* global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + if (global->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + global->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + global->setAlignment(info.alignment); #endif + } buf = global; } @@ -212,6 +214,20 @@ inline int DetectROCMComputeVersion(const std::string& target) { return 900; } +inline int DetectROCMApiVersion() { + TVMContext tvm_ctx; + tvm_ctx.device_type = kDLROCM; + tvm_ctx.device_id = 0; + tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true); + if (api != nullptr) { + TVMRetValue val; + api->GetAttr(tvm_ctx, tvm::runtime::kApiVersion, &val); + return val.operator int(); + } + LOG(WARNING) << "Cannot detect ROCm version, assume >= 3.5"; + return 305; +} + runtime::Module BuildAMDGPU(IRModule mod, std::string target) { #if TVM_LLVM_VERSION < 90 LOG(FATAL) << "AMDGPU backend requires at least LLVM 9"; @@ -221,8 +237,13 @@ runtime::Module BuildAMDGPU(IRModule mod, std::string target) { InitializeLLVM(); CHECK(target.length() >= 4 && target.substr(0, 4) == "rocm"); std::ostringstream config; - config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target) - << " -mattr=-code-object-v3 " << target.substr(4, target.length() - 4); + config << "-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" << DetectROCMComputeVersion(target); + if (DetectROCMApiVersion() < 305) { + // before ROCm 3.5 we needed code object v2, starting + // with 3.5 we need v3 (this argument disables v3) + config << " -mattr=-code-object-v3 "; + } + config << target.substr(4, target.length() - 4); std::unique_ptr tm = GetLLVMTargetMachine(config.str()); std::unique_ptr ctx(new llvm::LLVMContext()); // careful: cg will hold a naked pointer reference to ctx, so it should