Skip to content

Commit

Permalink
Bump the minimum CUDNN version to v9.1.
Browse files Browse the repository at this point in the history
This actually was already the minimum version since we build with that version, but we needed to tighten the constraints.

Also in passing, drop mentions of CUDA builds from the Windows build instructions. jaxlib hasn't built with CUDA enabled on Windows for a very long time, so it's probably best we just don't mention it.

PiperOrigin-RevId: 657225917
  • Loading branch information
hawkinsp authored and jax authors committed Jul 29, 2024
1 parent 6127baa commit d1c0d99
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 14 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Remember to align the itemized text with the first line of an item within a list
* xmap has been deleted. Please use {func}`shard_map` as the replacement.

* Changes
* The minimum CuDNN version is v9.1. This was true in previous releases also,
but we now declare this version constraint formally.
* The minimum Python version is now 3.10. 3.10 will remain the minimum
supported version until July 2025.
* The minimum NumPy version is now 1.24. NumPy 1.24 will remain the minimum
Expand Down
13 changes: 2 additions & 11 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,6 @@ particular before each `jaxlib` release.

On Windows, follow [Install Visual Studio](https://docs.microsoft.com/en-us/visualstudio/install/install-visual-studio?view=vs-2019)
to set up a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required.
If you need to build with CUDA enabled, follow the
[CUDA Installation Guide](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html)
to set up a CUDA environment.

JAX builds use symbolic links, which require that you activate
[Developer Mode](https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development).
Expand All @@ -139,16 +136,10 @@ Once coreutils is installed, the realpath command should be present in your shel

Once everything is installed. Open PowerShell, and make sure MSYS2 is in the
path of the current session. Ensure `bazel`, `patch` and `realpath` are
accessible. Activate the conda environment. The following command builds with
CUDA enabled, adjust it to whatever suitable for you:
accessible. Activate the conda environment.

```
python .\build\build.py `
--enable_cuda `
--cuda_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
--cudnn_path='C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v10.1' `
--cuda_version='10.1' `
--cudnn_version='7.6.5'
python .\build\build.py
```

To build with debug information, add the flag `--bazel_options='--copt=/Z7'`.
Expand Down
2 changes: 1 addition & 1 deletion docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ JAX currently ships one CUDA wheel variant:
| Built with | Compatible with |
|------------|--------------------|
| CUDA 12.3 | CUDA >=12.1 |
| CUDNN 9.0 | CUDNN >=9.0, <10.0 |
| CUDNN 9.1 | CUDNN >=9.1, <10.0 |
| NCCL 2.19 | NCCL >=2.18 |

JAX checks the versions of your libraries, and will report an error if they are
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def _version_check(name: str,
# versions:
# https://docs.nvidia.com/deeplearning/cudnn/developer-guide/index.html#api-compat
scale_for_comparison=100,
min_supported_version=9000
min_supported_version=9100
)
_version_check("cuFFT", cuda_versions.cufft_get_version,
cuda_versions.cufft_build_version,
Expand Down
2 changes: 1 addition & 1 deletion jax_plugins/cuda/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def has_ext_modules(self):
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.1.105",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12>=9.0,<10.0",
"nvidia-cudnn-cu12>=9.1,<10.0",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
Expand Down

0 comments on commit d1c0d99

Please sign in to comment.