diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 28b2deb3f2b76..87052aae521e1 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -36,6 +36,29 @@ namespace tvm { namespace codegen { +namespace { + +// calls the device api to get the max threads per block +static inline int DetectROCMmaxThreadsPerBlock() { + TVMContext tvm_ctx; + tvm_ctx.device_type = kDLROCM; + tvm_ctx.device_id = 0; + tvm::runtime::DeviceAPI* api = tvm::runtime::DeviceAPI::Get(tvm_ctx, true); + if (api != nullptr) { + TVMRetValue val; + api->GetAttr(tvm_ctx, tvm::runtime::kExist, &val); + if (val.operator int() == 1) { + tvm::runtime::DeviceAPI::Get(tvm_ctx)-> + GetAttr(tvm_ctx, tvm::runtime::kMaxThreadsPerBlock, &val); + return val.operator int(); + } + } + LOG(WARNING) << "Cannot get maximum number of threads for AMD codegen"; + return 1024; +} + +} // namespace + // AMDGPU code generator. class CodeGenAMDGPU : public CodeGenLLVM { public: @@ -43,6 +66,9 @@ class CodeGenAMDGPU : public CodeGenLLVM { // add function as void return value CodeGenLLVM::AddFunctionInternal(f, true); function_->setCallingConv(llvm::CallingConv::AMDGPU_KERNEL); + std::ostringstream attr; + attr << "1," << DetectROCMmaxThreadsPerBlock(); + function_->addFnAttr("amdgpu-flat-work-group-size", attr.str()); } void VisitStmt_(const Allocate* op) final {