diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1897b4c94f..084a6f7dbd 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)
-## Features
+## Breaking changes
+- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044))
+- `pybamm_install_jax` is deprecated. It is now replaced with `pip install pybamm[jax]` ([#3163](https://github.com/pybamm-team/PyBaMM/pull/3163))
- Double-layer capacity can now be provided as a function of temperature ([#3174](https://github.com/pybamm-team/PyBaMM/pull/3174))
## Bug fixes
@@ -10,10 +12,7 @@
- Parameters in `Prada2013` have been updated to better match those given in the paper, which is a 2.3 Ah cell, instead of the mix-and-match with the 1.1 Ah cell from Lain2019.
- Error generated when invalid parameter values are passed.
- Thevenin() model is now constructed with standard variables: `Time [s], Time [min], Time [h]` ([#3143](https://github.com/pybamm-team/PyBaMM/pull/3143))
-
-## Breaking changes
-
-- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044))
+- Fix SEI Example Notebook ([#3166](https://github.com/pybamm-team/PyBaMM/pull/3166))
# [v23.5](https://github.com/pybamm-team/PyBaMM/tree/v23.5) - 2023-06-18
diff --git a/docs/source/user_guide/installation/GNU-linux.rst b/docs/source/user_guide/installation/GNU-linux.rst
index d0a18bacec..8f1ee50dbc 100644
--- a/docs/source/user_guide/installation/GNU-linux.rst
+++ b/docs/source/user_guide/installation/GNU-linux.rst
@@ -118,7 +118,7 @@ Currently, only GNU/Linux and macOS are supported.
pybamm_install_odes
The ``pybamm_install_odes`` command is installed with PyBaMM. It automatically downloads and installs the SUNDIALS library on your
- system (under ``~/.local``), before installing ``scikits.odes`` (by running ``pip install scikits.odes``).
+ system (under ``~/.local``), before installing ``scikits.odes``. (Alternatively, one can install SUNDIALS without this script and run ``pip install pybamm[odes]`` to install ``pybamm`` with ``scikits.odes``.)
.. tab:: macOS
@@ -141,9 +141,9 @@ GNU/Linux and macOS
.. code:: bash
- pybamm_install_jax
+ pip install "pybamm[jax]"
-The ``pybamm_install_jax`` command is installed with PyBaMM. It automatically downloads and installs jax and jaxlib on your system.
+The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.)
Developer install
-----------------
diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst
index 1e88ce2780..166bef64d7 100644
--- a/docs/source/user_guide/installation/index.rst
+++ b/docs/source/user_guide/installation/index.rst
@@ -194,6 +194,37 @@ Dependency Minimum Version p
`tqdm `__ \- tqdm For logging loops.
=========================================================== ================== ================== ==================
+.. _install.jax_dependencies:
+
+Jax dependencies
+^^^^^^^^^^^^^^^^^
+
+Installable with ``pip install "pybamm[jax]"``
+
+========================================================================= ================== ================== =======================
+Dependency Minimum Version pip extra Notes
+========================================================================= ================== ================== =======================
+`JAX `__ 0.4.8 jax For JAX solvers
+`jaxlib `__ 0.4.7 jax Support library for JAX
+========================================================================= ================== ================== =======================
+
+.. _install.odes_dependencies:
+
+odes dependencies
+^^^^^^^^^^^^^^^^^
+
+Installable with ``pip install "pybamm[odes]"``
+
+================================================================================================================================ ================== ================== =============================
+Dependency Minimum Version pip extra Notes
+================================================================================================================================ ================== ================== =============================
+`scikits.odes `__ \- odes For scikits ODE & DAE solvers
+================================================================================================================================ ================== ================== =============================
+
+.. note::
+
+ Before running ``pip install "pybamm[odes]"``, make sure to install ``scikits.odes`` build-time requirements as described `here `_ .
+
Full installation guide
-----------------------
diff --git a/noxfile.py b/noxfile.py
index aa07b1524f..2f1231214c 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -63,8 +63,8 @@ def run_coverage(session):
session.install("coverage")
session.install("-e", ".[all]")
if sys.platform != "win32":
- session.install("scikits.odes")
- session.run("pybamm_install_jax")
+ session.install("-e", ".[odes]")
+ session.install("-e", ".[jax]")
session.run("coverage", "run", "--rcfile=.coveragerc", "run-tests.py", "--nosub")
session.run("coverage", "combine")
session.run("coverage", "xml")
@@ -76,7 +76,7 @@ def run_integration(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]")
if sys.platform == "linux":
- session.install("scikits.odes")
+ session.install("-e", ".[odes]")
session.run("python", "run-tests.py", "--integration")
@@ -93,8 +93,8 @@ def run_unit(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]")
if sys.platform == "linux":
- session.install("scikits.odes")
- session.run("pybamm_install_jax")
+ session.install("-e", ".[odes]")
+ session.install("-e", ".[jax]")
session.run("python", "run-tests.py", "--unit")
@@ -129,8 +129,8 @@ def run_tests(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]")
if sys.platform == "linux" or sys.platform == "darwin":
- session.install("scikits.odes")
- session.run("pybamm_install_jax")
+ session.install("-e", ".[odes]")
+ session.install("-e", ".[jax]")
session.run("python", "run-tests.py", "--all")
diff --git a/pybamm/install_odes.py b/pybamm/install_odes.py
index 424367c6ba..4bf310a0f2 100644
--- a/pybamm/install_odes.py
+++ b/pybamm/install_odes.py
@@ -108,6 +108,7 @@ def update_LD_LIBRARY_PATH(install_dir):
def main(arguments=None):
+
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logger = logging.getLogger("scikits.odes setup")
diff --git a/pybamm/util.py b/pybamm/util.py
index 9cef01523a..5f84f37e0a 100644
--- a/pybamm/util.py
+++ b/pybamm/util.py
@@ -15,6 +15,7 @@
import timeit
from platform import system
import difflib
+from warnings import warn
import numpy as np
import pkg_resources
@@ -329,6 +330,11 @@ def install_jax(arguments=None): # pragma: no cover
" following command: \npybamm_install_jax --force"
)
+ msg = (
+ "pybamm_install_jax is deprecated,"
+ " use 'pip install pybamm[jax]' to install jax & jaxlib"
+ )
+ warn(msg, DeprecationWarning)
subprocess.check_call(
[
sys.executable,
diff --git a/setup.py b/setup.py
index d175e7ae80..a391d42fc8 100644
--- a/setup.py
+++ b/setup.py
@@ -255,6 +255,11 @@ def compile_KLU():
"pre-commit", # For code style checking
"ruff", # For code style auto-formatting
],
+ "jax": [
+ "jax==0.4.8",
+ "jaxlib==0.4.7",
+ ],
+ "odes": ["scikits.odes"],
"all": [
"anytree>=2.4.3",
"autograd>=1.2",