From 9ccc89d22c94ab02cc7bcb130626ae3ff77c94a0 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Thu, 14 Nov 2019 22:00:05 +0100 Subject: [PATCH] Add workgroup size attribute to AMDGPU functions in codegen When we did not set the workgroup size, LLVM will use too many registers for kernel launches with many threads. This resulted in "invalid ISA" errors. Here we set the maximum workgroup size to the maximum threads per block from the device API. Of course, one might look into allowing configurations with fewer threads at runtime to use more registers. --- src/codegen/llvm/codegen_amdgpu.cc | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 28b2deb3f2b76..87052aae521e1 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -36,6 +36,29 @@ namespace tvm { namespace codegen { +namespace { + +// calls the device api to get the max threads per block +static inline int DetectROCMmaxThreadsPerBlock() { + TVMContext tvm_ctx; + tvm_ctx.device_type = kDLROCM; + tvm_ctx.device_id = 0; + tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true); + if (api != nullptr) { + TVMRetValue val; + api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); + if (val.operator int() == 1) { + tvm::runtime::DeviceAPI::Get(tvm_ctx)-> + GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, &val); + return val.operator int(); + } + } + LOG(WARNING) << "Cannot get maximum number of threads for AMD codegen"; + return 1024; +} + +} // namespace + // AMDGPU code generator. class CodeGenAMDGPU : public CodeGenLLVM { public: @@ -43,6 +66,9 @@ class CodeGenAMDGPU : public CodeGenLLVM { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); + std::ostringstream attr; + attr << "1," << DetectROCMmaxThreadsPerBlock(); + function_->addFnAttr("amdgpu-flat-work-group-size", attr.str()); } void VisitStmt_(const Allocate* op) final {