Skip to content

Commit

Permalink
[Runtime] Utils to Stringify Device (#15630)
Browse files Browse the repository at this point in the history
There exist some basic functionality to convert Device and DLDeviceType
to std::string, but they are not following the common naming convention
in TVM, and thus less discoverable. This commit makes changes
accordingly:
- `runtime::DeviceName` to `runtime::DLDeviceType2Str`
- move declaration of `operator << (std::ostream&, Device)` from
  `runtime/device_api.h` to `runtime/packed_func.h`
  • Loading branch information
junrushao committed Aug 27, 2023
1 parent bebf590 commit d3f9d3d
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 85 deletions.
1 change: 1 addition & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
throw;
}

inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
Expand Down
50 changes: 1 addition & 49 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,54 +245,6 @@ class TVM_DLL DeviceAPI {
constexpr int kRPCSessMask = 128;
static_assert(kRPCSessMask >= TVMDeviceExtType_End);

/*!
* \brief The name of Device API factory.
* \param type The device type.
* \return the device name.
*/
inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU:
return "cpu";
case kDLCUDA:
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLCUDAManaged:
return "cuda_managed";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
return "sdaccel";
case kDLAOCL:
return "aocl";
case kDLVulkan:
return "vulkan";
case kDLMetal:
return "metal";
case kDLVPI:
return "vpi";
case kDLROCM:
return "rocm";
case kDLROCMHost:
return "rocm_host";
case kDLExtDev:
return "ext_dev";
case kDLOneAPI:
return "oneapi";
case kDLWebGPU:
return "webgpu";
case kDLHexagon:
return "hexagon";
case kOpenGL:
return "opengl";
case kDLMicroDev:
return "microdev";
default:
LOG(FATAL) << "unknown type =" << type;
}
}

/*!
* \brief Return true if a Device is owned by an RPC session.
*/
Expand Down Expand Up @@ -324,7 +276,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-";
dev = tvm::runtime::RemoveRPCSessionMask(dev);
}
os << tvm::runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
os << tvm::runtime::DLDeviceType2Str(static_cast<int>(dev.device_type)) << ":" << dev.device_id;
return os;
}

Expand Down
107 changes: 96 additions & 11 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ class TVMArgs {
*/
inline const char* ArgTypeCode2Str(int type_code);

inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*)

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)
Expand Down Expand Up @@ -1257,6 +1259,56 @@ inline const char* ArgTypeCode2Str(int type_code) {
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
throw;
}

/*!
* \brief The name of DLDeviceType.
* \param type The device type.
* \return the device name.
*/
inline const char* DLDeviceType2Str(int type) {
switch (type) {
case kDLCPU:
return "cpu";
case kDLCUDA:
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLCUDAManaged:
return "cuda_managed";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
return "sdaccel";
case kDLAOCL:
return "aocl";
case kDLVulkan:
return "vulkan";
case kDLMetal:
return "metal";
case kDLVPI:
return "vpi";
case kDLROCM:
return "rocm";
case kDLROCMHost:
return "rocm_host";
case kDLExtDev:
return "ext_dev";
case kDLOneAPI:
return "oneapi";
case kDLWebGPU:
return "webgpu";
case kDLHexagon:
return "hexagon";
case kOpenGL:
return "opengl";
case kDLMicroDev:
return "microdev";
default:
LOG(FATAL) << "unknown type = " << type;
}
throw;
}

