From c91f123ecc18bb4c15a6153c2abfa4bbe58225a6 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Sun, 12 Jul 2020 09:20:27 +0100 Subject: [PATCH] #1104 debugging for incorporation of mass matrix --- pybamm/solvers/jax_bdf_solver.py | 139 ++++++++++++------ .../unit/test_solvers/test_jax_bdf_solver.py | 40 +++++ 2 files changed, 135 insertions(+), 44 deletions(-) diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 717dd7c6af..47db4f366e 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -25,8 +25,8 @@ MAX_FACTOR = 10 -@jax.partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2)) -def _bdf_odeint(fun, rtol, atol, y0, t_eval, *args): +@jax.partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2, 3)) +def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args): """ This implements a Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in [2]_. This particular implementation follows that @@ -43,8 +43,10 @@ def _bdf_odeint(fun, rtol, atol, y0, t_eval, *args): func: callable function to evaluate the time derivative of the solution `y` at time `t` as `func(y, t, *args)`, producing the same shape/structure as `y0`. + mass: ndarray + constant mass matrix with shape (n,n) y0: ndarray - initial state vector + initial state vector, has shape (n,) t_eval: ndarray time points to evaluate the solution, has shape (m,) args: (optional) @@ -82,7 +84,7 @@ def fun_bind_inputs(y, t): t0 = t_eval[0] h0 = t_eval[1] - t0 - stepper = _bdf_init(fun_bind_inputs, jac_bind_inputs, t0, y0, h0, rtol, atol) + stepper = _bdf_init(fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol) i = 0 y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype) @@ -103,10 +105,10 @@ def for_body(j, y_out): _bdf_interpolate(stepper, t)) return y_out - y_out = jax.lax.fori_loop(i, index, for_body, y_out) + y_out = flax_fori_loop(i, index, for_body, y_out) return [stepper, t_eval, index, y_out, n_steps + 1] - stepper, t_eval, i, y_out, n_steps = jax.lax.while_loop(cond_fun, body_fun, + stepper, t_eval, i, y_out, n_steps = flax_while_loop(cond_fun, body_fun, init_state) stepper['n_steps'] = n_steps @@ -114,7 +116,7 @@ def for_body(j, y_out): return y_out -def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): +def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): """ Initiation routine for Backward Difference formula (BDF) implicit multistep integrator. @@ -132,6 +134,8 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): jac: callable function with signature (y, t), where t is a scalar time and y is a ndarray with shape (n,), returns the jacobian matrix of fun as an ndarray with shape (n,n) + mass: ndarray + constant mass matrix with shape (n,n) t0: float initial time y0: ndarray @@ -147,29 +151,34 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): state['t'] = t0 state['y'] = y0 f0 = fun(y0, t0) + print('f0 is ', f0) state['atol'] = atol state['rtol'] = rtol order = 1 state['order'] = order state['h'] = _select_initial_step(state, fun, t0, y0, f0, h0) + print('h0 is ',state['h']) EPS = jnp.finfo(y0.dtype).eps state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5)))) state['n_equal_steps'] = 0 D = jnp.empty((MAX_ORDER + 1, len(y0)), dtype=y0.dtype) D = jax.ops.index_update(D, jax.ops.index[0, :], y0) - D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * h0) + D = jax.ops.index_update(D, jax.ops.index[1, :], f0 * state['h']) state['D'] = D state['y0'] = None state['scale_y0'] = None state = _predict(state) - I = jnp.identity(len(y0), dtype=y0.dtype) - state['I'] = I + if mass is None: + state['M'] = jnp.identity(len(y0), dtype=y0.dtype) + else: + state['M'] = mass + print('mass is ', state['M']) # kappa values for difference orders, taken from Table 1 of [1] kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) gamma = jnp.hstack((0, jnp.cumsum(1 / jnp.arange(1, MAX_ORDER + 1)))) alpha = 1.0 / ((1 - kappa) * gamma) - c = h0 * alpha[order] + c = state['h'] * alpha[order] error_const = kappa * gamma + 1 / jnp.arange(1, MAX_ORDER + 2) state['kappa'] = kappa @@ -180,7 +189,8 @@ def _bdf_init(fun, jac, t0, y0, h0, rtol, atol): J = jac(y0, t0) state['J'] = J - state['LU'] = jax.scipy.linalg.lu_factor(I - c * J) + print('J is ', J) + state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - c * J) state['U'] = _compute_R(order, 1) state['psi'] = None @@ -300,7 +310,7 @@ def while_body(while_state): i -= 1 return [i, D] - i, D = jax.lax.while_loop(while_cond, while_body, while_state) + i, D = flax_while_loop(while_cond, while_body, while_state) state['D'] = D @@ -313,7 +323,7 @@ def update_psi_and_predict(state): return state - state = jax.lax.cond(only_update_D == False, # noqa: E712 + state = flax_cond(only_update_D == False, # noqa: E712 state, update_psi_and_predict, state, lambda x: x) @@ -326,9 +336,10 @@ def _update_step_size(state, factor, dont_update_lu): the first equation of page 9 of [1]: - constant c = h / (1-kappa) gamma_k term - - lu factorisation of (I - c * J) used in newton iteration (same equation) + - lu factorisation of (M - c * J) used in newton iteration (same equation) - psi term """ + print('update_step_size',factor) order = state['order'] h = state['h'] @@ -338,11 +349,11 @@ def _update_step_size(state, factor, dont_update_lu): # redo lu (c has changed) def update_lu(state): - state['LU'] = jax.scipy.linalg.lu_factor(state['I'] - c * state['J']) + state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - c * state['J']) state['n_lu_decompositions'] += 1 return state - state = jax.lax.cond(dont_update_lu == False, # noqa: E712 + state = flax_cond(dont_update_lu == False, # noqa: E712 state, update_lu, state, lambda x: x) @@ -377,20 +388,23 @@ def _update_jacobian(state, jac): we update the jacobian using J(t_{n+1}, y^0_{n+1}) following the scipy bdf implementation rather than J(t_n, y_n) as per [1] """ + print('update_jacobian') J = jac(state['y0'], state['t'] + state['h']) state['n_jacobian_evals'] += 1 - state['LU'] = jax.scipy.linalg.lu_factor(state['I'] - state['c'] * J) + state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - state['c'] * J) state['n_lu_decompositions'] += 1 state['J'] = J return state def _newton_iteration(state, fun): + print('newton iterate') tol = state['newton_tol'] c = state['c'] psi = state['psi'] y0 = state['y0'] LU = state['LU'] + M = state['M'] scale_y0 = state['scale_y0'] t = state['t'] + state['h'] d = jnp.zeros_like(y0) @@ -409,7 +423,7 @@ def while_body(while_state): k, not_converged, dy_norm_old, d, y, state = while_state f_eval = fun(y, t) state['n_function_evals'] += 1 - b = c * f_eval - psi - d + b = c * f_eval - M @ (psi - d) dy = jax.scipy.linalg.lu_solve(LU, b) dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0)**2)) rate = dy_norm / dy_norm_old @@ -419,7 +433,7 @@ def while_body(while_state): pred = rate >= 1 pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol pred *= dy_norm_old >= 0 - k = jax.lax.cond(pred, k, lambda k: NEWTON_MAXITER, k, lambda k: k) + k = flax_cond(pred, k, lambda k: NEWTON_MAXITER, k, lambda k: k) d += dy y = y0 + d @@ -439,18 +453,19 @@ def not_converged_fun(not_converged): dy_norm_old = dy_norm not_converged = \ - jax.lax.cond(pred, + flax_cond(pred, not_converged, converged_fun, not_converged, not_converged_fun) return [k + 1, not_converged, dy_norm_old, d, y, state] - k, not_converged, dy_norm_old, d, y, state = jax.lax.while_loop(while_cond, + k, not_converged, dy_norm_old, d, y, state = flax_while_loop(while_cond, while_body, while_state) return not_converged, k, y, d, state def _bdf_step(state, fun, jac): + print('bdf step', state['t']) # we will try and use the old jacobian unless convergence of newton iteration # fails not_updated_jacobian = True @@ -516,6 +531,7 @@ def converged(if_state2): def error_too_large(if_state3): # error too large, reduce step size and try again + print('error too large') state, step_accepted = if_state3 state['n_error_test_failures'] += 1 # calculate optimal step size factor as per eq 2.46 of [2] @@ -532,26 +548,26 @@ def accept_step(if_state3): return [state, step_accepted] state, step_accepted = \ - jax.lax.cond(error_norm > 1, + flax_cond(error_norm > 1, if_state3, error_too_large, if_state3, accept_step) return [state, step_accepted] - state, step_accepted = jax.lax.cond(not_converged, + state, step_accepted = flax_cond(not_converged, if_state2, need_to_update_step_size, if_state2, converged) return [state, not_updated_jacobian, step_accepted] state, not_updated_jacobian, step_accepted = \ - jax.lax.cond(pred, + flax_cond(pred, if_state, need_to_update_jacobian, if_state, dont_need_to_update_jacobian) return [state, step_accepted, not_updated_jacobian, y, d, n_iter] state, step_accepted, not_updated_jacobian, y, d, n_iter = \ - jax.lax.while_loop(while_cond, while_body, while_state) + flax_while_loop(while_cond, while_body, while_state) # take the accepted step state['y'] = y @@ -598,7 +614,7 @@ def order_equal_one(if_state2): error_m_norm = jnp.inf return error_m_norm - error_m_norm = jax.lax.cond(order > 1, + error_m_norm = flax_cond(order > 1, if_state2, order_greater_one, if_state2, order_equal_one) @@ -612,7 +628,7 @@ def order_max(if_state2): error_p_norm = jnp.inf return error_p_norm - error_p_norm = jax.lax.cond(order < MAX_ORDER, + error_p_norm = flax_cond(order < MAX_ORDER, if_state2, order_less_max, if_state2, order_max) @@ -630,7 +646,7 @@ def order_max(if_state2): return state - state = jax.lax.cond(state['n_equal_steps'] < state['order'] + 1, + state = flax_cond(state['n_equal_steps'] < state['order'] + 1, if_state, no_change_in_order, if_state, order_change) @@ -663,12 +679,27 @@ def while_body(while_state): j += 1 return [j, time_factor, order_summation] - j, time_factor, order_summation = jax.lax.while_loop(while_cond, + j, time_factor, order_summation = flax_while_loop(while_cond, while_body, while_state) return order_summation +def block_diag(lst): + def block_fun(i, j, Ai, Aj): + if i == j: + return Ai + else: + return jnp.zeros(Ai.shape[0], Aj.shape[1]) + + blocks = [ + [ block_fun(i, j, Ai, Aj) for j, Aj in enumerate(lst)] + for i, Ai in enumerate(lst) + ] + + return jnp.block(blocks) + + # NOTE: all code below (except the docstring on jax_bdf_integrate), to define the API of # the jax solver and the ability to solve the adjoint sensitivities, has been copied # from the JAX library at https://github.com/google/jax. This is under an Apache @@ -687,7 +718,7 @@ def while_body(while_state): # governing permissions and limitations under the License. -def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6): +def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None): """ Backward Difference formula (BDF) implicit multistep integrator. The basic algorithm is derived in [2]_. This particular implementation follows that implemented in the @@ -715,6 +746,8 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6): relative tolerance for the solver atol: (optional) float absolute tolerance for the solver + mass: (optional) ndarray + constant mass matrix with shape (n,n) Returns ------- @@ -743,13 +776,13 @@ def _check_arg(arg): in_avals = tuple(map(abstractify, flat_args)) converted, consts = closure_convert(func, in_tree, in_avals) - return _bdf_odeint_wrapper(converted, rtol, atol, y0, t_eval, *consts, *args) + return _bdf_odeint_wrapper(converted, mass, rtol, atol, y0, t_eval, *consts, *args) def flax_cond(pred, true_operand, true_fun, false_operand, false_fun): # pragma: no cover """ - for debugging purposes, use this instead of jax.lax.cond + for debugging purposes, use this instead of flax_cond """ if pred: return true_fun(true_operand) @@ -759,7 +792,7 @@ def flax_cond(pred, true_operand, true_fun, def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover """ - for debugging purposes, use this instead of jax.lax.while_loop + for debugging purposes, use this instead of flax_while_loop """ val = init_val while cond_fun(val): @@ -769,7 +802,7 @@ def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover """ - for debugging purposes, use this instead of jax.lax.fori_loop + for debugging purposes, use this instead of flax_fori_loop """ val = init_val for i in range(start, stop): @@ -779,7 +812,7 @@ def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover def flax_scan(f, init, xs, length=None): # pragma: no cover """ - for debugging purposes, use this instead of jax.lax.scan + for debugging purposes, use this instead of flax_scan """ if xs is None: xs = [None] * length @@ -791,20 +824,23 @@ def flax_scan(f, init, xs, length=None): # pragma: no cover return carry, onp.stack(ys) -@partial(jax.jit, static_argnums=(0, 1, 2)) -def _bdf_odeint_wrapper(func, rtol, atol, y0, ts, *args): +#@partial(jax.jit, static_argnums=(0, 1, 2, 3)) +def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args): y0, unravel = ravel_pytree(y0) + if mass is not None: + mass = block_diag(tree_flatten(mass)[0]) + print('mass matrix is ', mass) func = ravel_first_arg(func, unravel) - out = _bdf_odeint(func, rtol, atol, y0, ts, *args) + out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) return jax.vmap(unravel)(out) -def _bdf_odeint_fwd(func, rtol, atol, y0, ts, *args): - ys = _bdf_odeint(func, rtol, atol, y0, ts, *args) +def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args): + ys = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) return ys, (ys, ts, args) -def _bdf_odeint_rev(func, rtol, atol, res, g): +def _bdf_odeint_rev(func, mass, rtol, atol, res, g): ys, ts, args = res def aug_dynamics(augmented_state, t, *args): @@ -818,6 +854,7 @@ def aug_dynamics(augmented_state, t, *args): y_bar = g[-1] ts_bar = [] t0_bar = 0. + aug_mass = (mass, mass, jnp.ones((1, 1)), tree_map(jnp.ones_like, args)) def scan_fun(carry, i): y_bar, t0_bar, args_bar = carry @@ -828,14 +865,15 @@ def scan_fun(carry, i): _, y_bar, t0_bar, args_bar = jax_bdf_integrate( aug_dynamics, (ys[i], y_bar, t0_bar, args_bar), jnp.array([-ts[i], -ts[i - 1]]), - *args, rtol=rtol, atol=atol) + *args, mass=aug_mass, + rtol=rtol, atol=atol) y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar)) # Add gradient from current output y_bar = y_bar + g[i - 1] return (y_bar, t0_bar, args_bar), t_bar init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args)) - (y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan( + (y_bar, t0_bar, args_bar), rev_ts_bar = flax_scan( scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)) ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]]) return (y_bar, ts_bar, *args_bar) @@ -888,6 +926,19 @@ def abstractify(x): return core.raise_to_shaped(core.get_aval(x)) +def ravel_2d_pytree(pytree): + leaves, treedef = tree_flatten(pytree) + flat, unravel_list = jax.api.vjp(ravel_2d_list, *leaves) + + def unravel_pytree(flat): + return tree_unflatten(treedef, unravel_list(flat)) + return flat, unravel_pytree + + +def ravel_2d_list(*lst): + return jnp.concatenate([jnp.ravel(elt) for elt in lst]) if lst else jnp.array([]) + + def ravel_first_arg(f, unravel): return ravel_first_arg_(lu.wrap_init(f), unravel).call_wrapped diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index 26a935461d..41d8288ed8 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -54,6 +54,46 @@ def fun(y, t): np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), rtol=1e-7, atol=1e-7) + def test_mass_matrix(self): + # Solve + t_eval = np.linspace(0.0, 1.0, 80) + + def fun(y, t): + return jax.numpy.stack([ + 0.1 * y[0], + y[1] - 2.0 * y[0], + ]) + + mass = jax.numpy.array([ + [1.0, 0.0], + [0.0, 0.0], + ]) + + y0 = jax.numpy.array([1.0, 2.0]) + + t0 = time.perf_counter() + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) + t1 = time.perf_counter() - t0 + + # test accuracy + soln = np.exp(0.1 * t_eval) + np.testing.assert_allclose(y[:, 0], soln, + rtol=1e-7, atol=1e-7) + np.testing.assert_allclose(y[:, 1], 2.0 * soln, + rtol=1e-7, atol=1e-7) + + t0 = time.perf_counter() + y = pybamm.jax_bdf_integrate(fun, y0, t_eval, mass=mass, rtol=1e-8, atol=1e-8) + t2 = time.perf_counter() - t0 + + # second run should be much quicker + self.assertLess(t2, t1) + + # test second run is accurate + np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval), + rtol=1e-7, atol=1e-7) + + def test_solver_sensitivities(self): # Create model model = pybamm.BaseModel()