Skip to content

Commit

Permalink
#1104 fix scaling of newton iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 13, 2020
1 parent a42f4e5 commit 5f2db7c
Showing 1 changed file with 22 additions and 32 deletions.
54 changes: 22 additions & 32 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ def for_body(j, y_out):
_bdf_interpolate(stepper, t))
return y_out

y_out = flax_fori_loop(i, index, for_body, y_out)
y_out = jax.lax.fori_loop(i, index, for_body, y_out)
return [stepper, t_eval, index, y_out, n_steps + 1]

stepper, t_eval, i, y_out, n_steps = flax_while_loop(cond_fun, body_fun,
stepper, t_eval, i, y_out, n_steps = jax.lax.while_loop(cond_fun, body_fun,
init_state)

stepper['n_steps'] = n_steps
Expand Down Expand Up @@ -148,13 +148,11 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
state['t'] = t0
state['y'] = y0
f0 = fun(y0, t0)
print('f0 is ', f0)
state['atol'] = atol
state['rtol'] = rtol
order = 1
state['order'] = order
state['h'] = _select_initial_step(state, fun, t0, y0, f0, h0)
print('h0 is ',state['h'])
EPS = jnp.finfo(y0.dtype).eps
state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5))))
state['n_equal_steps'] = 0
Expand All @@ -169,7 +167,6 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
state['M'] = jnp.identity(len(y0), dtype=y0.dtype)
else:
state['M'] = mass
print('mass is ', state['M'])

# kappa values for difference orders, taken from Table 1 of [1]
kappa = jnp.array([0, -0.1850, -1 / 9, -0.0823, -0.0415, 0])
Expand All @@ -186,7 +183,6 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):

J = jac(y0, t0)
state['J'] = J
print('J is ', J)
state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - c * J)

state['U'] = _compute_R(order, 1)
Expand Down Expand Up @@ -307,7 +303,7 @@ def while_body(while_state):
i -= 1
return [i, D]

i, D = flax_while_loop(while_cond, while_body, while_state)
i, D = jax.lax.while_loop(while_cond, while_body, while_state)

state['D'] = D

Expand All @@ -320,7 +316,7 @@ def update_psi_and_predict(state):

return state

state = flax_cond(only_update_D == False, # noqa: E712
state = jax.lax.cond(only_update_D == False, # noqa: E712
state, update_psi_and_predict,
state, lambda x: x)

Expand All @@ -336,7 +332,6 @@ def _update_step_size(state, factor, dont_update_lu):
- lu factorisation of (M - c * J) used in newton iteration (same equation)
- psi term
"""
print('update_step_size',factor)
order = state['order']
h = state['h']

Expand All @@ -350,7 +345,7 @@ def update_lu(state):
state['n_lu_decompositions'] += 1
return state

state = flax_cond(dont_update_lu == False, # noqa: E712
state = jax.lax.cond(dont_update_lu == False, # noqa: E712
state, update_lu,
state, lambda x: x)

Expand Down Expand Up @@ -385,7 +380,6 @@ def _update_jacobian(state, jac):
we update the jacobian using J(t_{n+1}, y^0_{n+1})
following the scipy bdf implementation rather than J(t_n, y_n) as per [1]
"""
print('update_jacobian')
J = jac(state['y0'], state['t'] + state['h'])
state['n_jacobian_evals'] += 1
state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - state['c'] * J)
Expand All @@ -395,7 +389,6 @@ def _update_jacobian(state, jac):


def _newton_iteration(state, fun):
print('newton iterate')
tol = state['newton_tol']
c = state['c']
psi = state['psi']
Expand All @@ -420,7 +413,7 @@ def while_body(while_state):
k, not_converged, dy_norm_old, d, y, state = while_state
f_eval = fun(y, t)
state['n_function_evals'] += 1
b = c * f_eval - M @ (psi - d)
b = c * f_eval - M @ (psi + d)
dy = jax.scipy.linalg.lu_solve(LU, b)
dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0)**2))
rate = dy_norm / dy_norm_old
Expand All @@ -430,7 +423,7 @@ def while_body(while_state):
pred = rate >= 1
pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol
pred *= dy_norm_old >= 0
k = flax_cond(pred, k, lambda k: NEWTON_MAXITER, k, lambda k: k)
k = jax.lax.cond(pred, k, lambda k: NEWTON_MAXITER, k, lambda k: k)

d += dy
y = y0 + d
Expand All @@ -450,19 +443,18 @@ def not_converged_fun(not_converged):
dy_norm_old = dy_norm

