diff --git a/CHANGELOG.md b/CHANGELOG.md index 8e35a3cc2018..bcab0b12725a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. --> ## jax 0.3.23 +* Changes + * Update Colab TPU driver version for new jaxlib release. ## jaxlib 0.3.23 diff --git a/jax/tools/colab_tpu.py b/jax/tools/colab_tpu.py index 4540de658a0d..35278a2e5e80 100644 --- a/jax/tools/colab_tpu.py +++ b/jax/tools/colab_tpu.py @@ -22,7 +22,7 @@ TPU_DRIVER_MODE = 0 -def setup_tpu(tpu_driver_version='tpu_driver-0.2'): +def setup_tpu(tpu_driver_version='tpu_driver_20221011'): """Sets up Colab to run on TPU. Note: make sure the Colab Runtime is set to Accelerator: TPU. @@ -30,8 +30,7 @@ def setup_tpu(tpu_driver_version='tpu_driver-0.2'): Args ---- tpu_driver_version : (str) specify the version identifier for the tpu driver. - Defaults to "tpu_driver-0.2", which can be used with jaxlib 0.3.20. Set to - "tpu_driver_nightly" to use the nightly tpu driver build. + Set to "tpu_driver_nightly" to use the nightly tpu driver build. """ global TPU_DRIVER_MODE