-
-
Notifications
You must be signed in to change notification settings - Fork 553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add parameter list support to JAX solver (permitting multithreading/GPU execution) #3121
Conversation
supercedes PR #3028 |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## develop #3121 +/- ##
===========================================
+ Coverage 99.55% 99.57% +0.02%
===========================================
Files 253 253
Lines 19553 19570 +17
===========================================
+ Hits 19466 19487 +21
+ Misses 87 83 -4
☔ View full report in Codecov by Sentry. |
Code coverage is reduced slightly because there are several device specific pathways in the code (gpu vs cpu) that are not currently available for testing in our github CI pipelines. |
These changes make use of Jax's import pybamm
import numpy as np
model = pybamm.lithium_ion.DFN()
model.convert_to_format = 'jax'
model.events = [] # remove events (not supported in jax)
geometry = model.default_geometry
param = model.default_parameter_values
param.update({"Current function [A]": "[input]"})
param.process_geometry(geometry)
param.process_model(model)
n = 10
k = 5
values = np.linspace(0.1, 0.5, 100)
var = pybamm.standard_spatial_vars
var_pts = {var.x_n: n, var.x_s: n, var.x_p: n, var.r_n: k, var.r_p: k}
mesh = pybamm.Mesh(geometry, model.default_submesh_types, var_pts)
disc = pybamm.Discretisation(mesh, model.default_spatial_methods)
disc.process_model(model)
t_eval = np.linspace(0, 3600, 100)
solver = pybamm.JaxSolver(atol=1e-6, rtol=1e-6, method="BDF")
inputs = [{"Current function [A]": value} for value in values]
solution = solver.solve(model, t_eval, inputs=inputs) Solver times:
Further development may be possible by making use of Jax's |
@martinjrobins ready for review. as noted above, code-cov is down very slightly as we do not have github runners checking the gpu-specific implementation pathway. the CI fails seem to be due to apparently unrelated issues with lychee and docs (e.g. ubuntu/python3.11 fails installing doc deps after unit and integration tests all pass). benchmarks also don't appear to be running properly so we may need to delay merging until these are fixed. |
…uppress warning
@martinjrobins I've merged recent develop changes into the PR which has resolved most checks. In particular benchmarks are now passing; this just leaves lychee and my previous codecov comment (gh actions don't check gpu-specific pathways). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in reviewing this! Thanks @jsbrittain and happy for it to be merged, I'll open a new issue for adding runners with GPU
Can this PR be merged or are we waiting for other changes? |
@brosaplanella Yes, I believe this PR is ready to merge. |
Cool, can you fix the conflict with CHANGELOG and push? Pushing again might also fix the coverage |
Note that coverage remains slightly down due to some gpu-specific pathways in the new code; martin opened an issue to support these going forwards (#3274 ). |
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes #2644
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ pre-commit run
(see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ python run-tests.py --all
$ python run-tests.py --doctest
You can run unit and doctests together at once, using
$ python run-tests.py --quick
.Further checks: