From d3f9d3dfd76f032baf3cd70ef0281f7fdb176240 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 27 Aug 2023 17:52:42 +0000 Subject: [PATCH] [Runtime] Utils to Stringify Device (#15630) There exist some basic functionality to convert Device and DLDeviceType to std::string, but they are not following the common naming convention in TVM, and thus less discoverable. This commit makes changes accordingly: - `runtime::DeviceName` to `runtime::DLDeviceType2Str` - move declaration of `operator << (std::ostream&, Device)` from `runtime/device_api.h` to `runtime/packed_func.h` --- include/tvm/runtime/data_type.h | 1 + include/tvm/runtime/device_api.h | 50 +---------- include/tvm/runtime/packed_func.h | 107 +++++++++++++++++++++--- include/tvm/tir/op.h | 1 + src/runtime/c_runtime_api.cc | 2 +- src/runtime/profiling.cc | 6 +- src/runtime/relax_vm/memory_manager.cc | 18 ++-- src/runtime/rpc/rpc_module.cc | 1 + src/runtime/vm/memory_manager.cc | 15 ++-- src/tir/transforms/lower_tvm_builtin.cc | 4 +- 10 files changed, 120 insertions(+), 85 deletions(-) diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index 9fb113f56b2c..ac7e879a644d 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -339,6 +339,7 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) { default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } + throw; } inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*) diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h index 654018565716..cb0eb7c21f11 100644 --- a/include/tvm/runtime/device_api.h +++ b/include/tvm/runtime/device_api.h @@ -245,54 +245,6 @@ class TVM_DLL DeviceAPI { constexpr int kRPCSessMask = 128; static_assert(kRPCSessMask >= TVMDeviceExtType_End); -/*! - * \brief The name of Device API factory. - * \param type The device type. - * \return the device name. - */ -inline const char* DeviceName(int type) { - switch (type) { - case kDLCPU: - return "cpu"; - case kDLCUDA: - return "cuda"; - case kDLCUDAHost: - return "cuda_host"; - case kDLCUDAManaged: - return "cuda_managed"; - case kDLOpenCL: - return "opencl"; - case kDLSDAccel: - return "sdaccel"; - case kDLAOCL: - return "aocl"; - case kDLVulkan: - return "vulkan"; - case kDLMetal: - return "metal"; - case kDLVPI: - return "vpi"; - case kDLROCM: - return "rocm"; - case kDLROCMHost: - return "rocm_host"; - case kDLExtDev: - return "ext_dev"; - case kDLOneAPI: - return "oneapi"; - case kDLWebGPU: - return "webgpu"; - case kDLHexagon: - return "hexagon"; - case kOpenGL: - return "opengl"; - case kDLMicroDev: - return "microdev"; - default: - LOG(FATAL) << "unknown type =" << type; - } -} - /*! * \brief Return true if a Device is owned by an RPC session. */ @@ -324,7 +276,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*) os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-"; dev = tvm::runtime::RemoveRPCSessionMask(dev); } - os << tvm::runtime::DeviceName(static_cast(dev.device_type)) << "(" << dev.device_id << ")"; + os << tvm::runtime::DLDeviceType2Str(static_cast(dev.device_type)) << ":" << dev.device_id; return os; } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 660c24284b8d..e63e92835cc5 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -418,6 +418,8 @@ class TVMArgs { */ inline const char* ArgTypeCode2Str(int type_code); +inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*) + // macro to check type code. #define TVM_CHECK_TYPE_CODE(CODE, T) \ ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE) @@ -1257,6 +1259,56 @@ inline const char* ArgTypeCode2Str(int type_code) { default: LOG(FATAL) << "unknown type_code=" << static_cast(type_code); } + throw; +} + +/*! + * \brief The name of DLDeviceType. + * \param type The device type. + * \return the device name. + */ +inline const char* DLDeviceType2Str(int type) { + switch (type) { + case kDLCPU: + return "cpu"; + case kDLCUDA: + return "cuda"; + case kDLCUDAHost: + return "cuda_host"; + case kDLCUDAManaged: + return "cuda_managed"; + case kDLOpenCL: + return "opencl"; + case kDLSDAccel: + return "sdaccel"; + case kDLAOCL: + return "aocl"; + case kDLVulkan: + return "vulkan"; + case kDLMetal: + return "metal"; + case kDLVPI: + return "vpi"; + case kDLROCM: + return "rocm"; + case kDLROCMHost: + return "rocm_host"; + case kDLExtDev: + return "ext_dev"; + case kDLOneAPI: + return "oneapi"; + case kDLWebGPU: + return "webgpu"; + case kDLHexagon: + return "hexagon"; + case kOpenGL: + return "opengl"; + case kDLMicroDev: + return "microdev"; + default: + LOG(FATAL) << "unknown type = " << type; + } + throw; } namespace detail { @@ -1284,13 +1336,27 @@ namespace parameter_pack { template struct EnumeratedParamPack { - struct Invoke { - template