diff --git a/HalideIR b/HalideIR index a40a3e2fedee..cb3c025d5b91 160000 --- a/HalideIR +++ b/HalideIR @@ -1 +1 @@ -Subproject commit a40a3e2fedee88d2f7b97ba4caf8a9d0eb25886f +Subproject commit cb3c025d5b91ab994b063a85bc935fc364d8f491 diff --git a/src/codegen/llvm/codegen_amdgpu.cc b/src/codegen/llvm/codegen_amdgpu.cc index 146880b7dd89..12218955f110 100644 --- a/src/codegen/llvm/codegen_amdgpu.cc +++ b/src/codegen/llvm/codegen_amdgpu.cc @@ -135,9 +135,21 @@ runtime::Module BuildAMDGPU(Array funcs, std::string target) { CHECK(target.length( ) >= 4 && target.substr(0, 4) == "rocm"); + + TVMContext tvmCtx; + tvmCtx.device_type = kROCM; + tvmCtx.device_id = 0; + TVMRetValue val; + tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kExist, &val); + if (val.operator int() == 1) { + tvm::runtime::DeviceAPI::Get(tvmCtx)->GetAttr(tvmCtx, tvm::runtime::kComputeVersion, &val); + } else { + val = 803; + } + llvm::TargetMachine* tm = \ - GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx803" + \ - target.substr(4, target.length() - 4)); + GetLLVMTargetMachine("-mtriple=amdgcn-amd-amdhsa-hcc -mcpu=gfx" + \ + std::to_string(val.operator int())+ target.substr(4, target.length() - 4)); std::unique_ptr cg(new CodeGenAMDGPU()); std::unique_ptr ctx(new llvm::LLVMContext()); diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index 4fa91dcbeb3f..d7b4eabf01d4 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -44,7 +44,11 @@ class ROCMDeviceAPI final : public DeviceAPI { value = 64; break; } - case kComputeVersion: return; + case kComputeVersion: + hipDeviceProp_t prop; + ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); + *rv = prop.gcnArch; + return; } *rv = value; }