namespace detail {
Expand Down Expand Up @@ -1284,13 +1336,27 @@ namespace parameter_pack {

template <typename... EnumArgs>
struct EnumeratedParamPack {
struct Invoke {
template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
static void F(ExtraParams&&... extra_params) {
struct InvokeWithoutArg {
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
static void F(ExtraParams&& extra_params) {
using TExpander = int[];
(void)TExpander{
0,
(Functor<EnumArgs::i, typename EnumArgs::T>::F(extra_params...), 0)...,
(Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params)),
0)...,
};
}
};
struct InvokeWithArg {
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams,
typename... Params>
static void F(ExtraParams&& extra_params, Params&&... params) {
using TExpander = int[];
(void)TExpander{
0,
(Functor<EnumArgs::i, typename EnumArgs::T>::F(std::forward<ExtraParams>(extra_params),
std::forward<Params>(params)),
0)...,
};
}
};
Expand All @@ -1310,22 +1376,27 @@ struct EnumerateImpl {

template <std::size_t... id>
struct Zipper<std::integer_sequence<std::size_t, id...>> {
using T = EnumeratedParamPack<Item<id, Args>...>;
using WithoutArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithoutArg;
using WithArg = typename EnumeratedParamPack<Item<id, Args>...>::InvokeWithArg;
};

public:
using T = typename Zipper<std::index_sequence_for<Args...>>::T;
using WithoutArg = typename Zipper<std::index_sequence_for<Args...>>::WithoutArg;
using WithArg = typename Zipper<std::index_sequence_for<Args...>>::WithArg;
};

template <typename... Args>
using Enumerate = typename EnumerateImpl<Args...>::T;
using EnumerateWithoutArg = typename EnumerateImpl<Args...>::WithoutArg;

template <typename... Args>
using EnumerateWithArg = typename EnumerateImpl<Args...>::WithArg;

template <typename... Args>
struct ParamPack {
template <template <size_t i, typename TArgument> class Functor, typename... ExtraParams>
static void InvokeWithoutArg(ExtraParams&&... extra_params) {
Enumerate<Args...>::Invoke::template F<Functor, ExtraParams...>(
std::forward<ExtraParams>(extra_params)...);
template <template <size_t i, typename TArgument> class Functor, typename ExtraParams>
static void InvokeWithoutArg(ExtraParams&& extra_params) {
EnumerateWithoutArg<Args...>::template F<Functor, ExtraParams>(
std::forward<ExtraParams>(extra_params));
}
};

Expand Down Expand Up @@ -1622,6 +1693,20 @@ inline TVMRetValue PackedFunc::operator()(Args&&... args) const {
return rv;
}

template <int i, typename T>
struct TVMArgsSetterApply {
static TVM_ALWAYS_INLINE void F(TVMArgsSetter* setter, T&& value) {
(*setter)(i, std::forward<T>(value));
}
};

template <typename... Args>
void TVM_ALWAYS_INLINE PackArgs(TVMValue* values, int* type_codes, Args&&... args) {
TVMArgsSetter setter(values, type_codes);
detail::parameter_pack::EnumerateWithArg<Args...>::template F<TVMArgsSetterApply>(
&setter, std::forward<Args>(args)...);
}

namespace detail {
template <typename R, int nleft, int index, typename F>
struct unpack_call_dispatcher {
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return FloatImm(t, static_cast<double>(value), span);
}
LOG(FATAL) << "cannot make const for type " << t;
throw;
}

template <>
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class DeviceAPIManager {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/profiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ std::set<DLDeviceType> seen_devices;
std::mutex seen_devices_lock;

Timer Timer::Start(Device dev) {
auto f = Registry::Get(std::string("profiling.timer.") + DeviceName(dev.device_type));
auto f = Registry::Get(std::string("profiling.timer.") + DLDeviceType2Str(dev.device_type));
if (f == nullptr) {
{
std::lock_guard<std::mutex> lock(seen_devices_lock);
if (seen_devices.find(dev.device_type) == seen_devices.end()) {
LOG(WARNING)
<< "No timer implementation for " << DeviceName(dev.device_type)
<< "No timer implementation for " << DLDeviceType2Str(dev.device_type)
<< ", using default timer instead. It may be inaccurate or have extra overhead.";
seen_devices.insert(dev.device_type);
}
Expand Down Expand Up @@ -652,7 +652,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con
}

std::string DeviceString(Device dev) {
return DeviceName(dev.device_type) + std::to_string(dev.device_id);
return DLDeviceType2Str(dev.device_type) + std::to_string(dev.device_id);
}

Report Profiler::Report() {
Expand Down
18 changes: 8 additions & 10 deletions src/runtime/relax_vm/memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
* \file tvm/runtime/relax_vm/memory_manager.cc
* \brief Allocate and manage memory for the Relay VM.
*/
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/relax_vm/memory_manager.h>

Expand All @@ -29,7 +31,6 @@

#include "naive_allocator.h"
#include "pooled_allocator.h"
#include "tvm/runtime/memory.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -132,14 +133,12 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
std::unique_ptr<Allocator> alloc;
switch (type) {
case kNaive: {
DLOG(INFO) << "New naive allocator for " << runtime::DeviceName(dev.device_type) << "("
<< dev.device_id << ")";
DLOG(INFO) << "New naive allocator for " << dev;
alloc.reset(new NaiveAllocator(dev));
break;
}
case kPooled: {
DLOG(INFO) << "New pooled allocator for " << runtime::DeviceName(dev.device_type) << "("
<< dev.device_id << ")";
DLOG(INFO) << "New pooled allocator for " << dev;
alloc.reset(new PooledAllocator(dev));
break;
}
Expand All @@ -152,9 +151,9 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
}
auto alloc = m->allocators_.at(dev).get();
if (alloc->type() != type) {
LOG(WARNING) << "The type of existing allocator for " << runtime::DeviceName(dev.device_type)
<< "(" << dev.device_id << ") is different from the request type ("
<< alloc->type() << " vs " << type << ")";
LOG(WARNING) << "The type of existing allocator for " << dev
<< " is different from the request type (" << alloc->type() << " vs " << type
<< ")";
}
return alloc;
}
Expand All @@ -164,8 +163,7 @@ Allocator* MemoryManager::GetAllocator(Device dev) {
std::lock_guard<std::mutex> lock(m->mutex_);
auto it = m->allocators_.find(dev);
if (it == m->allocators_.end()) {
LOG(FATAL) << "Allocator for " << runtime::DeviceName(dev.device_type) << "(" << dev.device_id
<< ") has not been created yet.";
LOG(FATAL) << "Allocator for " << dev << " has not been created yet.";
}
return it->second.get();
}
Expand Down
1 change: 1 addition & 0 deletions src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class RPCModuleNode final : public ModuleNode {

String GetSource(const String& format) final {
LOG(FATAL) << "GetSource for rpc Module is not supported";
throw;
}

PackedFunc GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat,
Expand Down
15 changes: 6 additions & 9 deletions src/runtime/vm/memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,12 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
std::unique_ptr<Allocator> alloc;
switch (type) {
case kNaive: {
VLOG(1) << "New naive allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
<< ")";
VLOG(1) << "New naive allocator for " << dev;
alloc.reset(new NaiveAllocator(dev));
break;
}
case kPooled: {
VLOG(1) << "New pooled allocator for " << DeviceName(dev.device_type) << "("
<< dev.device_id << ")";
VLOG(1) << "New pooled allocator for " << dev;
alloc.reset(new PooledAllocator(dev));
break;
}
Expand All @@ -139,9 +137,9 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
}
auto alloc = m->allocators_.at(dev).get();
if (alloc->type() != type) {
LOG(WARNING) << "The type of existing allocator for " << DeviceName(dev.device_type) << "("
<< dev.device_id << ") is different from the request type (" << alloc->type()
<< " vs " << type << ")";
LOG(WARNING) << "The type of existing allocator for " << dev
<< " is different from the request type (" << alloc->type() << " vs " << type
<< ")";
}
return alloc;
}
Expand All @@ -151,8 +149,7 @@ Allocator* MemoryManager::GetAllocator(Device dev) {
std::lock_guard<std::mutex> lock(m->mu_);
auto it = m->allocators_.find(dev);
if (it == m->allocators_.end()) {
LOG(FATAL) << "Allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
<< ") has not been created yet.";
LOG(FATAL) << "Allocator for " << dev << " has not been created yet.";
}
return it->second.get();
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ class BuiltinLower : public StmtExprMutator {
<< "but was instead the expression " << device_type_ << " with type "
<< device_type_.value()->GetTypeKey();

String device_name = runtime::DeviceName(as_int->value);
String device_name = runtime::DLDeviceType2Str(as_int->value);
return StringImm("device_api." + device_name + "." + method_name);
}

Expand Down Expand Up @@ -602,7 +602,7 @@ class BuiltinLower : public StmtExprMutator {
let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;

std::string fdevapi_prefix = "device_api.";
fdevapi_prefix += runtime::DeviceName(device_type_.as<IntImmNode>()->value);
fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as<IntImmNode>()->value);

Array<PrimExpr> args = {
GetDeviceMethodName("alloc_nd"),
Expand Down

0 comments on commit d3f9d3d

Please sign in to comment.