diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index f7f50e0c3a..db912d0da1 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -30,6 +30,10 @@ class JaxSolver(pybamm.BaseSolver): method: str 'RK45' (default) uses jax.experimental.odeint 'BDF' uses custom jax_bdf_integrate (see jax_bdf_integrate.py for details) + root_method: str, optional + Method to use to calculate consistent initial conditions. By default this uses + the newton chord method internal to the jax bdf solver, otherwise choose from + the set of default options defined in docs for pybamm.BaseSolver rtol : float, optional The relative tolerance for the solver (default is 1e-6). atol : float, optional @@ -41,10 +45,11 @@ class JaxSolver(pybamm.BaseSolver): for details. """ - def __init__(self, method='RK45', rtol=1e-6, atol=1e-6, extra_options=None): + def __init__(self, method='RK45', root_method=None, + rtol=1e-6, atol=1e-6, extra_options=None): # note: bdf solver itself calculates consistent initial conditions so can set - # root_method to none - super().__init__(method, rtol, atol, root_method=None) + # root_method to none, allow user to override this behavior + super().__init__(method, rtol, atol, root_method=root_method) method_options = ['RK45', 'BDF'] if method not in method_options: raise ValueError('method must be one of {}'.format(method_options))