Skip to content

Commit

Permalink
[Runtime] Utils to Stringify Device
Browse files Browse the repository at this point in the history
There exist some basic functionality to convert Device and DLDeviceType
to std::string, but they are not following the common naming convention
in TVM, and thus less discoverable. This commit makes changes
accordingly:
- `runtime::DeviceName` to `runtime::DLDeviceType2Str`
- move declaration of `operator << (std::ostream&, Device)` from
  `runtime/device_api.h` to `runtime/packed_func.h`
  • Loading branch information
junrushao committed Aug 28, 2023
1 parent 909c8fa commit 3e43778
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 68 deletions.
1 change: 1 addition & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
throw;
}

inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
Expand Down
50 changes: 1 addition & 49 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,54 +245,6 @@ class TVM_DLL DeviceAPI {
constexpr int kRPCSessMask = 128;
static_assert(kRPCSessMask >= TVMDeviceExtType_End);

/*!
* \brief The name of Device API factory.
* \param type The device type.
* \return the device name.
*/
inline const char* DeviceName(int type) {
switch (type) {
case kDLCPU:
return "cpu";
case kDLCUDA:
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLCUDAManaged:
return "cuda_managed";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
return "sdaccel";
case kDLAOCL:
return "aocl";
case kDLVulkan:
return "vulkan";
case kDLMetal:
return "metal";
case kDLVPI:
return "vpi";
case kDLROCM:
return "rocm";
case kDLROCMHost:
return "rocm_host";
case kDLExtDev:
return "ext_dev";
case kDLOneAPI:
return "oneapi";
case kDLWebGPU:
return "webgpu";
case kDLHexagon:
return "hexagon";
case kOpenGL:
return "opengl";
case kDLMicroDev:
return "microdev";
default:
LOG(FATAL) << "unknown type =" << type;
}
}

/*!
* \brief Return true if a Device is owned by an RPC session.
*/
Expand Down Expand Up @@ -324,7 +276,7 @@ inline std::ostream& operator<<(std::ostream& os, DLDevice dev) { // NOLINT(*)
os << "remote[" << tvm::runtime::GetRPCSessionIndex(dev) << "]-";
dev = tvm::runtime::RemoveRPCSessionMask(dev);
}
os << tvm::runtime::DeviceName(static_cast<int>(dev.device_type)) << "(" << dev.device_id << ")";
os << tvm::runtime::DLDeviceType2Str(static_cast<int>(dev.device_type)) << ":" << dev.device_id;
return os;
}

Expand Down
52 changes: 52 additions & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ class TVMArgs {
*/
inline const char* ArgTypeCode2Str(int type_code);

inline std::ostream& operator<<(std::ostream& os, DLDevice dev); // NOLINT(*)

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
ICHECK_EQ(CODE, T) << "expected " << ArgTypeCode2Str(T) << " but got " << ArgTypeCode2Str(CODE)
Expand Down Expand Up @@ -1257,6 +1259,56 @@ inline const char* ArgTypeCode2Str(int type_code) {
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
}
throw;
}

/*!
* \brief The name of DLDeviceType.
* \param type The device type.
* \return the device name.
*/
inline const char* DLDeviceType2Str(int type) {
switch (type) {
case kDLCPU:
return "cpu";
case kDLCUDA:
return "cuda";
case kDLCUDAHost:
return "cuda_host";
case kDLCUDAManaged:
return "cuda_managed";
case kDLOpenCL:
return "opencl";
case kDLSDAccel:
return "sdaccel";
case kDLAOCL:
return "aocl";
case kDLVulkan:
return "vulkan";
case kDLMetal:
return "metal";
case kDLVPI:
return "vpi";
case kDLROCM:
return "rocm";
case kDLROCMHost:
return "rocm_host";
case kDLExtDev:
return "ext_dev";
case kDLOneAPI:
return "oneapi";
case kDLWebGPU:
return "webgpu";
case kDLHexagon:
return "hexagon";
case kOpenGL:
return "opengl";
case kDLMicroDev:
return "microdev";
default:
LOG(FATAL) << "unknown type = " << type;
}
throw;
}

namespace detail {
Expand Down
1 change: 1 addition & 0 deletions include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span())
return FloatImm(t, static_cast<double>(value), span);
}
LOG(FATAL) << "cannot make const for type " << t;
throw;
}

