Skip to content

Commit

Permalink
Merge pull request #3163 from arjxn-py/make-jax-optional
Browse files Browse the repository at this point in the history
Make `jax` & `odes` optional
  • Loading branch information
Saransh-cpp authored Jul 28, 2023
2 parents d588ec1 + fe704bf commit 695917e
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 15 deletions.
9 changes: 4 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)

## Features
## Breaking changes

- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044))
- `pybamm_install_jax` is deprecated. It is now replaced with `pip install pybamm[jax]` ([#3163](https://github.com/pybamm-team/PyBaMM/pull/3163))
- Double-layer capacity can now be provided as a function of temperature ([#3174](https://github.com/pybamm-team/PyBaMM/pull/3174))

## Bug fixes
Expand All @@ -10,10 +12,7 @@
- Parameters in `Prada2013` have been updated to better match those given in the paper, which is a 2.3 Ah cell, instead of the mix-and-match with the 1.1 Ah cell from Lain2019.
- Error generated when invalid parameter values are passed.
- Thevenin() model is now constructed with standard variables: `Time [s], Time [min], Time [h]` ([#3143](https://github.com/pybamm-team/PyBaMM/pull/3143))

## Breaking changes

- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044))
- Fix SEI Example Notebook ([#3166](https://github.com/pybamm-team/PyBaMM/pull/3166))

# [v23.5](https://github.com/pybamm-team/PyBaMM/tree/v23.5) - 2023-06-18

Expand Down
6 changes: 3 additions & 3 deletions docs/source/user_guide/installation/GNU-linux.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ Currently, only GNU/Linux and macOS are supported.
pybamm_install_odes

The ``pybamm_install_odes`` command is installed with PyBaMM. It automatically downloads and installs the SUNDIALS library on your
system (under ``~/.local``), before installing ``scikits.odes`` (by running ``pip install scikits.odes``).
system (under ``~/.local``), before installing ``scikits.odes``. (Alternatively, one can install SUNDIALS without this script and run ``pip install pybamm[odes]`` to install ``pybamm`` with ``scikits.odes``.)

.. tab:: macOS

Expand All @@ -141,9 +141,9 @@ GNU/Linux and macOS

.. code:: bash
pybamm_install_jax
pip install "pybamm[jax]"
The ``pybamm_install_jax`` command is installed with PyBaMM. It automatically downloads and installs jax and jaxlib on your system.
The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.)

Developer install
-----------------
Expand Down
31 changes: 31 additions & 0 deletions docs/source/user_guide/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,37 @@ Dependency Minimum Version p
`tqdm <https://tqdm.github.io/>`__ \- tqdm For logging loops.
=========================================================== ================== ================== ==================

.. _install.jax_dependencies:

Jax dependencies
^^^^^^^^^^^^^^^^^

Installable with ``pip install "pybamm[jax]"``

========================================================================= ================== ================== =======================
Dependency Minimum Version pip extra Notes
========================================================================= ================== ================== =======================
`JAX <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`__ 0.4.8 jax For JAX solvers
`jaxlib <https://pypi.org/project/jaxlib/>`__ 0.4.7 jax Support library for JAX
========================================================================= ================== ================== =======================

.. _install.odes_dependencies:

odes dependencies
^^^^^^^^^^^^^^^^^

Installable with ``pip install "pybamm[odes]"``

================================================================================================================================ ================== ================== =============================
Dependency Minimum Version pip extra Notes
================================================================================================================================ ================== ================== =============================
`scikits.odes <https://docs.pybamm.org/en/latest/source/user_guide/installation/GNU-linux.html#optional-scikits-odes-solver>`__ \- odes For scikits ODE & DAE solvers
================================================================================================================================ ================== ================== =============================

.. note::

Before running ``pip install "pybamm[odes]"``, make sure to install ``scikits.odes`` build-time requirements as described `here <https://docs.pybamm.org/en/latest/source/user_guide/installation/GNU-linux.html#optional-scikits-odes-solver>`_ .

Full installation guide
-----------------------

Expand Down
14 changes: 7 additions & 7 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def run_coverage(session):
session.install("coverage")
session.install("-e", ".[all]")
if sys.platform != "win32":
session.install("scikits.odes")
session.run("pybamm_install_jax")
session.install("-e", ".[odes]")
session.install("-e", ".[jax]")
session.run("coverage", "run", "--rcfile=.coveragerc", "run-tests.py", "--nosub")
session.run("coverage", "combine")
session.run("coverage", "xml")
Expand All @@ -76,7 +76,7 @@ def run_integration(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]")
if sys.platform == "linux":
session.install("scikits.odes")
session.install("-e", ".[odes]")
session.run("python", "run-tests.py", "--integration")


Expand All @@ -93,8 +93,8 @@ def run_unit(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]")
if sys.platform == "linux":
session.install("scikits.odes")
session.run("pybamm_install_jax")
session.install("-e", ".[odes]")
session.install("-e", ".[jax]")
session.run("python", "run-tests.py", "--unit")


Expand Down Expand Up @@ -129,8 +129,8 @@ def run_tests(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("-e", ".[all]")
if sys.platform == "linux" or sys.platform == "darwin":
session.install("scikits.odes")
session.run("pybamm_install_jax")
session.install("-e", ".[odes]")
session.install("-e", ".[jax]")
session.run("python", "run-tests.py", "--all")


Expand Down
1 change: 1 addition & 0 deletions pybamm/install_odes.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def update_LD_LIBRARY_PATH(install_dir):


def main(arguments=None):

log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
logger = logging.getLogger("scikits.odes setup")

Expand Down
6 changes: 6 additions & 0 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import timeit
from platform import system
import difflib
from warnings import warn

import numpy as np
import pkg_resources
Expand Down Expand Up @@ -329,6 +330,11 @@ def install_jax(arguments=None): # pragma: no cover
" following command: \npybamm_install_jax --force"
)

msg = (
"pybamm_install_jax is deprecated,"
" use 'pip install pybamm[jax]' to install jax & jaxlib"
)
warn(msg, DeprecationWarning)
subprocess.check_call(
[
sys.executable,
Expand Down
5 changes: 5 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def compile_KLU():
"pre-commit", # For code style checking
"ruff", # For code style auto-formatting
],
"jax": [
"jax==0.4.8",
"jaxlib==0.4.7",
],
"odes": ["scikits.odes"],
"all": [
"anytree>=2.4.3",
"autograd>=1.2",
Expand Down

0 comments on commit 695917e

Please sign in to comment.