diff --git a/setup.py b/setup.py index cab39020a2..e55036059a 100644 --- a/setup.py +++ b/setup.py @@ -382,15 +382,26 @@ def compile_KLU(): "pandas": [ "pandas>=1.5.0", ], + # Note: jax and jaxlib must be pinned to a specific version + # to avoid upstream breaking changes. "jax": [ - "jax==0.4.8", - "jaxlib==0.4.7", + # 0.4.18 provides support for Jax on aarch64 containers + # via the PyBaMM images on Docker Hub which come with + # Python 3.11 installed. + # It also provides support for CPU-only Jax on Windows. + "jax==0.4.18; python_version >= '3.9'", + "jaxlib==0.4.18; python_version >= '3.9'", + # Jax 0.4.13 was the last version to support Python 3.8. + # Support for CPU-only Windows was added in 0.4.13, so + # this version supports Windows too. + "jax==0.4.13; python_version < '3.9'", + "jaxlib==0.4.13; python_version < '3.9'", ], "odes": ["scikits.odes"], "all": [ "autograd>=1.6.2", "scikit-fem>=8.1.0", - "pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]" + "pybamm[examples,plot,cite,latexify,bpx,tqdm,pandas]", ], }, entry_points={