You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While sensitivities are not currently supported with the JAX solvers, passing the calculate_sensitivities argument currently results in an internal error due to extending the mass matrix, but not the corresponding initial conditions.
Steps to Reproduce
Pass calculate_sensitivities to Jax BDF solver.solve(). I.e,
Traceback (most recent call last):
File "/Users/Documents/Git/forks/PyBaMM/examples/scripts/compare_lithium_ion.py", line 31, in<module>
solution = solver.solve(model, t_eval, inputs=inputs, calculate_sensitivities=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/base_solver.py", line 905, in solve
new_solutions = self._integrate(
^^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 231, in _integrate
y = asyncio.run(solve_model_for_inputs())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/.pyenv/versions/3.12.2/lib/python3.12/asyncio/runners.py", line 194, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/Users/.pyenv/versions/3.12.2/lib/python3.12/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/.pyenv/versions/3.12.2/lib/python3.12/asyncio/base_events.py", line 685, in run_until_complete
returnfuture.result()
^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 229, in solve_model_for_inputs
return await asyncio.gather(*coro)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 224, in solve_model_async
return self._cached_solves[model](inputs_v)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_solver.py", line 171, in solve_model_bdf
y = pybamm.jax_bdf_integrate(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 1030, in jax_bdf_integrate
return _bdf_odeint_wrapper(converted, mass, rtol, atol, y0, t_eval, *consts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 63, incallerreturn callee(*args)
^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 57, in callee
return fun(*args)
^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 812, in _bdf_odeint_wrapper
out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 117, in _bdf_odeint
stepper = _bdf_init(
^^^^^^^^^^
File "/Users/Documents/Git/forks/PyBaMM/src/pybamm/solvers/jax_bdf_solver.py", line 256, in _bdf_init
state["LU"] = jax.scipy.linalg.lu_factor(state["M"] - c * J)
~~~~~~~~~~~^~~~~~~
File "/Users/.pyenv/versions/pybamm-v24.5/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 737, in op
return getattr(self.aval, f"_{name}")(self, *args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/.pyenv/versions/pybamm-v24.5/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py", line 265, in deferring_binary_op
return binary_op(*args)
^^^^^^^^^^^^^^^^
File "/Users/.pyenv/versions/pybamm-v24.5/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py", line 86, in<lambda>
fn = lambda x1, x2, /: lax_fn(*promote_args(numpy_fn.__name__, x1, x2))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: sub got incompatible shapes for broadcasting: (84, 84), (42, 42).
The text was updated successfully, but these errors were encountered:
PyBaMM Version
develop
Python Version
3.12
Describe the bug
While sensitivities are not currently supported with the JAX solvers, passing the
calculate_sensitivities
argument currently results in an internal error due to extending the mass matrix, but not the corresponding initial conditions.Steps to Reproduce
Pass
calculate_sensitivities
to Jax BDFsolver.solve()
. I.e,Relevant log output
The text was updated successfully, but these errors were encountered: