diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index c49af895480ac..657fe58924059 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -22,14 +22,12 @@ * \file rocm_device_api.cc * \brief GPU specific API */ -#include - #include #include #include #include +#include #include -#include "../../../include/tvm/runtime/device_api.h" #include "rocm_common.h" namespace tvm { @@ -55,130 +53,153 @@ class ROCMDeviceAPI final : public DeviceAPI { break; } case kMaxThreadsPerBlock: { - value = 1024; + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeMaxThreadsPerBlock, ctx.device_id)); break; } case kWarpSize: { - value = 64; + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeWarpSize, + ctx.device_id)); + break; + } + case kMaxSharedMemoryPerBlock: { + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeMaxSharedMemoryPerBlock, ctx.device_id)); + break; + } + case kComputeVersion: { + std::ostringstream os; + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeComputeCapabilityMajor, ctx.device_id)); + os << value << "."; + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeComputeCapabilityMinor, ctx.device_id)); + os << value; + *rv = os.str(); + return; + } + case kDeviceName: + return; + case kMaxClockRate: { + ROCM_CALL(hipDeviceGetAttribute(&value, hipDeviceAttributeClockRate, + ctx.device_id)); break; } - case kMaxSharedMemoryPerBlock: return; - case kComputeVersion: - case kDeviceName: return; - case kMaxClockRate: return; - case kMultiProcessorCount: return; - case kMaxThreadDimensions: return; + case kMultiProcessorCount: { + ROCM_CALL(hipDeviceGetAttribute( + &value, hipDeviceAttributeMultiprocessorCount, ctx.device_id)); + break; + } + case kMaxThreadDimensions: { + int dims[3]; + ROCM_CALL(hipDeviceGetAttribute( + &dims[0], hipDeviceAttributeMaxBlockDimX, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute( + &dims[1], hipDeviceAttributeMaxBlockDimY, ctx.device_id)); + ROCM_CALL(hipDeviceGetAttribute( + &dims[2], hipDeviceAttributeMaxBlockDimZ, ctx.device_id)); + + std::stringstream ss; + ss << "[" << dims[0] << ", " << dims[1] << ", " << dims[2] << "]"; + *rv = ss.str(); + return; + } case kGcnArch: { hipDeviceProp_t prop; ROCM_CALL(hipGetDeviceProperties(&prop, ctx.device_id)); *rv = prop.gcnArch; return; } + *rv = value; + } + void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, + TVMType type_hint) final { + ROCM_CALL(hipSetDevice(ctx.device_id)); + CHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; + void* ret; + ROCM_CALL(hipMalloc(&ret, nbytes)); + return ret; } - *rv = value; - } - void* AllocDataSpace(TVMContext ctx, - size_t nbytes, - size_t alignment, - TVMType type_hint) final { - ROCM_CALL(hipSetDevice(ctx.device_id)); - CHECK_EQ(256 % alignment, 0U) - << "ROCM space is aligned at 256 bytes"; - void *ret; - ROCM_CALL(hipMalloc(&ret, nbytes)); - return ret; - } - void FreeDataSpace(TVMContext ctx, void* ptr) final { - ROCM_CALL(hipSetDevice(ctx.device_id)); - ROCM_CALL(hipFree(ptr)); - } + void FreeDataSpace(TVMContext ctx, void* ptr) final { + ROCM_CALL(hipSetDevice(ctx.device_id)); + ROCM_CALL(hipFree(ptr)); + } - void CopyDataFromTo(const void* from, - size_t from_offset, - void* to, - size_t to_offset, - size_t size, - TVMContext ctx_from, - TVMContext ctx_to, - TVMType type_hint, - TVMStreamHandle stream) final { - hipStream_t hip_stream = static_cast(stream); - from = static_cast(from) + from_offset; - to = static_cast(to) + to_offset; - if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) { - ROCM_CALL(hipSetDevice(ctx_from.device_id)); - if (ctx_from.device_id == ctx_to.device_id) { - GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); + void CopyDataFromTo(const void* from, size_t from_offset, void* to, + size_t to_offset, size_t size, TVMContext ctx_from, + TVMContext ctx_to, TVMType type_hint, + TVMStreamHandle stream) final { + hipStream_t hip_stream = static_cast(stream); + from = static_cast(from) + from_offset; + to = static_cast(to) + to_offset; + if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLROCM) { + ROCM_CALL(hipSetDevice(ctx_from.device_id)); + if (ctx_from.device_id == ctx_to.device_id) { + GPUCopy(from, to, size, hipMemcpyDeviceToDevice, hip_stream); + } else { + hipMemcpyPeerAsync(to, ctx_to.device_id, from, ctx_from.device_id, + size, hip_stream); + } + } else if (ctx_from.device_type == kDLROCM && + ctx_to.device_type == kDLCPU) { + ROCM_CALL(hipSetDevice(ctx_from.device_id)); + GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); + } else if (ctx_from.device_type == kDLCPU && + ctx_to.device_type == kDLROCM) { + ROCM_CALL(hipSetDevice(ctx_to.device_id)); + GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); } else { - hipMemcpyPeerAsync(to, ctx_to.device_id, - from, ctx_from.device_id, - size, hip_stream); + LOG(FATAL) << "expect copy from/to GPU or between GPU"; } - } else if (ctx_from.device_type == kDLROCM && ctx_to.device_type == kDLCPU) { - ROCM_CALL(hipSetDevice(ctx_from.device_id)); - GPUCopy(from, to, size, hipMemcpyDeviceToHost, hip_stream); - } else if (ctx_from.device_type == kDLCPU && ctx_to.device_type == kDLROCM) { - ROCM_CALL(hipSetDevice(ctx_to.device_id)); - GPUCopy(from, to, size, hipMemcpyHostToDevice, hip_stream); - } else { - LOG(FATAL) << "expect copy from/to GPU or between GPU"; } - } - - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - ROCM_CALL(hipSetDevice(ctx.device_id)); - ROCM_CALL(hipStreamSynchronize(static_cast(stream))); - } - void SetStream(TVMContext ctx, TVMStreamHandle stream) final { - ROCMThreadEntry::ThreadLocal() - ->stream = static_cast(stream); - } + void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { + ROCM_CALL(hipSetDevice(ctx.device_id)); + ROCM_CALL(hipStreamSynchronize(static_cast(stream))); + } - void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { - return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); - } + void SetStream(TVMContext ctx, TVMStreamHandle stream) final { + ROCMThreadEntry::ThreadLocal()->stream = static_cast(stream); + } - void FreeWorkspace(TVMContext ctx, void* data) final { - ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); - } + void* AllocWorkspace(TVMContext ctx, size_t size, TVMType type_hint) final { + return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); + } - static const std::shared_ptr& Global() { - static std::shared_ptr inst = - std::make_shared(); - return inst; - } + void FreeWorkspace(TVMContext ctx, void* data) final { + ROCMThreadEntry::ThreadLocal()->pool.FreeWorkspace(ctx, data); + } - private: - static void GPUCopy(const void* from, - void* to, - size_t size, - hipMemcpyKind kind, - hipStream_t stream) { - if (stream != 0) { - ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); - } else { - ROCM_CALL(hipMemcpy(to, from, size, kind)); + static const std::shared_ptr& Global() { + static std::shared_ptr inst = + std::make_shared(); + return inst; } - } -}; -typedef dmlc::ThreadLocalStore ROCMThreadStore; + private: + static void GPUCopy(const void* from, void* to, size_t size, + hipMemcpyKind kind, hipStream_t stream) { + if (stream != 0) { + ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); + } else { + ROCM_CALL(hipMemcpy(to, from, size, kind)); + } + } + }; -ROCMThreadEntry::ROCMThreadEntry() - : pool(kDLROCM, ROCMDeviceAPI::Global()) { -} + typedef dmlc::ThreadLocalStore ROCMThreadStore; -ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { - return ROCMThreadStore::Get(); -} + ROCMThreadEntry::ROCMThreadEntry() : pool(kDLROCM, ROCMDeviceAPI::Global()) {} -TVM_REGISTER_GLOBAL("device_api.rocm") -.set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); - *rv = static_cast(ptr); - }); + ROCMThreadEntry* ROCMThreadEntry::ThreadLocal() { + return ROCMThreadStore::Get(); + } + TVM_REGISTER_GLOBAL("device_api.rocm") + .set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global().get(); + *rv = static_cast(ptr); + }); } // namespace runtime } // namespace tvm