Skip to content

Commit

Permalink
Merge pull request #12761 from skye:colab_tpu_driver
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 480499024
  • Loading branch information
jax authors committed Oct 12, 2022
2 parents 6b459f5 + e2aa939 commit 012398b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
-->

## jax 0.3.23
* Changes
* Update Colab TPU driver version for new jaxlib release.

## jaxlib 0.3.23

Expand Down
5 changes: 2 additions & 3 deletions jax/tools/colab_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,15 @@
TPU_DRIVER_MODE = 0


def setup_tpu(tpu_driver_version='tpu_driver-0.2'):
def setup_tpu(tpu_driver_version='tpu_driver_20221011'):
"""Sets up Colab to run on TPU.
Note: make sure the Colab Runtime is set to Accelerator: TPU.
Args
----
tpu_driver_version : (str) specify the version identifier for the tpu driver.
Defaults to "tpu_driver-0.2", which can be used with jaxlib 0.3.20. Set to
"tpu_driver_nightly" to use the nightly tpu driver build.
Set to "tpu_driver_nightly" to use the nightly tpu driver build.
"""
global TPU_DRIVER_MODE

Expand Down

0 comments on commit 012398b

Please sign in to comment.