diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index 3f6686d05ee8..7fa1c0a5348c 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -36,6 +36,11 @@ '0x005e', ] +_NVIDIA_GPU_DEVICES = [ + '/dev/nvidia0', + '/dev/dxg', # WSL2 +] + def num_available_tpu_chips_and_device_id(): """Returns the device id and number of TPU chips attached through PCI.""" num_chips = 0 @@ -57,3 +62,9 @@ def tpu_enhanced_barrier_supported() -> bool: """Returns if tpu_enhanced_barrier flag is supported on this TPU version.""" _, device_id = num_available_tpu_chips_and_device_id() return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED + + +def has_visible_nvidia_gpu() -> bool: + """True if there's a visible nvidia gpu available on device, False otherwise.""" + + return any(os.path.exists(d) for d in _NVIDIA_GPU_DEVICES) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index a893f623d00b..c0903d8cb70f 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -881,6 +881,9 @@ def backends() -> dict[str, xla_client.Client]: default_priority = -1000 for platform, priority, fail_quietly in platform_registrations: try: + if platform == "cuda" and not hardware_utils.has_visible_nvidia_gpu(): + continue + backend = _init_backend(platform) _backends[platform] = backend @@ -918,12 +921,7 @@ def _suggest_missing_backends(): assert _default_backend is not None default_platform = _default_backend.platform - nvidia_gpu_devices = [ - "/dev/nvidia0", - "/dev/dxg", # WSL2 - ] - if ("cuda" not in _backends and - any(os.path.exists(d) for d in nvidia_gpu_devices)): + if "cuda" not in _backends and hardware_utils.has_visible_nvidia_gpu(): if hasattr(xla_extension, "GpuAllocatorConfig") and "cuda" in _backend_errors: err = _backend_errors["cuda"] warning_msg = f"CUDA backend failed to initialize: {err}."