diff --git a/setup.py b/setup.py index 144f2b4957bf..2eab9cd46bae 100644 --- a/setup.py +++ b/setup.py @@ -22,13 +22,13 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.19' +_current_jaxlib_version = '0.4.20' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.19' _available_cuda11_cudnn_versions = ['86'] _default_cuda11_cudnn_version = '86' _default_cuda12_cudnn_version = '89' -_libtpu_version = '0.1.dev20231018' +_libtpu_version = '0.1.dev20231102' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location(