Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 1104 mass matrix for bdf #1107

Merged
merged 40 commits into from
Jul 20, 2020
Merged

Conversation

martinjrobins
Copy link
Contributor

@martinjrobins martinjrobins commented Jul 13, 2020

Description

  • adds mass matrix support for jax bdf solver, so it can solve semi-explicit dae models
  • also adds sensitivity support for ODE and semi-explicit DAE models using the adjoint method
  • limitations: currently limited to support for index 1 semi-explicit dae models, but this includes spm, spme and dfn models (that I've tested)
  • while doing this, the bdf solver has been substantially refactored to remove lax.cond operators (which are slow), which also has the benefit of making the code more readable

Fixes #1104

Type of change

  • New feature (non-breaking change which adds functionality)

Key checklist:

  • No style issues: $ flake8
  • All tests pass: $ python run-tests.py --unit
  • The documentation builds: $ cd docs and then $ make clean; make html

You can run all three at once, using $ python run-tests.py --quick.

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

@codecov
Copy link

codecov bot commented Jul 13, 2020

Codecov Report

Merging #1107 into develop will increase coverage by 0.00%.
The diff coverage is 98.63%.

Impacted file tree graph

@@           Coverage Diff            @@
##           develop    #1107   +/-   ##
========================================
  Coverage    97.81%   97.82%           
========================================
  Files          245      245           
  Lines        13179    13269   +90     
========================================
+ Hits         12891    12980   +89     
- Misses         288      289    +1     
Impacted Files Coverage Δ
pybamm/solvers/jax_solver.py 98.14% <95.45%> (+1.71%) ⬆️
pybamm/solvers/jax_bdf_solver.py 99.04% <98.83%> (-0.35%) ⬇️
pybamm/solvers/base_solver.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4d10b27...f6c51f9. Read the comment docs.

@martinjrobins martinjrobins force-pushed the issue-1104-mass-matrix-for-bdf branch from 08d811b to ec0ab2b Compare July 13, 2020 11:39
@martinjrobins martinjrobins force-pushed the issue-1104-mass-matrix-for-bdf branch from ec0ab2b to 107f7ce Compare July 13, 2020 16:21
@martinjrobins martinjrobins marked this pull request as ready for review July 15, 2020 07:16
@martinjrobins martinjrobins removed the request for review from valentinsulzer July 15, 2020 12:01
@martinjrobins martinjrobins marked this pull request as draft July 15, 2020 12:47
@martinjrobins martinjrobins marked this pull request as ready for review July 16, 2020 11:06
Copy link
Member

@valentinsulzer valentinsulzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @martinjrobins , this looks great! Could you add the jax solver to the compare-dae-solvers.py example?
So goes doing jax.grad do adjoint sensitivities for you?
We should coordinate how the sensitivity API will work, let's discuss this on the PR for sensitivity for the casadi solver that I will open soon

pybamm/solvers/jax_solver.py Outdated Show resolved Hide resolved
@martinjrobins
Copy link
Contributor Author

martinjrobins commented Jul 19, 2020

yes, we should iterate on the API, I just put in something temporary for now. The solver can return a pure solve function that takes in inputs and returns a solution ndarray. This can then be differentiated using jax.grad or jax.jacrev (not jax.jaxfwd). Ideally you would take the solve function, and wrap it in some sort of scalar error measure, and then jax.grad that (see sensitivity tests in test_jax_bdf_solver.py and test_jax.solver.py

I've added in the jax solver tocompare-dae-solvers.py, but be aware that it takes a long time. The first time you call the dfn solver jax needs to trace and compile the entire algorithm, along with the dfn equations. This takes 11 minutes for the dfn solver (it is significantly quicker for simpler models such as spm or spme). Subsequent calls to the solver take 4 seconds

@valentinsulzer
Copy link
Member

Ah ok, yeah 11 minutes to run an example isn't great, perhaps shouldn't include it in that light. Thanks for the info re: sensitivities :)

@martinjrobins martinjrobins merged commit b76042c into develop Jul 20, 2020
@martinjrobins martinjrobins deleted the issue-1104-mass-matrix-for-bdf branch July 20, 2020 07:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

modify JAX bdf solver to allow solution of semi-explicit DAEs
2 participants