not_converged = \
flax_cond(pred,
jax.lax.cond(pred,
not_converged, converged_fun,
not_converged, not_converged_fun)
return [k + 1, not_converged, dy_norm_old, d, y, state]

k, not_converged, dy_norm_old, d, y, state = flax_while_loop(while_cond,
k, not_converged, dy_norm_old, d, y, state = jax.lax.while_loop(while_cond,
while_body,
while_state)
return not_converged, k, y, d, state


def _bdf_step(state, fun, jac):
print('bdf step', state['t'])
# we will try and use the old jacobian unless convergence of newton iteration
# fails
not_updated_jacobian = True
Expand Down Expand Up @@ -528,7 +520,6 @@ def converged(if_state2):

def error_too_large(if_state3):
# error too large, reduce step size and try again
print('error too large')
state, step_accepted = if_state3
state['n_error_test_failures'] += 1
# calculate optimal step size factor as per eq 2.46 of [2]
Expand All @@ -545,26 +536,26 @@ def accept_step(if_state3):
return [state, step_accepted]

state, step_accepted = \
flax_cond(error_norm > 1,
jax.lax.cond(error_norm > 1,
if_state3, error_too_large,
if_state3, accept_step)

return [state, step_accepted]

state, step_accepted = flax_cond(not_converged,
state, step_accepted = jax.lax.cond(not_converged,
if_state2, need_to_update_step_size,
if_state2, converged)

return [state, not_updated_jacobian, step_accepted]

state, not_updated_jacobian, step_accepted = \
flax_cond(pred,
jax.lax.cond(pred,
if_state, need_to_update_jacobian,
if_state, dont_need_to_update_jacobian)
return [state, step_accepted, not_updated_jacobian, y, d, n_iter]

state, step_accepted, not_updated_jacobian, y, d, n_iter = \
flax_while_loop(while_cond, while_body, while_state)
jax.lax.while_loop(while_cond, while_body, while_state)

# take the accepted step
state['y'] = y
Expand Down Expand Up @@ -611,7 +602,7 @@ def order_equal_one(if_state2):
error_m_norm = jnp.inf
return error_m_norm

error_m_norm = flax_cond(order > 1,
error_m_norm = jax.lax.cond(order > 1,
if_state2, order_greater_one,
if_state2, order_equal_one)

Expand All @@ -625,7 +616,7 @@ def order_max(if_state2):
error_p_norm = jnp.inf
return error_p_norm

error_p_norm = flax_cond(order < MAX_ORDER,
error_p_norm = jax.lax.cond(order < MAX_ORDER,
if_state2, order_less_max,
if_state2, order_max)

Expand All @@ -643,7 +634,7 @@ def order_max(if_state2):

return state

state = flax_cond(state['n_equal_steps'] < state['order'] + 1,
state = jax.lax.cond(state['n_equal_steps'] < state['order'] + 1,
if_state, no_change_in_order,
if_state, order_change)

Expand Down Expand Up @@ -676,7 +667,7 @@ def while_body(while_state):
j += 1
return [j, time_factor, order_summation]

j, time_factor, order_summation = flax_while_loop(while_cond,
j, time_factor, order_summation = jax.lax.while_loop(while_cond,
while_body,
while_state)
return order_summation
Expand Down Expand Up @@ -778,7 +769,7 @@ def _check_arg(arg):
def flax_cond(pred, true_operand, true_fun,
false_operand, false_fun): # pragma: no cover
"""
for debugging purposes, use this instead of flax_cond
for debugging purposes, use this instead of jax.lax.cond
"""
if pred:
return true_fun(true_operand)
Expand All @@ -788,7 +779,7 @@ def flax_cond(pred, true_operand, true_fun,

def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover
"""
for debugging purposes, use this instead of flax_while_loop
for debugging purposes, use this instead of jax.lax.while_loop
"""
val = init_val
while cond_fun(val):
Expand All @@ -798,7 +789,7 @@ def flax_while_loop(cond_fun, body_fun, init_val): # pragma: no cover

def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover
"""
for debugging purposes, use this instead of flax_fori_loop
for debugging purposes, use this instead of jax.lax.fori_loop
"""
val = init_val
for i in range(start, stop):
Expand All @@ -808,7 +799,7 @@ def flax_fori_loop(start, stop, body_fun, init_val): # pragma: no cover

def flax_scan(f, init, xs, length=None): # pragma: no cover
"""
for debugging purposes, use this instead of flax_scan
for debugging purposes, use this instead of jax.lax.scan
"""
if xs is None:
xs = [None] * length
Expand All @@ -825,7 +816,6 @@ def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
if mass is not None:
mass = block_diag(tree_flatten(mass)[0])
print('mass matrix is ', mass)
func = ravel_first_arg(func, unravel)
out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
return jax.vmap(unravel)(out)
Expand Down Expand Up @@ -869,7 +859,7 @@ def scan_fun(carry, i):
return (y_bar, t0_bar, args_bar), t_bar

init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
(y_bar, t0_bar, args_bar), rev_ts_bar = flax_scan(
(y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan(
scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
return (y_bar, ts_bar, *args_bar)
Expand Down

0 comments on commit 5f2db7c

Please sign in to comment.