From bea40bf0fa086622c94017b40e0793467ead7936 Mon Sep 17 00:00:00 2001 From: Akshita Bhagia Date: Mon, 2 Oct 2023 10:47:14 -0700 Subject: [PATCH] jax, flax versions --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5548dfd91..933688f17 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,8 +89,8 @@ fairscale = [ ] flax = [ "datasets>=1.12,<3.0", - "jax>=0.3.13", - "flax>=0.5.0", + "jax>=0.4.1,<=0.4.13", + "flax>=0.5.0,<=0.7.0", "optax>=0.1.2", "tensorflow-cpu>=2.9.1" ]