Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: JAX BDF solver sensitivities bug #4455

Closed
BradyPlanden opened this issue Sep 23, 2024 · 0 comments · Fixed by #4456
Closed

[Bug]: JAX BDF solver sensitivities bug #4455

BradyPlanden opened this issue Sep 23, 2024 · 0 comments · Fixed by #4456
Assignees
Labels
bug Something isn't working

Comments

@BradyPlanden
Copy link
Member

PyBaMM Version

develop

Python Version

3.12

Describe the bug

While sensitivities are not currently supported with the JAX solvers, passing the calculate_sensitivities argument currently results in an internal error due to extending the mass matrix, but not the corresponding initial conditions.

Steps to Reproduce

Pass calculate_sensitivities to Jax BDF solver.solve(). I.e,

...
solver = pybamm.JaxSolver(atol=1e-6, rtol=1e-6, method="BDF")
sol = solver.solve(model, t_eval, calculate_sensitivities=True)

Relevant log output

Traceback (most recent call last):
  File "/Users/Documents/Git/forks/PyBaMM/examples/scripts/compare_lithium_ion.py", line 31, in <module>
    solution = solver.solve(model, t_eval, inputs=inputs, calculate_sensitivities=True)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/base_solver.py", line 905, in solve
    new_solutions = self._integrate(
                    ^^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 231, in _integrate
    y = asyncio.run(solve_model_for_inputs())
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/.pyenv/versions/3.12.2/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
           ^^^^^^^^^^^^^^^^
  File "/Users/.pyenv/versions/3.12.2/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/.pyenv/versions/3.12.2/lib/python3.12/asyncio/base_events.py", line 685, in run_until_complete
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 229, in solve_model_for_inputs
    return await asyncio.gather(*coro)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 224, in solve_model_async
    return self._cached_solves[model](inputs_v)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 171, in solve_model_bdf
    y = pybamm.jax_bdf_integrate(
        ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 1030, in jax_bdf_integrate
    return _bdf_odeint_wrapper(converted, mass, rtol, atol, y0, t_eval, *consts, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 63, in caller
    return callee(*args)
           ^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 57, in callee
    return fun(*args)
           ^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 812, in _bdf_odeint_wrapper
    out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 117, in _bdf_odeint
    stepper = _bdf_init(
              ^^^^^^^^^^
  File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 256, in _bdf_init
    state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J)
                                             ~~~~~~~~~~~^~~~~~~
  File "/Users/.pyenv/versions/pybamm-v24.5/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 737, in op
    return getattr(self.aval, f"_{name}")(self, *args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/.pyenv/versions/pybamm-v24.5/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
    return binary_op(*args)
           ^^^^^^^^^^^^^^^^
  File "/Users/.pyenv/versions/pybamm-v24.5/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 86, in <lambda>
    fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: sub got incompatible shapes for broadcasting: (84, 84), (42, 42).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant