Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip cuda backend initialization if no nvidia GPUs are visible #22703

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions jax/_src/hardware_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@
'0x005e',
]

_NVIDIA_GPU_DEVICES = [
'/dev/nvidia0',
'/dev/dxg', # WSL2
]

def num_available_tpu_chips_and_device_id():
"""Returns the device id and number of TPU chips attached through PCI."""
num_chips = 0
Expand All @@ -57,3 +62,9 @@ def tpu_enhanced_barrier_supported() -> bool:
"""Returns if tpu_enhanced_barrier flag is supported on this TPU version."""
_, device_id = num_available_tpu_chips_and_device_id()
return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED


def has_visible_nvidia_gpu() -> bool:
"""True if there's a visible nvidia gpu available on device, False otherwise."""

return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES)
10 changes: 4 additions & 6 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,9 @@ def backends() -> dict[str, xla_client.Client]:
default_priority = -1000
for platform, priority, fail_quietly in platform_registrations:
try:
if platform == "cuda" and not hardware_utils.has_visible_nvidia_gpu():
Rifur13 marked this conversation as resolved.
Show resolved Hide resolved
continue

backend = _init_backend(platform)
_backends[platform] = backend

Expand Down Expand Up @@ -918,12 +921,7 @@ def _suggest_missing_backends():

assert _default_backend is not None
default_platform = _default_backend.platform
nvidia_gpu_devices = [
"/dev/nvidia0",
"/dev/dxg", # WSL2
]
if ("cuda" not in _backends and
any(os.path.exists(d) for d in nvidia_gpu_devices)):
if "cuda" not in _backends and hardware_utils.has_visible_nvidia_gpu():
if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors:
err = _backend_errors["cuda"]
warning_msg = f"CUDA backend failed to initialize: {err}."
Expand Down