diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 1ad65107d3..17562d3bc5 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -102,10 +102,10 @@ def for_body(j, y_out): _bdf_interpolate(stepper, t)) return y_out - y_out = flax_fori_loop(i, index, for_body, y_out) + y_out = jax.lax.fori_loop(i, index, for_body, y_out) return [stepper, t_eval, index, y_out, n_steps + 1] - stepper, t_eval, i, y_out, n_steps = flax_while_loop(cond_fun, body_fun, + stepper, t_eval, i, y_out, n_steps = jax.lax.while_loop(cond_fun, body_fun, init_state) stepper['n_steps'] = n_steps @@ -148,13 +148,11 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): state['t'] = t0 state['y'] = y0 f0 = fun(y0, t0) - print('f0 is ', f0) state['atol'] = atol state['rtol'] = rtol order = 1 state['order'] = order state['h'] = _select_initial_step(state, fun, t0, y0, f0, h0) - print('h0 is ',state['h']) EPS = jnp.finfo(y0.dtype).eps state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5)))) state['n_equal_steps'] = 0 @@ -169,7 +167,6 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): state['M'] = jnp.identity(len(y0), dtype=y0.dtype) else: state['M'] = mass - print('mass is ', state['M']) # kappa values for difference orders, taken from Table 1 of [1] kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0]) @@ -186,7 +183,6 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol): J = jac(y0, t0) state['J'] = J - print('J is ', J) state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - c * J) state['U'] = _compute_R(order, 1) @@ -307,7 +303,7 @@ def while_body(while_state): i -= 1 return [i, D] - i, D = flax_while_loop(while_cond, while_body, while_state) + i, D = jax.lax.while_loop(while_cond, while_body, while_state) state['D'] = D @@ -320,7 +316,7 @@ def update_psi_and_predict(state): return state - state = flax_cond(only_update_D == False, # noqa: E712 + state = jax.lax.cond(only_update_D == False, # noqa: E712 state, update_psi_and_predict, state, lambda x: x) @@ -336,7 +332,6 @@ def _update_step_size(state, factor, dont_update_lu): - lu factorisation of (M - c * J) used in newton iteration (same equation) - psi term """ - print('update_step_size',factor) order = state['order'] h = state['h'] @@ -350,7 +345,7 @@ def update_lu(state): state['n_lu_decompositions'] += 1 return state - state = flax_cond(dont_update_lu == False, # noqa: E712 + state = jax.lax.cond(dont_update_lu == False, # noqa: E712 state, update_lu, state, lambda x: x) @@ -385,7 +380,6 @@ def _update_jacobian(state, jac): we update the jacobian using J(t_{n+1}, y^0_{n+1}) following the scipy bdf implementation rather than J(t_n, y_n) as per [1] """ - print('update_jacobian') J = jac(state['y0'], state['t'] + state['h']) state['n_jacobian_evals'] += 1 state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - state['c'] * J) @@ -395,7 +389,6 @@ def _update_jacobian(state, jac): def _newton_iteration(state, fun): - print('newton iterate') tol = state['newton_tol'] c = state['c'] psi = state['psi'] @@ -420,7 +413,7 @@ def while_body(while_state): k, not_converged, dy_norm_old, d, y, state = while_state f_eval = fun(y, t) state['n_function_evals'] += 1 - b = c * f_eval - M @ (psi - d) + b = c * f_eval - M @ (psi + d) dy = jax.scipy.linalg.lu_solve(LU, b) dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0)**2)) rate = dy_norm / dy_norm_old @@ -430,7 +423,7 @@ def while_body(while_state): pred = rate >= 1 pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol pred *= dy_norm_old >= 0 - k = flax_cond(pred, k, lambda k: NEWTON_MAXITER, k, lambda k: k) + k = jax.lax.cond(pred, k, lambda k: NEWTON_MAXITER, k, lambda k: k) d += dy y = y0 + d @@ -450,19 +443,18 @@ def not_converged_fun(not_converged): dy_norm_old = dy_norm not_converged = \ - flax_cond(pred, + jax.lax.cond(pred, not_converged, converged_fun, not_converged, not_converged_fun) return [k + 1, not_converged, dy_norm_old, d, y, state] - k, not_converged, dy_norm_old, d, y, state = flax_while_loop(while_cond, + k, not_converged, dy_norm_old, d, y, state = jax.lax.while_loop(while_cond, while_body, while_state) return not_converged, k, y, d, state def _bdf_step(state, fun, jac): - print('bdf step', state['t']) # we will try and use the old jacobian unless convergence of newton iteration # fails not_updated_jacobian = True @@ -528,7 +520,6 @@ def converged(if_state2): def error_too_large(if_state3): # error too large, reduce step size and try again - print('error too large') state, step_accepted = if_state3 state['n_error_test_failures'] += 1 # calculate optimal step size factor as per eq 2.46 of [2] @@ -545,26 +536,26 @@ def accept_step(if_state3): return [state, step_accepted] state, step_accepted = \ - flax_cond(error_norm > 1, + jax.lax.cond(error_norm > 1, if_state3, error_too_large, if_state3, accept_step) return [state, step_accepted] - state, step_accepted = flax_cond(not_converged, + state, step_accepted = jax.lax.cond(not_converged, if_state2, need_to_update_step_size, if_state2, converged) return [state, not_updated_jacobian, step_accepted] state, not_updated_jacobian, step_accepted = \ - flax_cond(pred, + jax.lax.cond(pred, if_state, need_to_update_jacobian, if_state, dont_need_to_update_jacobian) return [state, step_accepted, not_updated_jacobian, y, d, n_iter] state, step_accepted, not_updated_jacobian, y, d, n_iter = \ - flax_while_loop(while_cond, while_body, while_state) + jax.lax.while_loop(while_cond, while_body, while_state) # take the accepted step state['y'] = y @@ -611,7 +602,7 @@ def order_equal_one(if_state2): error_m_norm = jnp.inf return error_m_norm - error_m_norm = flax_cond(order > 1, + error_m_norm = jax.lax.cond(order > 1, if_state2, order_greater_one, if_state2, order_equal_one) @@ -625,7 +616,7 @@ def order_max(if_state2): error_p_norm = jnp.inf return error_p_norm - error_p_norm = flax_cond(order < MAX_ORDER, + error_p_norm = jax.lax.cond(order < MAX_ORDER, if_state2, order_less_max, if_state2, order_max) @@ -643,7 +634,7 @@ def order_max(if_state2): return state - state = flax_cond(state['n_equal_steps'] < state['order'] + 1, + state = jax.lax.cond(state['n_equal_steps'] < state['order'] + 1, if_state, no_change_in_order, if_state, order_change) @@ -676,7 +667,7 @@ def while_body(while_state): j += 1 return [j, time_factor, order_summation] - j, time_factor, order_summation = flax_while_loop(while_cond, + j, time_factor, order_summation = jax.lax.while_loop(while_cond, while_body, while_state) return order_summation @@ -778,7 +769,7 @@ def _check_arg(arg): def flax_cond(pred, true_operand, true_fun, false_operand, false_fun): # pragma: no cover """ - for debugging purposes, use this instead of flax_cond + for debugging purposes, use this instead of jax.lax.cond """ if pred: return true_fun(true_operand) @@ -788,7 +779,7 @@ def flax_cond(pred, true_operand, true_fun, def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover """ - for debugging purposes, use this instead of flax_while_loop + for debugging purposes, use this instead of jax.lax.while_loop """ val = init_val while cond_fun(val): @@ -798,7 +789,7 @@ def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover """ - for debugging purposes, use this instead of flax_fori_loop + for debugging purposes, use this instead of jax.lax.fori_loop """ val = init_val for i in range(start, stop): @@ -808,7 +799,7 @@ def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover def flax_scan(f, init, xs, length=None): # pragma: no cover """ - for debugging purposes, use this instead of flax_scan + for debugging purposes, use this instead of jax.lax.scan """ if xs is None: xs = [None] * length @@ -825,7 +816,6 @@ def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args): y0, unravel = ravel_pytree(y0) if mass is not None: mass = block_diag(tree_flatten(mass)[0]) - print('mass matrix is ', mass) func = ravel_first_arg(func, unravel) out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args) return jax.vmap(unravel)(out) @@ -869,7 +859,7 @@ def scan_fun(carry, i): return (y_bar, t0_bar, args_bar), t_bar init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args)) - (y_bar, t0_bar, args_bar), rev_ts_bar = flax_scan( + (y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan( scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1)) ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]]) return (y_bar, ts_bar, *args_bar)