Skip to content

Commit

Permalink
[Runtime] Use cudaGetDeviceCount to check if device exists (#16377)
Browse files Browse the repository at this point in the history
Using `cudaDeviceGetAttribute` will set the global error code when the
device doesn't exist and will impact subsequent CUDA API calls.
  • Loading branch information
vinx13 authored Jan 10, 2024
1 parent 3166366 commit 524ec5f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/runtime/cuda/cuda_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ class CUDADeviceAPI final : public DeviceAPI {
int value = 0;
switch (kind) {
case kExist:
value = (cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id) ==
cudaSuccess);
int count;
CUDA_CALL(cudaGetDeviceCount(&count));
value = static_cast<int>(dev.device_id < count);
break;
case kMaxThreadsPerBlock: {
CUDA_CALL(cudaDeviceGetAttribute(&value, cudaDevAttrMaxThreadsPerBlock, dev.device_id));
Expand Down

0 comments on commit 524ec5f

Please sign in to comment.