diff --git a/CHANGELOG.md b/CHANGELOG.md index 7334d46015..1090184dd7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ## Features +- Added support for index 1 semi-explicit dae equations and sensitivity calculations to JAX BDF solver ([#1107](https://github.com/pybamm-team/PyBaMM/pull/1107)) - Allowed keyword arguments to be passed to `Simulation.plot()` ([#1099](https://github.com/pybamm-team/PyBaMM/pull/1099)) ## Optimizations diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 4451a7b9fd..1b32fdf751 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -451,9 +451,12 @@ def calculate_consistent_state(self, model, time=0, inputs=None): ------- y0_consistent : array-like, same shape as y0_guess Initial conditions that are consistent with the algebraic equations (roots - of the algebraic equations) + of the algebraic equations). If self.root_method == None then returns + model.y0. """ pybamm.logger.info("Start calculating consistent states") + if self.root_method is None: + return model.y0 try: root_sol = self.root_method._integrate(model, [time], inputs) except pybamm.SolverError as e: diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 8161d3ea3a..59c5459d12 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -1,36 +1,56 @@ -import jax -import jax.numpy as np +import operator as op +import numpy as onp +import collections +import jax +import jax.numpy as jnp +from jax import core +from jax import dtypes +from jax.util import safe_map, cache, split_list +from jax.api_util import flatten_fun_nokwargs +from jax.flatten_util import ravel_pytree +from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_multimap, partial +from jax.interpreters import partial_eval as pe +from jax import linear_util as lu from jax.config import config + config.update("jax_enable_x64", True) +MAX_ORDER = 5 +NEWTON_MAXITER = 4 +ROOT_SOLVE_MAXITER = 15 +MIN_FACTOR = 0.2 +MAX_FACTOR = 10 + -def jax_bdf_integrate(fun, y0, t_eval, jac=None, inputs=None, rtol=1e-6, atol=1e-6): +@jax.partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3)) +def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args): """ - Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm - is derived in [2]_. This particular implementation follows that implemented in the - Matlab routine ode15s described in [1]_ and the SciPy implementation [3]_, which - features the NDF formulas for improved stability, with associated differences in the - error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This - implementation was based on that implemented in the scipy library [3]_, which also - mainly follows [1]_ but uses the more standard jacobian update. + This implements a Backward Difference formula (BDF) implicit multistep integrator. + The basic algorithm is derived in [2]_. This particular implementation follows that + implemented in the Matlab routine ode15s described in [1]_ and the SciPy + implementation [3]_, which features the NDF formulas for improved stability, with + associated differences in the error constants, and calculates the jacobian at + J(t_{n+1}, y^0_{n+1}). This implementation was based on that implemented in the + scipy library [3]_, which also mainly follows [1]_ but uses the more standard + jacobian update. Parameters ---------- - fun: callable - function with signature (t, y, in), where t is a scalar time, y is a ndarray - with shape (n,), in is a dict of input parameters. Returns the rhs of the system - of ODE equations as an nd array with shape (n,) + func: callable + function to evaluate the time derivative of the solution `y` at time + `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`. + mass: ndarray + diagonal of the mass matrix with shape (n,) y0: ndarray - initial state vector + initial state vector, has shape (n,) t_eval: ndarray time points to evaluate the solution, has shape (m,) - jac: (optional) callable - function with signature (t, y, in),returns the jacobian matrix of fun as an - ndarray with shape (n,n) - inputs: (optional) dict - dict mapping input parameter names to values + args: (optional) + tuple of additional arguments for `fun`, which must be arrays + scalars, or (nested) standard Python containers (tuples, lists, dicts, + namedtuples, i.e. pytrees) of those types. rtol: (optional) float relative tolerance for the solver atol: (optional) float @@ -41,9 +61,6 @@ def jax_bdf_integrate(fun, y0, t_eval, jac=None, inputs=None, rtol=1e-6, atol=1e y: ndarray with shape (n, m) calculated state vector at each of the m time points - stepper: dict - internal variables of the stepper object - References ---------- .. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI. @@ -57,89 +74,80 @@ def jax_bdf_integrate(fun, y0, t_eval, jac=None, inputs=None, rtol=1e-6, atol=1e Nature methods, 17(3), 261-272. """ - y0_device = jax.device_put(y0).reshape(-1) - t_eval_device = jax.device_put(t_eval) - y_out, stepper = _bdf_odeint(fun, jac, rtol, atol, y0_device, t_eval_device, inputs) - return y_out, stepper - + def fun_bind_inputs(y, t): + return fun(y, t, *args) -MAX_ORDER = 5 -NEWTON_MAXITER = 4 -MIN_FACTOR = 0.2 -MAX_FACTOR = 10 + jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=0) + t0 = t_eval[0] + h0 = t_eval[1] - t0 -def flax_cond(pred, true_operand, true_fun, - false_operand, false_fun): # pragma: no cover - """ - for debugging purposes, use this instead of jax.lax.cond - """ - if pred: - return true_fun(true_operand) - else: - return false_fun(false_operand) + stepper = _bdf_init( + fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol + ) + i = 0 + y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype) + init_state = [stepper, t_eval, i, y_out] -def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover - """ - for debugging purposes, use this instead of jax.lax.while_loop - """ - val = init_val - while cond_fun(val): - val = body_fun(val) - return val + def cond_fun(state): + _, t_eval, i, _ = state + return i < len(t_eval) + def body_fun(state): + stepper, t_eval, i, y_out = state + stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs) + index = jnp.searchsorted(t_eval, stepper.t) -def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover - """ - for debugging purposes, use this instead of jax.lax.fori_loop - """ - val = init_val - for i in range(start, stop): - val = body_fun(i, val) - return val + def for_body(j, y_out): + t = t_eval[j] + y_out = jax.ops.index_update(y_out, jax.ops.index[j, :], + _bdf_interpolate(stepper, t)) + return y_out + y_out = jax.lax.fori_loop(i, index, for_body, y_out) + return [stepper, t_eval, index, y_out] -def _compute_R(order, factor): - """ - computes the R matrix with entries - given by the first equation on page 8 of [1] + stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun, + init_state) + return y_out - This is used to update the differences matrix when step size h is varied according - to factor = h_{n+1} / h_n - Note that the U matrix also defined in the same section can be also be - found using factor = 1, which corresponds to R with a constant step size - """ - I = np.arange(1, MAX_ORDER + 1).reshape(-1, 1) - J = np.arange(1, MAX_ORDER + 1) - M = np.empty((MAX_ORDER + 1, MAX_ORDER + 1)) - M = jax.ops.index_update(M, jax.ops.index[1:, 1:], - (I - 1 - factor * J) / I) - M = jax.ops.index_update(M, jax.ops.index[0], 1) - R = np.cumprod(M, axis=0) +BDFInternalStates = [ + 't', 'atol', 'rtol', 'M', 'newton_tol', 'order', 'h', 'n_equal_steps', 'D', + 'y0', 'scale_y0', 'kappa', 'gamma', 'alpha', 'c', 'error_const', 'J', 'LU', 'U', + 'psi', 'n_function_evals', 'n_jacobian_evals', 'n_lu_decompositions', 'n_steps', + 'consistent_y0_failed' +] +BDFState = collections.namedtuple('BDFState', BDFInternalStates) - return R +jax.tree_util.register_pytree_node( + BDFState, + lambda xs: (tuple(xs), None), + lambda _, xs: BDFState(*xs) +) -def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): +def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): """ Initiation routine for Backward Difference formula (BDF) implicit multistep integrator. - See jax_bdf_solver function above for details, this function returns a dict with the + See _bdf_odeint function above for details, this function returns a dict with the initial state of the solver Parameters ---------- fun: callable - function with signature (t, y), where t is a scalar time and y is a ndarray with + function with signature (y, t), where t is a scalar time and y is a ndarray with shape (n,), returns the rhs of the system of ODE equations as an nd array with shape (n,) jac: callable - function with signature (t, y), where t is a scalar time and y is a ndarray with + function with signature (y, t), where t is a scalar time and y is a ndarray with shape (n,), returns the jacobian matrix of fun as an ndarray with shape (n,n) + mass: ndarray + diagonal of the mass matrix with shape (n,) t0: float initial time y0: ndarray @@ -151,34 +159,39 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): atol: (optional) float absolute tolerance for the solver """ + state = {} state['t'] = t0 - state['y'] = y0 - f0 = fun(t0, y0) state['atol'] = atol state['rtol'] = rtol + state['M'] = mass + EPS = jnp.finfo(y0.dtype).eps + state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5)))) + + scale_y0 = atol + rtol * jnp.abs(y0) + y0, not_converged = _select_initial_conditions( + fun, mass, t0, y0, state['newton_tol'], scale_y0 + ) + state['consistent_y0_failed'] = not_converged + + f0 = fun(y0, t0) order = 1 state['order'] = order - state['h'] = _select_initial_step(state, fun, t0, y0, f0, h0) - EPS = np.finfo(y0.dtype).eps - state['newton_tol'] = np.max((10 * EPS / rtol, np.min((0.03, rtol ** 0.5)))) + state['h'] = _select_initial_step(atol, rtol, fun, t0, y0, f0, h0) state['n_equal_steps'] = 0 - D = np.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) + D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) D = jax.ops.index_update(D, jax.ops.index[0, :], y0) D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state['h']) state['D'] = D - state['y0'] = None - state['scale_y0'] = None - state = _predict(state) - I = np.identity(len(y0), dtype=y0.dtype) - state['I'] = I + state['y0'] = y0 + state['scale_y0'] = scale_y0 # kappa values for difference orders, taken from Table 1 of [1] - kappa = np.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) - gamma = np.hstack((0, np.cumsum(1 / np.arange(1, MAX_ORDER + 1)))) + kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) + gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1)))) alpha = 1.0 / ((1 - kappa) * gamma) c = state['h'] * alpha[order] - error_const = kappa * gamma + 1 / np.arange(1, MAX_ORDER + 2) + error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2) state['kappa'] = kappa state['gamma'] = gamma @@ -186,22 +199,114 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): state['c'] = c state['error_const'] = error_const - J = jac(t0, y0) + J = jac(y0, t0) state['J'] = J - state['LU'] = jax.scipy.linalg.lu_factor(I - c * J) + + state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - c * J) state['U'] = _compute_R(order, 1) state['psi'] = None - state = _update_psi(state) state['n_function_evals'] = 2 state['n_jacobian_evals'] = 1 state['n_lu_decompositions'] = 1 - state['n_error_test_failures'] = 0 - return state + state['n_steps'] = 0 + + tuple_state = BDFState(*[state[k] for k in BDFInternalStates]) + y0, scale_y0 = _predict(tuple_state, D) + psi = _update_psi(tuple_state, D) + return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi) -def _select_initial_step(state, fun, t0, y0, f0, h0): +def _compute_R(order, factor): + """ + computes the R matrix with entries + given by the first equation on page 8 of [1] + + This is used to update the differences matrix when step size h is varied according + to factor = h_{n+1} / h_n + + Note that the U matrix also defined in the same section can be also be + found using factor = 1, which corresponds to R with a constant step size + """ + I = jnp.arange(1, MAX_ORDER + 1).reshape(-1, 1) + J = jnp.arange(1, MAX_ORDER + 1) + M = jnp.empty((MAX_ORDER + 1, MAX_ORDER + 1)) + M = jax.ops.index_update(M, jax.ops.index[1:, 1:], + (I - 1 - factor * J) / I) + M = jax.ops.index_update(M, jax.ops.index[0], 1) + R = jnp.cumprod(M, axis=0) + + return R + + +def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0): + # identify algebraic variables as zeros on diagonal + algebraic_variables = jnp.diag(M == 0.) + + # if all differentiable variables then return y0 (can use normal python if since M + # is static) + if not jnp.any(algebraic_variables): + return y0, False + + # calculate consistent initial conditions via a newton on -J_a @ delta = f_a This + # follows this reference: + # + # Shampine, L. F., Reichelt, M. W., & Kierzenka, J. A. (1999). Solving index-1 DAEs + # in MATLAB and Simulink. SIAM review, 41(3), 538-552. + + # calculate fun_a, function of algebraic variables + def fun_a(y_a): + y_full = jax.ops.index_update(y0, algebraic_variables, y_a) + return fun(y_full, t0)[algebraic_variables] + + y0_a = y0[algebraic_variables] + scale_y0_a = scale_y0[algebraic_variables] + + d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype) + y_a = jnp.array(y0_a, copy=True) + + # calculate neg jacobian of fun_a + J_a = jax.jacfwd(fun_a)(y_a) + LU = jax.scipy.linalg.lu_factor(-J_a) + + converged = False + dy_norm_old = -1.0 + k = 0 + while_state = [k, converged, dy_norm_old, d, y_a] + + def while_cond(while_state): + k, converged, _, _, _ = while_state + return (converged == False) * (k < ROOT_SOLVE_MAXITER) # noqa: E712 + + def while_body(while_state): + k, converged, dy_norm_old, d, y_a = while_state + f_eval = fun_a(y_a) + dy = jax.scipy.linalg.lu_solve(LU, f_eval) + dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0_a)**2)) + rate = dy_norm / dy_norm_old + + d += dy + y_a = y0_a + d + + # if converged then break out of iteration early + pred = dy_norm_old >= 0. + pred *= rate / (1 - rate) * dy_norm < tol + converged = (dy_norm == 0.) + pred + + dy_norm_old = dy_norm + + return [k + 1, converged, dy_norm_old, d, y_a] + + k, converged, dy_norm_old, d, y_a = jax.lax.while_loop(while_cond, + while_body, + while_state) + y_tilde = jax.ops.index_update(y0, algebraic_variables, y_a) + + return y_tilde, converged + + +def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0): """ Select a good initial step by stepping forward one step of forward euler, and comparing the predicted state against that using the provided function. @@ -214,46 +319,46 @@ def _select_initial_step(state, fun, t0, y0, f0, h0): .. [1] E. Hairer, S. P. Norsett G. Wanner, "Solving Ordinary Differential Equations I: Nonstiff Problems", Sec. II.4. """ - scale = state['atol'] + np.abs(y0) * state['rtol'] + scale = atol + jnp.abs(y0) * rtol y1 = y0 + h0 * f0 - f1 = fun(t0 + h0, y1) - d2 = np.sqrt(np.mean(((f1 - f0) / scale)**2)) + f1 = fun(y1, t0 + h0) + d2 = jnp.sqrt(jnp.mean(((f1 - f0) / scale)**2)) order = 1 h1 = h0 * d2 ** (-1 / (order + 1)) - return np.min((100 * h0, h1)) + return jnp.min((100 * h0, h1)) -def _predict(state): +def _predict(state, D): """ predict forward to new step (eq 2 in [1]) """ - n = len(state['y']) - order = state['order'] - orders = np.repeat(np.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1) - subD = np.where(orders <= order, state['D'], 0) - state['y0'] = np.sum(subD, axis=0) - state['scale_y0'] = state['atol'] + state['rtol'] * np.abs(state['y0']) - return state + n = len(state.y0) + order = state.order + orders = jnp.repeat(jnp.arange(MAX_ORDER + 1).reshape(-1, 1), n, axis=1) + subD = jnp.where(orders <= order, D, 0) + y0 = jnp.sum(subD, axis=0) + scale_y0 = state.atol + state.rtol * jnp.abs(state.y0) + return y0, scale_y0 -def _update_psi(state): +def _update_psi(state, D): """ update psi term as defined in second equation on page 9 of [1] """ - order = state['order'] - n = len(state['y']) - orders = np.arange(MAX_ORDER + 1) - subGamma = np.where(orders > 0, np.where(orders <= order, state['gamma'], 0), 0) - orders = np.repeat(orders.reshape(-1, 1), n, axis=1) - subD = np.where(orders > 0, np.where(orders <= order, state['D'], 0), 0) - state['psi'] = np.dot( + order = state.order + n = len(state.y0) + orders = jnp.arange(MAX_ORDER + 1) + subGamma = jnp.where(orders > 0, jnp.where(orders <= order, state.gamma, 0), 0) + orders = jnp.repeat(orders.reshape(-1, 1), n, axis=1) + subD = jnp.where(orders > 0, jnp.where(orders <= order, D, 0), 0) + psi = jnp.dot( subD.T, subGamma - ) * state['alpha'][order] - return state + ) * state.alpha[order] + return psi -def _update_difference_for_next_step(state, d, only_update_D=False): +def _update_difference_for_next_step(state, d): """ update of difference equations can be done efficiently by reusing d and D. @@ -266,8 +371,8 @@ def _update_difference_for_next_step(state, d, only_update_D=False): Combining these gives the following algorithm """ - order = state['order'] - D = state['D'] + order = state.order + D = state.D D = jax.ops.index_update(D, jax.ops.index[order + 2], d - D[order + 1]) D = jax.ops.index_update(D, jax.ops.index[order + 1], @@ -288,74 +393,55 @@ def while_body(while_state): i, D = jax.lax.while_loop(while_cond, while_body, while_state) - state['D'] = D - - def update_psi_and_predict(state): - # update psi (D has changed) - state = _update_psi(state) + return D - # update y0 (D has changed) - state = _predict(state) - return state +def _update_step_size_and_lu(state, factor): + state = _update_step_size(state, factor) - state = jax.lax.cond(only_update_D == False, # noqa: E712 - state, update_psi_and_predict, - state, lambda x: x) + # redo lu (c has changed) + LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J) + n_lu_decompositions = state.n_lu_decompositions + 1 - return state + return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions) -def _update_step_size(state, factor, dont_update_lu): +def _update_step_size(state, factor): """ If step size h is changed then also need to update the terms in the first equation of page 9 of [1]: - constant c = h / (1-kappa) gamma_k term - - lu factorisation of (I - c * J) used in newton iteration (same equation) + - lu factorisation of (M - c * J) used in newton iteration (same equation) - psi term """ - order = state['order'] - h = state['h'] - - h *= factor - state['n_equal_steps'] = 0 - c = h * state['alpha'][order] - - # redo lu (c has changed) - def update_lu(state): - state['LU'] = jax.scipy.linalg.lu_factor(state['I'] - c * state['J']) - state['n_lu_decompositions'] += 1 - return state - - state = jax.lax.cond(dont_update_lu == False, # noqa: E712 - state, update_lu, - state, lambda x: x) - - state['h'] = h - state['c'] = c + order = state.order + h = state.h * factor + n_equal_steps = 0 + c = h * state.alpha[order] # update D using equations in section 3.2 of [1] - RU = _compute_R(order, factor).dot(state['U']) - I = np.arange(0, MAX_ORDER + 1).reshape(-1, 1) - J = np.arange(0, MAX_ORDER + 1) + RU = _compute_R(order, factor).dot(state.U) + I = jnp.arange(0, MAX_ORDER + 1).reshape(-1, 1) + J = jnp.arange(0, MAX_ORDER + 1) # only update order+1, order+1 entries of D - RU = np.where(np.logical_and(I <= order, J <= order), - RU, np.identity(MAX_ORDER + 1)) - D = state['D'] - D = np.dot(RU.T, D) + RU = jnp.where(jnp.logical_and(I <= order, J <= order), + RU, jnp.identity(MAX_ORDER + 1)) + D = state.D + D = jnp.dot(RU.T, D) # D = jax.ops.index_update(D, jax.ops.index[:order + 1], - # np.dot(RU.T, D[:order + 1])) - state['D'] = D + # jnp.dot(RU.T, D[:order + 1])) # update psi (D has changed) - state = _update_psi(state) + psi = _update_psi(state, D) # update y0 (D has changed) - state = _predict(state) + y0, scale_y0 = _predict(state, D) - return state + return state._replace(n_equal_steps=n_equal_steps, + h=h, c=c, + D=D, psi=psi, y0=y0, scale_y0=scale_y0) def _update_jacobian(state, jac): @@ -363,41 +449,43 @@ def _update_jacobian(state, jac): we update the jacobian using J(t_{n+1}, y^0_{n+1}) following the scipy bdf implementation rather than J(t_n, y_n) as per [1] """ - J = jac(state['t'] + state['h'], state['y0']) - state['n_jacobian_evals'] += 1 - state['LU'] = jax.scipy.linalg.lu_factor(state['I'] - state['c'] * J) - state['n_lu_decompositions'] += 1 - state['J'] = J - return state + J = jac(state.y0, state.t + state.h) + n_jacobian_evals = state.n_jacobian_evals + 1 + LU = jax.scipy.linalg.lu_factor(state.M - state.c * J) + n_lu_decompositions = state.n_lu_decompositions + 1 + return state._replace(J=J, n_jacobian_evals=n_jacobian_evals, LU=LU, + n_lu_decompositions=n_lu_decompositions) def _newton_iteration(state, fun): - tol = state['newton_tol'] - c = state['c'] - psi = state['psi'] - y0 = state['y0'] - LU = state['LU'] - scale_y0 = state['scale_y0'] - t = state['t'] + state['h'] - d = np.zeros_like(y0) - y = np.array(y0, copy=True) - - not_converged = True + tol = state.newton_tol + c = state.c + psi = state.psi + y0 = state.y0 + LU = state.LU + M = state.M + scale_y0 = state.scale_y0 + t = state.t + state.h + d = jnp.zeros(y0.shape, dtype=y0.dtype) + y = jnp.array(y0, copy=True) + n_function_evals = state.n_function_evals + + converged = False dy_norm_old = -1.0 k = 0 - while_state = [k, not_converged, dy_norm_old, d, y, state] + while_state = [k, converged, dy_norm_old, d, y, n_function_evals] def while_cond(while_state): - k, not_converged, _, _, _, _ = while_state - return not_converged * (k < NEWTON_MAXITER) + k, converged, _, _, _, _ = while_state + return (converged == False) * (k < NEWTON_MAXITER) # noqa: E712 def while_body(while_state): - k, not_converged, dy_norm_old, d, y, state = while_state - f_eval = fun(t, y) - state['n_function_evals'] += 1 - b = c * f_eval - psi - d + k, converged, dy_norm_old, d, y, n_function_evals = while_state + f_eval = fun(y, t) + n_function_evals += 1 + b = c * f_eval - M @ (psi + d) dy = jax.scipy.linalg.lu_solve(LU, b) - dy_norm = np.sqrt(np.mean((dy / scale_y0)**2)) + dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0)**2)) rate = dy_norm / dy_norm_old # if iteration is not going to converge in NEWTON_MAXITER @@ -405,220 +493,175 @@ def while_body(while_state): pred = rate >= 1 pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol pred *= dy_norm_old >= 0 - k = jax.lax.cond(pred, k, lambda k: NEWTON_MAXITER, k, lambda k: k) + k += pred * (NEWTON_MAXITER - k - 1) d += dy y = y0 + d # if converged then break out of iteration early - pred = dy_norm_old >= 0 + pred = dy_norm_old >= 0. pred *= rate / (1 - rate) * dy_norm < tol - pred += dy_norm == 0 + converged = (dy_norm == 0.) + pred - def converged_fun(not_converged): - not_converged = False - return not_converged + dy_norm_old = dy_norm - def not_converged_fun(not_converged): - return not_converged + return [k + 1, converged, dy_norm_old, d, y, n_function_evals] - dy_norm_old = dy_norm + k, converged, dy_norm_old, d, y, n_function_evals = \ + jax.lax.while_loop(while_cond, + while_body, + while_state) + return converged, k, y, d, state._replace(n_function_evals=n_function_evals) + + +def rms_norm(arg): + return jnp.sqrt(jnp.mean(arg**2)) + + +def _prepare_next_step(state, d): + D = _update_difference_for_next_step(state, d) + psi = _update_psi(state, D) + y0, scale_y0 = _predict(state, D) + return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0) + + +def _prepare_next_step_order_change(state, d, y, n_iter): + order = state.order + + D = _update_difference_for_next_step(state, d) + + # Note: we are recalculating these from the while loop above, could re-use? + scale_y = state.atol + state.rtol * jnp.abs(y) + error = state.error_const[order] * d + error_norm = rms_norm(error / scale_y) + safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + + n_iter) + + # similar to the optimal step size factor we calculated above for the current + # order k, we need to calculate the optimal step size factors for orders + # k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n + error_m_norm = jnp.where( + order > 1, + rms_norm(state.error_const[order - 1] * D[order] / scale_y), + jnp.inf + ) + error_p_norm = jnp.where( + order < MAX_ORDER, + rms_norm(state.error_const[order + 1] * D[order + 2] / scale_y), + jnp.inf + ) - not_converged = \ - jax.lax.cond(pred, - not_converged, converged_fun, - not_converged, not_converged_fun) - return [k + 1, not_converged, dy_norm_old, d, y, state] + error_norms = jnp.array([error_m_norm, error_norm, error_p_norm]) + factors = error_norms ** (-1 / (jnp.arange(3) + order)) - k, not_converged, dy_norm_old, d, y, state = jax.lax.while_loop(while_cond, - while_body, - while_state) - return not_converged, k, y, d, state + # now we have the three factors for orders k-1, k and k+1, pick the maximum in + # order to maximise the resultant step size + max_index = jnp.argmax(factors) + order += max_index - 1 + + factor = jnp.min((MAX_FACTOR, safety * factors[max_index])) + + new_state = _update_step_size_and_lu(state._replace(D=D, order=order), factor) + return new_state def _bdf_step(state, fun, jac): + #print('bdf_step', state.t, state.h) # we will try and use the old jacobian unless convergence of newton iteration # fails - not_updated_jacobian = True + updated_jacobian = False # initialise step size and try to make the step, # iterate, reducing step size until error is in bounds step_accepted = False - y = np.empty_like(state['y']) - d = np.empty_like(state['y']) + y = jnp.empty_like(state.y0) + d = jnp.empty_like(state.y0) n_iter = -1 # loop until step is accepted - while_state = [state, step_accepted, not_updated_jacobian, y, d, n_iter] + while_state = [state, step_accepted, updated_jacobian, y, d, n_iter] def while_cond(while_state): _, step_accepted, _, _, _, _ = while_state return step_accepted == False # noqa: E712 def while_body(while_state): - state, step_accepted, not_updated_jacobian, y, d, n_iter = while_state + state, step_accepted, updated_jacobian, y, d, n_iter = while_state # solve BDF equation using y0 as starting point - not_converged, n_iter, y, d, state = _newton_iteration(state, fun) - - # if not converged update the jacobian for J(t_n,y_n) and try again - pred = not_converged * not_updated_jacobian - if_state = [state, not_updated_jacobian, step_accepted] - - def need_to_update_jacobian(if_state): - # newton iteration did not converge, update the jacobian and try again - state, not_updated_jacobian, step_accepted = if_state - state = _update_jacobian(state, jac) - not_updated_jacobian = False - return [state, not_updated_jacobian, step_accepted] - - def dont_need_to_update_jacobian(if_state): - state, not_updated_jacobian, step_accepted = if_state - - if_state2 = [state, step_accepted] - - def need_to_update_step_size(if_state2): - # newton iteration did not converge, but jacobian has already been - # evaluated so reduce step size by 0.3 (as per [1]) and try again - state, step_accepted = if_state2 - state = _update_step_size(state, 0.3, False) - return [state, step_accepted] - - def converged(if_state2): - state, step_accepted = if_state2 - # yay, converged, now check error is within bounds - scale_y = state['atol'] + state['rtol'] * np.abs(y) - - # combine eq 3, 4 and 6 from [1] to obtain error - # Note that error = C_k * h^{k+1} y^{k+1} - # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} - error = state['error_const'][state['order']] * d - error_norm = np.sqrt(np.mean((error / scale_y)**2)) - - # calculate safety outside if since we will reuse later - safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER - + n_iter) - - if_state3 = [state, step_accepted] - - def error_too_large(if_state3): - # error too large, reduce step size and try again - state, step_accepted = if_state3 - state['n_error_test_failures'] += 1 - # calculate optimal step size factor as per eq 2.46 of [2] - factor = np.max((MIN_FACTOR, - safety * - error_norm ** (-1 / (state['order'] + 1)))) - state = _update_step_size(state, factor, False) - return [state, step_accepted] - - def accept_step(if_state3): - # if we get here we can accept the step - state, step_accepted = if_state3 - step_accepted = True - return [state, step_accepted] - - state, step_accepted = \ - jax.lax.cond(error_norm > 1, - if_state3, error_too_large, - if_state3, accept_step) - - return [state, step_accepted] - - state, step_accepted = jax.lax.cond(not_converged, - if_state2, need_to_update_step_size, - if_state2, converged) - - return [state, not_updated_jacobian, step_accepted] - - state, not_updated_jacobian, step_accepted = \ - jax.lax.cond(pred, - if_state, need_to_update_jacobian, - if_state, dont_need_to_update_jacobian) - return [state, step_accepted, not_updated_jacobian, y, d, n_iter] - - state, step_accepted, not_updated_jacobian, y, d, n_iter = \ + converged, n_iter, y, d, state = _newton_iteration(state, fun) + not_converged = converged == False # noqa: E712 + + # newton iteration did not converge, but jacobian has already been + # evaluated so reduce step size by 0.3 (as per [1]) and try again + state = tree_multimap( + partial(jnp.where, not_converged * updated_jacobian), + _update_step_size_and_lu(state, 0.3), + state + ) + + #if not_converged * updated_jacobian: + # print('not converged, update step size by 0.3') + #if not_converged * (updated_jacobian == False): + # print('not converged, update jacobian') + + # if not converged and jacobian not updated, then update the jacobian and try + # again + (state, updated_jacobian) = tree_multimap( + partial( + jnp.where, + not_converged * (updated_jacobian == False) # noqa: E712 + ), + (_update_jacobian(state, jac), True), + (state, False + updated_jacobian) + ) + + safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter) + scale_y = state.atol + state.rtol * jnp.abs(y) + + # combine eq 3, 4 and 6 from [1] to obtain error + # Note that error = C_k * h^{k+1} y^{k+1} + # and d = D^{k+1} y_{n+1} \approx h^{k+1} y^{k+1} + error = state.error_const[state.order] * d + + error_norm = rms_norm(error / scale_y) + + # calculate optimal step size factor as per eq 2.46 of [2] + factor = jnp.max((MIN_FACTOR, + safety * + error_norm ** (-1 / (state.order + 1)))) + + #if converged * (error_norm > 1): + # print('converged, but error is too large',error_norm, factor, d, scale_y) + + (state, step_accepted) = tree_multimap( + partial( + jnp.where, + converged * (error_norm > 1) # noqa: E712 + ), + (_update_step_size_and_lu(state, factor), False), + (state, converged) + ) + + return [state, step_accepted, updated_jacobian, y, d, n_iter] + + state, step_accepted, updated_jacobian, y, d, n_iter = \ jax.lax.while_loop(while_cond, while_body, while_state) # take the accepted step - state['y'] = y - state['t'] += state['h'] + n_steps = state.n_steps + 1 + t = state.t + state.h # a change in order is only done after running at order k for k + 1 steps # (see page 83 of [2]) - state['n_equal_steps'] += 1 - - if_state = [state, d, y, n_iter] - - def no_change_in_order(if_state): - # no change in order this step, update differences D and exit - state, d, _, _ = if_state - state = _update_difference_for_next_step(state, d, False) - return state - - def order_change(if_state): - state, d, y, n_iter = if_state - order = state['order'] - - # Note: we are recalculating these from the while loop above, could re-use? - scale_y = state['atol'] + state['rtol'] * np.abs(y) - error = state['error_const'][order] * d - error_norm = np.sqrt(np.mean((error / scale_y)**2)) - safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER - + n_iter) - - # don't need to update psi and y0 yet as we will be changing D again soon - state = _update_difference_for_next_step(state, d, True) - - if_state2 = [state, scale_y, order] - # similar to the optimal step size factor we calculated above for the current - # order k, we need to calculate the optimal step size factors for orders - # k-1 and k+1. To do this, we note that the error = C_k * D^{k+1} y_n - - def order_greater_one(if_state2): - state, scale_y, order = if_state2 - error_m = state['error_const'][order - 1] * state['D'][order] - error_m_norm = np.sqrt(np.mean((error_m / scale_y)**2)) - return error_m_norm - - def order_equal_one(if_state2): - error_m_norm = np.inf - return error_m_norm - - error_m_norm = jax.lax.cond(order > 1, - if_state2, order_greater_one, - if_state2, order_equal_one) - - def order_less_max(if_state2): - state, scale_y, order = if_state2 - error_p = state['error_const'][order + 1] * state['D'][order + 2] - error_p_norm = np.sqrt(np.mean((error_p / scale_y)**2)) - return error_p_norm - - def order_max(if_state2): - error_p_norm = np.inf - return error_p_norm - - error_p_norm = jax.lax.cond(order < MAX_ORDER, - if_state2, order_less_max, - if_state2, order_max) - - error_norms = np.array([error_m_norm, error_norm, error_p_norm]) - factors = error_norms ** (-1 / (np.arange(3) + order)) + n_equal_steps = state.n_equal_steps + 1 - # now we have the three factors for orders k-1, k and k+1, pick the maximum in - # order to maximise the resultant step size - max_index = np.argmax(factors) - order += max_index - 1 - state['order'] = order + state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps) - factor = np.min((MAX_FACTOR, safety * factors[max_index])) - state = _update_step_size(state, factor, False) - - return state - - state = jax.lax.cond(state['n_equal_steps'] < state['order'] + 1, - if_state, no_change_in_order, - if_state, order_change) + state = tree_multimap( + partial(jnp.where, n_equal_steps < state.order + 1), + _prepare_next_step(state, d), + _prepare_next_step_order_change(state, d, y, n_iter) + ) return state @@ -629,10 +672,10 @@ def _bdf_interpolate(state, t_eval): definition of the interpolating polynomial can be found on page 7 of [1] """ - order = state['order'] - t = state['t'] - h = state['h'] - D = state['D'] + order = state.order + t = state.t + h = state.h + D = state.D j = 0 time_factor = 1.0 order_summation = D[0] @@ -655,52 +698,295 @@ def while_body(while_state): return order_summation -@jax.partial(jax.jit, static_argnums=(0, 1, 2, 3)) -def _bdf_odeint(fun, jac, rtol, atol, y0, t_eval, inputs): - """ - main solver loop - creates a stepper object and steps through time, interpolating to - the time points in t_eval +def block_diag(lst): + def block_fun(i, j, Ai, Aj): + if i == j: + return Ai + else: + return jnp.zeros( + ( + Ai.shape[0] if Ai.ndim > 1 else 1, + Aj.shape[1] if Aj.ndim > 1 else 1, + ), + dtype=Ai.dtype + ) + + blocks = [ + [block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)] + for i, Ai in enumerate(lst) + ] + + return jnp.block(blocks) + +# NOTE: the code below (except the docstring on jax_bdf_integrate and other minor +# edits), has been modified from the JAX library at https://github.com/google/jax. +# The main difference is the addition of support for semi-explicit dae index 1 problems +# via the addition of a mass matrix. +# This is under an Apache license, a short form of which is given here: +# +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this +# file except in compliance with the License. You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under +# the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific language +# governing permissions and limitations under the License. + + +def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): """ + Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm + is derived in [2]_. This particular implementation follows that implemented in the + Matlab routine ode15s described in [1]_ and the SciPy implementation [3]_, which + features the NDF formulas for improved stability, with associated differences in the + error constants, and calculates the jacobian at J(t_{n+1}, y^0_{n+1}). This + implementation was based on that implemented in the scipy library [3]_, which also + mainly follows [1]_ but uses the more standard jacobian update. - def fun_bind_inputs(t, y): - return fun(t, y, inputs) + Parameters + ---------- - if jac is None: - jac_bind_inputs = jax.jacfwd(fun_bind_inputs, argnums=1) - else: - def jac_bind_inputs(t, y): - jac(t, y, inputs) + func: callable + function to evaluate the time derivative of the solution `y` at time + `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`. + y0: ndarray + initial state vector + t_eval: ndarray + time points to evaluate the solution, has shape (m,) + args: (optional) + tuple of additional arguments for `fun`, which must be arrays + scalars, or (nested) standard Python containers (tuples, lists, dicts, + namedtuples, i.e. pytrees) of those types. + rtol: (optional) float + relative tolerance for the solver + atol: (optional) float + absolute tolerance for the solver + mass: (optional) ndarray + diagonal of the mass matrix with shape (n,) - t0 = t_eval[0] - h0 = t_eval[1] - t0 + Returns + ------- + y: ndarray with shape (n, m) + calculated state vector at each of the m time points - stepper = _bdf_init(fun_bind_inputs, jac_bind_inputs, t0, y0, h0, rtol, atol) - i = 0 - y_out = np.empty((len(y0), len(t_eval)), dtype=y0.dtype) + References + ---------- + .. [1] L. F. Shampine, M. W. Reichelt, "THE MATLAB ODE SUITE", SIAM J. SCI. + COMPUTE., Vol. 18, No. 1, pp. 1-22, January 1997. + .. [2] G. D. Byrne, A. C. Hindmarsh, "A Polyalgorithm for the Numerical + Solution of Ordinary Differential Equations", ACM Transactions on + Mathematical Software, Vol. 1, No. 1, pp. 71-96, March 1975. + .. [3] Virtanen, P., Gommers, R., Oliphant, T. E., Haberland, M., Reddy, + T., Cournapeau, D., ... & van der Walt, S. J. (2020). SciPy 1.0: + fundamental algorithms for scientific computing in Python. + Nature methods, 17(3), 261-272. + """ + def _check_arg(arg): + if not isinstance(arg, core.Tracer) and not core.valid_jaxtype(arg): + msg = ("The contents of odeint *args must be arrays or scalars, but got " + "\n{}.") + raise TypeError(msg.format(arg)) - init_state = [stepper, t_eval, i, y_out, 0] + flat_args, in_tree = tree_flatten((y0, t_eval[0], *args)) + in_avals = tuple(safe_map(abstractify, flat_args)) + converted, consts = closure_convert(func, in_tree, in_avals) + return _bdf_odeint_wrapper(converted, mass, rtol, atol, y0, t_eval, *consts, *args) - def cond_fun(state): - _, t_eval, i, _, _ = state - return i < len(t_eval) - def body_fun(state): - stepper, t_eval, i, y_out, n_steps = state - stepper = _bdf_step(stepper, fun_bind_inputs, jac_bind_inputs) - index = np.searchsorted(t_eval, stepper['t']) +def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover + """ + for debugging purposes, use this instead of jax.lax.while_loop + """ + val = init_val + while cond_fun(val): + val = body_fun(val) + return val - def for_body(j, y_out): - t = t_eval[j] - y_out = jax.ops.index_update(y_out, jax.ops.index[:, j], - _bdf_interpolate(stepper, t)) - return y_out - y_out = jax.lax.fori_loop(i, index, for_body, y_out) - return [stepper, t_eval, index, y_out, n_steps + 1] +def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover + """ + for debugging purposes, use this instead of jax.lax.fori_loop + """ + val = init_val + for i in range(start, stop): + val = body_fun(i, val) + return val - stepper, t_eval, i, y_out, n_steps = jax.lax.while_loop(cond_fun, body_fun, - init_state) - stepper['n_steps'] = n_steps +def flax_scan(f, init, xs, length=None): # pragma: no cover + """ + for debugging purposes, use this instead of jax.lax.scan + """ + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, onp.stack(ys) - return y_out, stepper + +@jax.partial(jax.jit, static_argnums=(0, 1, 2, 3)) +def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args): + y0, unravel = ravel_pytree(y0) + if mass is None: + mass = onp.identity(y0.shape[0], dtype=y0.dtype) + else: + mass = block_diag(tree_flatten(mass)[0]) + func = ravel_first_arg(func, unravel) + out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) + return jax.vmap(unravel)(out) + + +def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args): + ys = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) + return ys, (ys, ts, args) + + +def _bdf_odeint_rev(func, mass, rtol, atol, res, g): + ys, ts, args = res + + def aug_dynamics(augmented_state, t, *args): + """Original system augmented with vjp_y, vjp_t and vjp_args.""" + y, y_bar, *_ = augmented_state + # `t` here is negative time, so we need to negate again to get back to + # normal time. See the `odeint` invocation in `scan_fun` below. + y_dot, vjpfun = jax.vjp(func, y, -t, *args) + + # Adjoint equations for semi-explicit dae index 1 system from + # + # [1] Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity + # analysis for differential-algebraic equations: The adjoint DAE system and its + # numerical solution. SIAM journal on scientific computing, 24(3), 1076-1089. + # + # y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a + # 0 = J_da^T y_bar_d + J_aa^T y_bar_d + + y_bar_dot, *rest = vjpfun(y_bar) + + return (-y_dot, y_bar_dot, *rest) + + algebraic_variables = jnp.diag(mass) == 0. + differentiable_variables = algebraic_variables == False # noqa: E712 + mass_is_I = (mass == jnp.eye(mass.shape[0])).all() + is_dae = jnp.any(algebraic_variables) + + if not mass_is_I: + M_dd = mass[onp.ix_(differentiable_variables, differentiable_variables)] + LU_invM_dd = jax.scipy.linalg.lu_factor(M_dd) + + def initialise(g0, y0, t0): + # [1] gives init conditions for y_bar_a = g_d - J_ad^T (J_aa^T)^-1 g_a + if mass_is_I: + y_bar = g0 + elif is_dae: + J = jax.jacfwd(func)(y0, t0, *args) + + # boolean arguments not implemented in jnp.ix_ + J_aa = J[onp.ix_(algebraic_variables, algebraic_variables)] + J_ad = J[onp.ix_(algebraic_variables, differentiable_variables)] + LU = jax.scipy.linalg.lu_factor(J_aa) + g0_a = g0[algebraic_variables] + invJ_aa = jax.scipy.linalg.lu_solve(LU, g0_a) + y_bar = jax.ops.index_update( + g0, differentiable_variables, + jax.scipy.linalg.lu_solve(LU_invM_dd, + g0_a - J_ad @ invJ_aa) + ) + else: + y_bar = jax.scipy.linalg.lu_solve(LU_invM_dd, g0) + return y_bar + + y_bar = initialise(g[-1], ys[-1], ts[-1]) + ts_bar = [] + t0_bar = 0. + + def arg_to_identity(arg): + return onp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype) + + aug_mass = (mass, mass, jnp.array(1.), tree_map(arg_to_identity, args)) + + def scan_fun(carry, i): + y_bar, t0_bar, args_bar = carry + # Compute effect of moving measurement time + t_bar = jnp.dot(func(ys[i], ts[i], *args), g[i]) + t0_bar = t0_bar - t_bar + # Run augmented system backwards to previous observation + _, y_bar, t0_bar, args_bar = jax_bdf_integrate( + aug_dynamics, (ys[i], y_bar, t0_bar, args_bar), + jnp.array([-ts[i], -ts[i - 1]]), + *args, mass=aug_mass, + rtol=rtol, atol=atol) + y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar)) + # Add gradient from current output + y_bar = y_bar + initialise(g[i - 1], ys[i - 1], ts[i - 1]) + return (y_bar, t0_bar, args_bar), t_bar + + init_carry = (y_bar, t0_bar, tree_map(jnp.zeros_like, args)) + (y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan( + scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)) + ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]]) + return (y_bar, ts_bar, *args_bar) + + +_bdf_odeint.defvjp(_bdf_odeint_fwd, _bdf_odeint_rev) + + +@cache() +def closure_convert(fun, in_tree, in_avals): + in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] + wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) + with core.initial_style_staging(): + jaxpr, _, consts = pe.trace_to_jaxpr( + wrapped_fun, in_pvals, instantiate=True, stage_out=False + ) + out_tree = out_tree() + + # We only want to closure convert for constants with respect to which we're + # differentiating. As a proxy for that, we hoist consts with float dtype. + # TODO(mattjj): revise this approach + (closure_consts, hoisted_consts), merge = partition_list( + lambda c: dtypes.issubdtype(dtypes.dtype(c), jnp.inexact), + consts + ) + num_consts = len(hoisted_consts) + + def converted_fun(y, t, *hconsts_args): + hoisted_consts, args = split_list(hconsts_args, [num_consts]) + consts = merge(closure_consts, hoisted_consts) + all_args, _ = tree_flatten((y, t, *args)) + out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) + return tree_unflatten(out_tree, out_flat) + + return converted_fun, hoisted_consts + + +def partition_list(choice, lst): + out = [], [] + which = [out[choice(elt)].append(elt) or choice(elt) for elt in lst] + + def merge(l1, l2): + i1, i2 = iter(l1), iter(l2) + return [next(i2 if snd else i1) for snd in which] + return out, merge + + +def abstractify(x): + return core.raise_to_shaped(core.get_aval(x)) + + +def ravel_first_arg(f, unravel): + return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped + + +@lu.transformation +def ravel_first_arg_(unravel, y_flat, *args): + y = unravel(y_flat) + ans = yield (y,) + args, {} + ans_flat, _ = ravel_pytree(ans) + yield ans_flat diff --git a/pybamm/solvers/jax_solver.py b/pybamm/solvers/jax_solver.py index 5b9fb6bc0e..db912d0da1 100644 --- a/pybamm/solvers/jax_solver.py +++ b/pybamm/solvers/jax_solver.py @@ -5,7 +5,7 @@ import jax from jax.experimental.ode import odeint -import jax.numpy as np +import jax.numpy as jnp import numpy as onp @@ -30,6 +30,10 @@ class JaxSolver(pybamm.BaseSolver): method: str 'RK45' (default) uses jax.experimental.odeint 'BDF' uses custom jax_bdf_integrate (see jax_bdf_integrate.py for details) + root_method: str, optional + Method to use to calculate consistent initial conditions. By default this uses + the newton chord method internal to the jax bdf solver, otherwise choose from + the set of default options defined in docs for pybamm.BaseSolver rtol : float, optional The relative tolerance for the solver (default is 1e-6). atol : float, optional @@ -41,14 +45,19 @@ class JaxSolver(pybamm.BaseSolver): for details. """ - def __init__(self, method='RK45', rtol=1e-6, atol=1e-6, extra_options=None): - super().__init__(method, rtol, atol) - self.ode_solver = True + def __init__(self, method='RK45', root_method=None, + rtol=1e-6, atol=1e-6, extra_options=None): + # note: bdf solver itself calculates consistent initial conditions so can set + # root_method to none, allow user to override this behavior + super().__init__(method, rtol, atol, root_method=root_method) method_options = ['RK45', 'BDF'] if method not in method_options: raise ValueError('method must be one of {}'.format(method_options)) + self.ode_solver = False + if method == 'RK45': + self.ode_solver = True self.extra_options = extra_options or {} - self.name = "JAX solver" + self.name = "JAX solver ({})".format(method) self._cached_solves = dict() def get_solve(self, model, t_eval): @@ -74,11 +83,11 @@ def get_solve(self, model, t_eval): raise RuntimeError("Model is not set up for solving, run" "`solver.solve(model)` first") - self._cached_solves[model] = self._create_solve(model, t_eval) + self._cached_solves[model] = self.create_solve(model, t_eval) return self._cached_solves[model] - def _create_solve(self, model, t_eval): + def create_solve(self, model, t_eval): """ Return a compiled JAX function that solves an ode model with input arguments. @@ -109,15 +118,24 @@ def _create_solve(self, model, t_eval): " re-solve using no events and a fixed" " end-time".format(model.events)) - # Initial conditions - y0 = model.y0 + # Initial conditions, make sure they are an 0D array + y0 = jnp.array(model.y0).reshape(-1) + mass = None + if self.method == 'BDF': + mass = model.mass_matrix.entries.toarray() - def rhs_odeint(y, t, inputs): - return model.rhs_eval(t, y, inputs) + def rhs_ode(y, t, inputs): + return model.rhs_eval(t, y, inputs), + + def rhs_dae(y, t, inputs): + return jnp.concatenate([ + model.rhs_eval(t, y, inputs), + model.algebraic_eval(t, y, inputs), + ]) def solve_model_rk45(inputs): y = odeint( - rhs_odeint, + rhs_ode, y0, t_eval, inputs, @@ -125,19 +143,20 @@ def solve_model_rk45(inputs): atol=self.atol, **self.extra_options ) - return np.transpose(y), None + return jnp.transpose(y) def solve_model_bdf(inputs): - y, stepper = pybamm.jax_bdf_integrate( - model.rhs_eval, + y = pybamm.jax_bdf_integrate( + rhs_dae, y0, t_eval, - inputs=inputs, + inputs, rtol=self.rtol, atol=self.atol, + mass=mass, **self.extra_options ) - return y, stepper + return jnp.transpose(y) if self.method == 'RK45': return jax.jit(solve_model_rk45) @@ -165,27 +184,13 @@ def _integrate(self, model, t_eval, inputs=None): """ if model not in self._cached_solves: - self._cached_solves[model] = self._create_solve(model, t_eval) + self._cached_solves[model] = self.create_solve(model, t_eval) - y, stepper = self._cached_solves[model](inputs) + y = self._cached_solves[model](inputs) # note - the actual solve is not done until this line! y = onp.array(y) - if stepper is not None: - sstring = '' - sstring += 'JAX {} solver - stats\n'.format(self.method) - sstring += '\tNumber of steps: {}\n'.format(stepper['n_steps']) - sstring += '\tnumber of function evaluations: {}\n'.format( - stepper['n_function_evals']) - sstring += '\tnumber of jacobian evaluations: {}\n'.format( - stepper['n_jacobian_evals']) - sstring += '\tnumber of LU decompositions: {}\n'.format( - stepper['n_lu_decompositions']) - sstring += '\tnumber of error test failures: {}'.format( - stepper['n_error_test_failures']) - pybamm.logger.info(sstring) - termination = "final time" t_event = None y_event = onp.array(None) diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index a91a722c13..1ec810c09d 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -5,6 +5,8 @@ import time import numpy as np from platform import system +if system() != "Windows": + import jax @unittest.skipIf(system() == "Windows", "JAX not supported on windows") @@ -26,33 +28,158 @@ def test_solver(self): disc.process_model(model) # Solve - t_eval = np.linspace(0, 1, 80) - y0 = model.concatenated_initial_conditions.evaluate() + t_eval = np.linspace(0.0, 1.0, 80) + y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) rhs = pybamm.EvaluatorJax(model.concatenated_rhs) - def fun(t, y, inputs): - return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1) + def fun(y, t): + return rhs.evaluate(t=t, y=y).reshape(-1) t0 = time.perf_counter() - y, _ = pybamm.jax_bdf_integrate( - fun, y0, t_eval, inputs=None, rtol=1e-8, atol=1e-8) + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) t1 = time.perf_counter() - t0 # test accuracy - np.testing.assert_allclose(y[0, :], np.exp(0.1 * t_eval), + np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), + rtol=1e-6, atol=1e-6) + + t0 = time.perf_counter() + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) + t2 = time.perf_counter() - t0 + + # second run should be much quicker + self.assertLess(t2, t1) + + # test second run is accurate + np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), + rtol=1e-6, atol=1e-6) + + def test_mass_matrix(self): + # Solve + t_eval = np.linspace(0.0, 1.0, 80) + + def fun(y, t): + return jax.numpy.stack([ + 0.1 * y[0], + y[1] - 2.0 * y[0], + ]) + + mass = jax.numpy.array([ + [2.0, 0.0], + [0.0, 0.0], + ]) + + # give some bad initial conditions, solver should calculate correct ones using + # this as a guess + y0 = jax.numpy.array([1.0, 1.5]) + + t0 = time.perf_counter() + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) + t1 = time.perf_counter() - t0 + + # test accuracy + soln = np.exp(0.05 * t_eval) + np.testing.assert_allclose(y[:, 0], soln, + rtol=1e-7, atol=1e-7) + np.testing.assert_allclose(y[:, 1], 2.0 * soln, rtol=1e-7, atol=1e-7) t0 = time.perf_counter() - y, _ = pybamm.jax_bdf_integrate(fun, y0, t_eval, rtol=1e-8, atol=1e-8) + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) t2 = time.perf_counter() - t0 # second run should be much quicker self.assertLess(t2, t1) # test second run is accurate - np.testing.assert_allclose(y[0, :], np.exp(0.1 * t_eval), + np.testing.assert_allclose(y[:, 0], np.exp(0.05 * t_eval), rtol=1e-7, atol=1e-7) + def test_solver_sensitivities(self): + # Create model + model = pybamm.BaseModel() + model.convert_to_format = "jax" + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: -pybamm.InputParameter("rate") * var} + model.initial_conditions = {var: 1} + + # create discretisation + mesh = get_mesh_for_testing(xpts=10) + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + # Solve + t_eval = np.linspace(0, 10, 4) + y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) + rhs = pybamm.EvaluatorJax(model.concatenated_rhs) + + def fun(y, t, inputs): + return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1) + + h = 0.0001 + rate = 0.1 + + # create a dummy "model" where we calculate the sum of the time series + @jax.jit + def solve_bdf(rate): + return jax.numpy.sum( + pybamm.jax_bdf_integrate(fun, y0, t_eval, + {'rate': rate}, + rtol=1e-9, atol=1e-9) + ) + + # check answers with finite difference + eval_plus = solve_bdf(rate + h) + eval_neg = solve_bdf(rate - h) + grad_num = (eval_plus - eval_neg) / (2 * h) + + grad_solve_bdf = jax.jit(jax.grad(solve_bdf)) + grad_bdf = grad_solve_bdf(rate) + + self.assertAlmostEqual(grad_bdf, grad_num, places=3) + + def test_mass_matrix_with_sensitivities(self): + # Solve + t_eval = np.linspace(0.0, 1.0, 80) + + def fun(y, t, inputs): + return jax.numpy.stack([ + inputs['rate'] * y[0], + y[1] - 2.0 * y[0], + ]) + + mass = jax.numpy.array([ + [2.0, 0.0], + [0.0, 0.0], + ]) + + y0 = jax.numpy.array([1.0, 2.0]) + + h = 0.0001 + rate = 0.1 + + # create a dummy "model" where we calculate the sum of the time series + @jax.jit + def solve_bdf(rate): + return jax.numpy.sum( + pybamm.jax_bdf_integrate(fun, y0, t_eval, + {'rate': rate}, + mass=mass, + rtol=1e-9, atol=1e-9) + ) + + # check answers with finite difference + eval_plus = solve_bdf(rate + h) + eval_neg = solve_bdf(rate - h) + grad_num = (eval_plus - eval_neg) / (2 * h) + + grad_solve_bdf = jax.jit(jax.grad(solve_bdf)) + grad_bdf = grad_solve_bdf(rate) + + self.assertAlmostEqual(grad_bdf, grad_num, places=3) + def test_solver_with_inputs(self): # Create model model = pybamm.BaseModel() @@ -69,17 +196,17 @@ def test_solver_with_inputs(self): disc.process_model(model) # Solve - t_eval = np.linspace(0, 10, 100) - y0 = model.concatenated_initial_conditions.evaluate() + t_eval = np.linspace(0, 10, 80) + y0 = model.concatenated_initial_conditions.evaluate().reshape(-1) rhs = pybamm.EvaluatorJax(model.concatenated_rhs) - def fun(t, y, inputs): + def fun(y, t, inputs): return rhs.evaluate(t=t, y=y, inputs=inputs).reshape(-1) - y, _ = pybamm.jax_bdf_integrate(fun, y0, t_eval, inputs={ + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, { "rate": 0.1}, rtol=1e-9, atol=1e-9) - np.testing.assert_allclose(y[0, :], np.exp(-0.1 * t_eval)) + np.testing.assert_allclose(y[:, 0].reshape(-1), np.exp(-0.1 * t_eval)) if __name__ == "__main__": diff --git a/tests/unit/test_solvers/test_jax_solver.py b/tests/unit/test_solvers/test_jax_solver.py index 47ad041ea8..d30f3f76b6 100644 --- a/tests/unit/test_solvers/test_jax_solver.py +++ b/tests/unit/test_solvers/test_jax_solver.py @@ -5,6 +5,8 @@ import time import numpy as np from platform import system +if system() != "Windows": + import jax @unittest.skipIf(system() == "Windows", "JAX not supported on windows") @@ -25,7 +27,7 @@ def test_model_solver(self): disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) - for method in ['BDF', 'RK45']: + for method in ['RK45', 'BDF']: # Solve solver = pybamm.JaxSolver( method=method, rtol=1e-8, atol=1e-8 @@ -36,7 +38,7 @@ def test_model_solver(self): t_first_solve = time.perf_counter() - t0 np.testing.assert_array_equal(solution.t, t_eval) np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t), - rtol=1e-7, atol=1e-7) + rtol=1e-6, atol=1e-6) # Test time self.assertEqual( @@ -51,6 +53,97 @@ def test_model_solver(self): self.assertLess(t_second_solve, t_first_solve) np.testing.assert_array_equal(second_solution.y, solution.y) + def test_semi_explicit_model(self): + # Create model + model = pybamm.BaseModel() + model.convert_to_format = "jax" + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + var2 = pybamm.Variable("var2", domain=domain) + model.rhs = {var: 0.1 * var} + model.algebraic = {var2: var2 - 2.0 * var} + # give inconsistent initial conditions, should calculate correct ones + model.initial_conditions = {var: 1.0, var2: 1.0} + # No need to set parameters; can use base discretisation (no spatial operators) + + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + # Solve + solver = pybamm.JaxSolver( + method='BDF', rtol=1e-8, atol=1e-8 + ) + t_eval = np.linspace(0, 1, 80) + t0 = time.perf_counter() + solution = solver.solve(model, t_eval) + t_first_solve = time.perf_counter() - t0 + np.testing.assert_array_equal(solution.t, t_eval) + soln = np.exp(0.1 * solution.t) + np.testing.assert_allclose(solution.y[0], soln, + rtol=1e-7, atol=1e-7) + np.testing.assert_allclose(solution.y[-1], 2 * soln, + rtol=1e-7, atol=1e-7) + + # Test time + self.assertEqual( + solution.total_time, solution.solve_time + solution.set_up_time + ) + self.assertEqual(solution.termination, "final time") + + t0 = time.perf_counter() + second_solution = solver.solve(model, t_eval) + t_second_solve = time.perf_counter() - t0 + + self.assertLess(t_second_solve, t_first_solve) + np.testing.assert_array_equal(second_solution.y, solution.y) + + def test_solver_sensitivities(self): + # Create model + model = pybamm.BaseModel() + model.convert_to_format = "jax" + domain = ["negative electrode", "separator", "positive electrode"] + var = pybamm.Variable("var", domain=domain) + model.rhs = {var: -pybamm.InputParameter("rate") * var} + model.initial_conditions = {var: 1.0} + # No need to set parameters; can use base discretisation (no spatial operators) + + # create discretisation + mesh = get_mesh_for_testing() + spatial_methods = {"macroscale": pybamm.FiniteVolume()} + disc = pybamm.Discretisation(mesh, spatial_methods) + disc.process_model(model) + + for method in ['RK45', 'BDF']: + # Solve + solver = pybamm.JaxSolver( + method=method, rtol=1e-8, atol=1e-8 + ) + t_eval = np.linspace(0, 1, 80) + + h = 0.0001 + rate = 0.1 + + # need to solve the model once to get it set up by the base solver + solver.solve(model, t_eval, inputs={'rate': rate}) + solve = solver.get_solve(model, t_eval) + + # create a dummy "model" where we calculate the sum of the time series + def solve_model(rate): + return jax.numpy.sum(solve({'rate': rate})) + + # check answers with finite difference + eval_plus = solve_model(rate + h) + eval_neg = solve_model(rate - h) + grad_num = (eval_plus - eval_neg) / (2 * h) + + grad_solve = jax.jit(jax.grad(solve_model)) + grad = grad_solve(rate) + + self.assertAlmostEqual(grad, grad_num, places=1) + def test_solver_only_works_with_jax(self): model = pybamm.BaseModel() var = pybamm.Variable("var") @@ -117,7 +210,7 @@ def test_model_solver_with_inputs(self): disc.process_model(model) # Solve solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8) - t_eval = np.linspace(0, 5, 100) + t_eval = np.linspace(0, 5, 80) t0 = time.perf_counter() solution = solver.solve(model, t_eval, inputs={"rate": 0.1}) @@ -151,21 +244,26 @@ def test_get_solve(self): spatial_methods = {"macroscale": pybamm.FiniteVolume()} disc = pybamm.Discretisation(mesh, spatial_methods) disc.process_model(model) + + # test that another method string gives error + with self.assertRaises(ValueError): + solver = pybamm.JaxSolver(method='not_real') + # Solve solver = pybamm.JaxSolver(rtol=1e-8, atol=1e-8) - t_eval = np.linspace(0, 5, 100) + t_eval = np.linspace(0, 5, 80) with self.assertRaisesRegex(RuntimeError, "Model is not set up for solving"): solver.get_solve(model, t_eval) solver.solve(model, t_eval, inputs={"rate": 0.1}) solver = solver.get_solve(model, t_eval) - y, _ = solver({"rate": 0.1}) + y = solver({"rate": 0.1}) np.testing.assert_allclose(y[0], np.exp(-0.1 * t_eval), rtol=1e-6, atol=1e-6) - y, _ = solver({"rate": 0.2}) + y = solver({"rate": 0.2}) np.testing.assert_allclose(y[0], np.exp(-0.2 * t_eval), rtol=1e-6, atol=1e-6)