diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index c49af895480a..a599dbdb97f8 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,14 +22,13 @@ * \file rocm_device_api.cc * \brief GPU specific API */ -#include - #include #include #include #include +#include #include -#include "../../../include/tvm/runtime/device_api.h" + #include "rocm_common.h" namespace tvm { @@ -55,19 +54,57 @@ class ROCMDeviceAPI final : public DeviceAPI { break; } case kMaxThreadsPerBlock: { - value = 1024; + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - value = 64; + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, + ctx.device_id)); break; } - case kMaxSharedMemoryPerBlock: return; - case kComputeVersion: - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; + case kMaxSharedMemoryPerBlock: { + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id)); + break; + } + case kComputeVersion: { + std::ostringstream os; + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); + os << value << "."; + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); + os << value; + *rv = os.str(); + return; + } + case kDeviceName: + return; + case kMaxClockRate: { + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, + ctx.device_id)); + break; + } + case kMultiProcessorCount: { + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); + break; + } + case kMaxThreadDimensions: { + int dims[3]; + ROCM_CALL(hipDeviceGetAttribute( + &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute( + &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute( + &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); + + std::stringstream ss; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; + *rv = ss.str(); + return; + } case kGcnArch: { hipDeviceProp_t prop; ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); @@ -77,14 +114,11 @@ class ROCMDeviceAPI final : public DeviceAPI { } *rv = value; } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, TVMType type_hint) final { ROCM_CALL(hipSetDevice(ctx.device_id)); - CHECK_EQ(256 % alignment, 0U) - << "ROCM space is aligned at 256 bytes"; - void *ret; + CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; + void* ret; ROCM_CALL(hipMalloc(&ret, nbytes)); return ret; } @@ -94,14 +128,9 @@ class ROCMDeviceAPI final : public DeviceAPI { ROCM_CALL(hipFree(ptr)); } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - TVMType type_hint, + void CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, TVMType type_hint, TVMStreamHandle stream) final { hipStream_t hip_stream = static_cast(stream); from = static_cast(from) + from_offset; @@ -111,14 +140,15 @@ class ROCMDeviceAPI final : public DeviceAPI { if (ctx_from.device_id == ctx_to.device_id) { GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); } else { - hipMemcpyPeerAsync(to, ctx_to.device_id, - from, ctx_from.device_id, - size, hip_stream); + hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, size, + hip_stream); } - } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) { + } else if (ctx_from.device_type == kDLROCM && + ctx_to.device_type == kDLCPU) { ROCM_CALL(hipSetDevice(ctx_from.device_id)); GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); - } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) { + } else if (ctx_from.device_type == kDLCPU && + ctx_to.device_type == kDLROCM) { ROCM_CALL(hipSetDevice(ctx_to.device_id)); GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); } else { @@ -132,8 +162,7 @@ class ROCMDeviceAPI final : public DeviceAPI { } void SetStream(TVMContext ctx, TVMStreamHandle stream) final { - ROCMThreadEntry::ThreadLocal() - ->stream = static_cast(stream); + ROCMThreadEntry::ThreadLocal()->stream = static_cast(stream); } void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { @@ -151,11 +180,8 @@ class ROCMDeviceAPI final : public DeviceAPI { } private: - static void GPUCopy(const void* from, - void* to, - size_t size, - hipMemcpyKind kind, - hipStream_t stream) { + static void GPUCopy(const void* from, void* to, size_t size, + hipMemcpyKind kind, hipStream_t stream) { if (stream != 0) { ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); } else { @@ -166,19 +192,16 @@ class ROCMDeviceAPI final : public DeviceAPI { typedef dmlc::ThreadLocalStore ROCMThreadStore; -ROCMThreadEntry::ROCMThreadEntry() - : pool(kDLROCM, ROCMDeviceAPI::Global()) { -} +ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { return ROCMThreadStore::Get(); } TVM_REGISTER_GLOBAL("device_api.rocm") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); - + .set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); + *rv = static_cast(ptr); + }); } // namespace runtime } // namespace tvm