Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameter list support to JAX solver (permitting multithreading/GPU execution) #3121

Merged
merged 12 commits into from
Sep 19, 2023

Conversation

jsbrittain
Copy link
Contributor

Description

Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.

Fixes #2644

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all
  • The documentation builds: $ python run-tests.py --doctest

You can run unit and doctests together at once, using $ python run-tests.py --quick.

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

@jsbrittain
Copy link
Contributor Author

supercedes PR #3028

@codecov
Copy link

codecov bot commented Jul 7, 2023

Codecov Report

Patch coverage: 93.61% and project coverage change: +0.02% 🎉

Comparison is base (7cba890) 99.55% compared to head (8ba70fd) 99.57%.
Report is 29 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #3121      +/-   ##
===========================================
+ Coverage    99.55%   99.57%   +0.02%     
===========================================
  Files          253      253              
  Lines        19553    19570      +17     
===========================================
+ Hits         19466    19487      +21     
+ Misses          87       83       -4     
Files Changed Coverage Δ
...bamm/expression_tree/operations/evaluate_python.py 99.29% <ø> (ø)
pybamm/solvers/solution.py 100.00% <ø> (ø)
pybamm/solvers/jax_solver.py 90.69% <79.31%> (-6.08%) ⬇️
pybamm/experiment/experiment.py 100.00% <100.00%> (ø)
pybamm/parameters/bpx.py 99.49% <100.00%> (+4.73%) ⬆️
pybamm/simulation.py 100.00% <100.00%> (ø)
pybamm/solvers/base_solver.py 100.00% <100.00%> (ø)
pybamm/step/_steps_util.py 100.00% <100.00%> (ø)
pybamm/util.py 100.00% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@jsbrittain
Copy link
Contributor Author

Codecov Report

Patch coverage: 84.21% and project coverage change: -0.04 ⚠️

Comparison is base (5486fce) 99.71% compared to head (efc61be) 99.68%.

❗ Current head efc61be differs from pull request most recent head 5abff1d. Consider uploading reports for the commit 5abff1d to get more accurate results

Additional details and impacted files

☔ View full report in Codecov by Sentry. 📢 Do you have feedback about the report comment? Let us know in this issue.

Code coverage is reduced slightly because there are several device specific pathways in the code (gpu vs cpu) that are not currently available for testing in our github CI pipelines.

@jsbrittain
Copy link
Contributor Author

These changes make use of Jax's vmap function to distribute solves over a GPU (where available), and multithreading when no GPU is detected. Metrics are for 100 solves with the following model on a ubuntu server with NVIDIA A40 card.

import pybamm
import numpy as np

model = pybamm.lithium_ion.DFN()
model.convert_to_format = 'jax'
model.events = []  # remove events (not supported in jax)
geometry = model.default_geometry
param = model.default_parameter_values
param.update({"Current function [A]": "[input]"})
param.process_geometry(geometry)
param.process_model(model)

n = 10
k = 5
values = np.linspace(0.1, 0.5, 100)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: n, var.x_s: n, var.x_p: n, var.r_n: k, var.r_p: k}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
t_eval = np.linspace(0, 3600, 100)
solver = pybamm.JaxSolver(atol=1e-6, rtol=1e-6, method="BDF")
inputs = [{"Current function [A]": value} for value in values]
solution = solver.solve(model, t_eval, inputs=inputs)

Solver times:

cpu-multithreading: 114 secs
gpu-enabled: 19 secs

Further development may be possible by making use of Jax's pmap functionality to express single-program multiple-data (SPMD) programs (for running on a rack of GPUs); however, Jax only maintains experimental support for sparse matrices at this time and execution of pmap on pybamm models leads to instabilities (see also jax-ml/jax#13930).

@jsbrittain
Copy link
Contributor Author

@martinjrobins ready for review. as noted above, code-cov is down very slightly as we do not have github runners checking the gpu-specific implementation pathway. the CI fails seem to be due to apparently unrelated issues with lychee and docs (e.g. ubuntu/python3.11 fails installing doc deps after unit and integration tests all pass). benchmarks also don't appear to be running properly so we may need to delay merging until these are fixed.

@jsbrittain jsbrittain marked this pull request as ready for review July 7, 2023 16:29
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this pull request Jul 8, 2023
@jsbrittain
Copy link
Contributor Author

@martinjrobins I've merged recent develop changes into the PR which has resolved most checks. In particular benchmarks are now passing; this just leaves lychee and my previous codecov comment (gh actions don't check gpu-specific pathways).

Copy link
Contributor

@martinjrobins martinjrobins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in reviewing this! Thanks @jsbrittain and happy for it to be merged, I'll open a new issue for adding runners with GPU

@brosaplanella
Copy link
Member

Can this PR be merged or are we waiting for other changes?

@jsbrittain
Copy link
Contributor Author

jsbrittain commented Sep 18, 2023

Can this PR be merged or are we waiting for other changes?

@brosaplanella Yes, I believe this PR is ready to merge.

@brosaplanella
Copy link
Member

@brosaplanella Yes, I believe this PR is ready to merge.

Cool, can you fix the conflict with CHANGELOG and push? Pushing again might also fix the coverage

@jsbrittain
Copy link
Contributor Author

@brosaplanella Yes, I believe this PR is ready to merge.

Cool, can you fix the conflict with CHANGELOG and push? Pushing again might also fix the coverage

Note that coverage remains slightly down due to some gpu-specific pathways in the new code; martin opened an issue to support these going forwards (#3274 ).

@brosaplanella brosaplanella merged commit a68c038 into pybamm-team:develop Sep 19, 2023
30 of 32 checks passed
agriyakhetarpal added a commit to agriyakhetarpal/PyBaMM that referenced this pull request Sep 28, 2023
js1tr3 pushed a commit to js1tr3/PyBaMM that referenced this pull request Aug 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

use massively parallel sundials solvers for running many solves with different input parameters
3 participants