template <>
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/c_runtime_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class DeviceAPIManager {
if (api_[type] != nullptr) return api_[type];
std::lock_guard<std::mutex> lock(mutex_);
if (api_[type] != nullptr) return api_[type];
api_[type] = GetAPI(DeviceName(type), allow_missing);
api_[type] = GetAPI(DLDeviceType2Str(type), allow_missing);
return api_[type];
} else {
if (rpc_api_ != nullptr) return rpc_api_;
Expand Down
7 changes: 4 additions & 3 deletions src/runtime/contrib/papi/papi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ int component_for_device(Device dev) {
component_name = "rocm";
break;
default:
LOG(WARNING) << "PAPI does not support device " << DeviceName(dev.device_type);
LOG(WARNING) << "PAPI does not support device " << DLDeviceType2Str(dev.device_type);
return -1;
}
int cidx = PAPI_get_component_index(component_name.c_str());
Expand Down Expand Up @@ -170,8 +170,9 @@ struct PAPIMetricCollectorNode final : public MetricCollectorNode {
default:
break;
}
LOG(WARNING) << "PAPI could not initialize counters for " << DeviceName(device.device_type)
<< ": " << component->disabled_reason << "\n"
LOG(WARNING) << "PAPI could not initialize counters for "
<< DLDeviceType2Str(device.device_type) << ": " << component->disabled_reason
<< "\n"
<< help_message;
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ void* HexagonDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shap
// until the AoT executor's multi-device dispatch code is mature. --cconvey 2022-08-26
CHECK(dev.device_type == kDLHexagon)
<< "dev.device_type: " << dev.device_type << " DeviceName(" << dev.device_type
<< "): " << DeviceName(dev.device_type) << "";
<< "): " << DLDeviceType2Str(dev.device_type) << "";

CHECK(ndim >= 0 && ndim <= 2)
<< "Hexagon Device API supports only 1d and 2d allocations, but received ndim = " << ndim;
Expand Down
6 changes: 3 additions & 3 deletions src/runtime/profiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ std::set<DLDeviceType> seen_devices;
std::mutex seen_devices_lock;

Timer Timer::Start(Device dev) {
auto f = Registry::Get(std::string("profiling.timer.") + DeviceName(dev.device_type));
auto f = Registry::Get(std::string("profiling.timer.") + DLDeviceType2Str(dev.device_type));
if (f == nullptr) {
{
std::lock_guard<std::mutex> lock(seen_devices_lock);
if (seen_devices.find(dev.device_type) == seen_devices.end()) {
LOG(WARNING)
<< "No timer implementation for " << DeviceName(dev.device_type)
<< "No timer implementation for " << DLDeviceType2Str(dev.device_type)
<< ", using default timer instead. It may be inaccurate or have extra overhead.";
seen_devices.insert(dev.device_type);
}
Expand Down Expand Up @@ -652,7 +652,7 @@ String ReportNode::AsTable(bool sort, bool aggregate, bool compute_col_sums) con
}

std::string DeviceString(Device dev) {
return DeviceName(dev.device_type) + std::to_string(dev.device_id);
return DLDeviceType2Str(dev.device_type) + std::to_string(dev.device_id);
}

Report Profiler::Report() {
Expand Down
1 change: 1 addition & 0 deletions src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ class RPCModuleNode final : public ModuleNode {

String GetSource(const String& format) final {
LOG(FATAL) << "GetSource for rpc Module is not supported";
throw;
}

PackedFunc GetTimeEvaluator(const std::string& name, Device dev, int number, int repeat,
Expand Down
15 changes: 6 additions & 9 deletions src/runtime/vm/memory_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,12 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
std::unique_ptr<Allocator> alloc;
switch (type) {
case kNaive: {
VLOG(1) << "New naive allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
<< ")";
VLOG(1) << "New naive allocator for " << dev;
alloc.reset(new NaiveAllocator(dev));
break;
}
case kPooled: {
VLOG(1) << "New pooled allocator for " << DeviceName(dev.device_type) << "("
<< dev.device_id << ")";
VLOG(1) << "New pooled allocator for " << dev;
alloc.reset(new PooledAllocator(dev));
break;
}
Expand All @@ -139,9 +137,9 @@ Allocator* MemoryManager::GetOrCreateAllocator(Device dev, AllocatorType type) {
}
auto alloc = m->allocators_.at(dev).get();
if (alloc->type() != type) {
LOG(WARNING) << "The type of existing allocator for " << DeviceName(dev.device_type) << "("
<< dev.device_id << ") is different from the request type (" << alloc->type()
<< " vs " << type << ")";
LOG(WARNING) << "The type of existing allocator for " << dev
<< " is different from the request type (" << alloc->type() << " vs " << type
<< ")";
}
return alloc;
}
Expand All @@ -151,8 +149,7 @@ Allocator* MemoryManager::GetAllocator(Device dev) {
std::lock_guard<std::mutex> lock(m->mu_);
auto it = m->allocators_.find(dev);
if (it == m->allocators_.end()) {
LOG(FATAL) << "Allocator for " << DeviceName(dev.device_type) << "(" << dev.device_id
<< ") has not been created yet.";
LOG(FATAL) << "Allocator for " << dev << " has not been created yet.";
}
return it->second.get();
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ class BuiltinLower : public StmtExprMutator {
<< "but was instead the expression " << device_type_ << " with type "
<< device_type_.value()->GetTypeKey();

String device_name = runtime::DeviceName(as_int->value);
String device_name = runtime::DLDeviceType2Str(as_int->value);
return StringImm("device_api." + device_name + "." + method_name);
}

Expand Down Expand Up @@ -595,7 +595,7 @@ class BuiltinLower : public StmtExprMutator {
let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;

std::string fdevapi_prefix = "device_api.";
fdevapi_prefix += runtime::DeviceName(device_type_.as<IntImmNode>()->value);
fdevapi_prefix += runtime::DLDeviceType2Str(device_type_.as<IntImmNode>()->value);

Array<PrimExpr> args = {
GetDeviceMethodName("alloc_nd"),
Expand Down

0 comments on commit 3e43778

Please sign in to comment.