From 8ab158574da0823165e29981aeb554ddfa524815 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Tue, 7 Feb 2023 15:45:28 -0800 Subject: [PATCH] Update WORKSPACE and setup.py for jax/jaxlib 0.4.3 release --- WORKSPACE | 6 +++--- jax/tools/colab_tpu.py | 2 +- setup.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 981ca11fd39a..4f6fea9202bf 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -7,10 +7,10 @@ load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # and update the sha256 with the result. http_archive( name = "org_tensorflow", - sha256 = "eee3984f41375d438c18fdfae71f66673a35cefe883ef309f4c6a9a4ffc16218", - strip_prefix = "tensorflow-bbabe41b20327819d2472425d43178c35ce2b1bf", + sha256 = "8d85c2a03b60fdb7152e810ac9e7380f9ebd677f382f7ea93f2e88af6f23ed26", + strip_prefix = "tensorflow-5e52982c02c0f5d88eaa449408abe3bdcac9ee05", urls = [ - "https://github.com/tensorflow/tensorflow/archive/bbabe41b20327819d2472425d43178c35ce2b1bf.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/5e52982c02c0f5d88eaa449408abe3bdcac9ee05.tar.gz", ], ) diff --git a/jax/tools/colab_tpu.py b/jax/tools/colab_tpu.py index 6c3060aac059..d9077f3f4b68 100644 --- a/jax/tools/colab_tpu.py +++ b/jax/tools/colab_tpu.py @@ -22,7 +22,7 @@ TPU_DRIVER_MODE = 0 -def setup_tpu(tpu_driver_version='tpu_driver_20230120'): +def setup_tpu(tpu_driver_version='tpu_driver_20230207'): """Sets up Colab to run on TPU. Note: make sure the Colab Runtime is set to Accelerator: TPU. diff --git a/setup.py b/setup.py index c7a4e4b54ae4..ce20b1455f2e 100644 --- a/setup.py +++ b/setup.py @@ -19,14 +19,14 @@ from setuptools import setup, find_packages -_current_jaxlib_version = '0.4.2' +_current_jaxlib_version = '0.4.3' # The following should be updated with each new jaxlib release. _latest_jaxlib_version_on_pypi = '0.4.2' _available_cuda_versions = ['11'] _default_cuda_version = '11' _available_cudnn_versions = ['82', '86'] _default_cudnn_version = '86' -_libtpu_version = '0.1.dev20230124' +_libtpu_version = '0.1.dev20230207' _dct = {} with open('jax/version.py', encoding='utf-8') as f: