From d1c0d993fc97107d7e1cffa0c7bb8f3f2217095f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 29 Jul 2024 09:27:54 -0700 Subject: [PATCH] Bump the minimum CUDNN version to v9.1. This actually was already the minimum version since we build with that version, but we needed to tighten the constraints. Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it. PiperOrigin-RevId: 657225917 --- CHANGELOG.md | 2 ++ docs/developer.md | 13 ++----------- docs/installation.md | 2 +- jax/_src/xla_bridge.py | 2 +- jax_plugins/cuda/plugin_setup.py | 2 +- 5 files changed, 7 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f6819f0f78af..ad310b9330b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list * xmap has been deleted. Please use {func}`shard_map` as the replacement. * Changes + * The minimum CuDNN version is v9.1. This was true in previous releases also, + but we now declare this version constraint formally. * The minimum Python version is now 3.10. 3.10 will remain the minimum supported version until July 2025. * The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum diff --git a/docs/developer.md b/docs/developer.md index e9ded0b59ffb..e2850d2a94e7 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -114,9 +114,6 @@ particular before each `jaxlib` release. On Windows, follow [Install Visual Studio](https://docs.microsoft.com/en-us/visualstudio/install/install-visual-studio?view=vs-2019) to set up a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required. -If you need to build with CUDA enabled, follow the -[CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) -to set up a CUDA environment. JAX builds use symbolic links, which require that you activate [Developer Mode](https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development). @@ -139,16 +136,10 @@ Once coreutils is installed, the realpath command should be present in your shel Once everything is installed. Open PowerShell, and make sure MSYS2 is in the path of the current session. Ensure `bazel`, `patch` and `realpath` are -accessible. Activate the conda environment. The following command builds with -CUDA enabled, adjust it to whatever suitable for you: +accessible. Activate the conda environment. ``` -python .\build\build.py ` - --enable_cuda ` - --cuda_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' ` - --cudnn_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' ` - --cuda_version='10.1' ` - --cudnn_version='7.6.5' +python .\build\build.py ``` To build with debug information, add the flag `--bazel_options='--copt=/Z7'`. diff --git a/docs/installation.md b/docs/installation.md index fa77d1fc29f6..20ffe436ff8a 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -138,7 +138,7 @@ JAX currently ships one CUDA wheel variant: | Built with | Compatible with | |------------|--------------------| | CUDA 12.3 | CUDA >=12.1 | -| CUDNN 9.0 | CUDNN >=9.0, <10.0 | +| CUDNN 9.1 | CUDNN >=9.1, <10.0 | | NCCL 2.19 | NCCL >=2.18 | JAX checks the versions of your libraries, and will report an error if they are diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index a893f623d00b..243705490f81 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -381,7 +381,7 @@ def _version_check(name: str, # versions: # https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat scale_for_comparison=100, - min_supported_version=9000 + min_supported_version=9100 ) _version_check("cuFFT", cuda_versions.cufft_get_version, cuda_versions.cufft_build_version, diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index cd26731aa629..468c0c48709f 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -57,7 +57,7 @@ def has_ext_modules(self): "nvidia-cuda-cupti-cu12>=12.1.105", "nvidia-cuda-nvcc-cu12>=12.1.105", "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12>=9.0,<10.0", + "nvidia-cudnn-cu12>=9.1,<10.0", "nvidia-cufft-cu12>=11.0.2.54", "nvidia-cusolver-cu12>=11.4.5.107", "nvidia-cusparse-cu12>=12.1.0.106",