Skip to content

Commit

Permalink
#1104 fix flake8 and some minor bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 14, 2020
1 parent 0ca87c4 commit 3e31c5f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 38 deletions.
85 changes: 48 additions & 37 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
import operator as op
import numpy as onp
import collections
Expand Down Expand Up @@ -83,8 +82,10 @@ def fun_bind_inputs(y, t):
t0 = t_eval[0]
h0 = t_eval[1] - t0

stepper = _bdf_init(fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol)
i = 0
stepper, failed = _bdf_init(
fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol
)
i = failed * len(t_eval)
y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype)

init_state = [stepper, t_eval, i, y_out]
Expand All @@ -108,21 +109,24 @@ def for_body(j, y_out):
return [stepper, t_eval, index, y_out]

stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun,
init_state)
init_state)

return y_out


BDFInternalStates = [
't', 'atol', 'rtol', 'M', 'newton_tol', 'order', 'h', 'n_equal_steps', 'D',
'y0', 'scale_y0', 'kappa', 'gamma', 'alpha', 'c', 'error_const', 'J', 'LU', 'U',
'psi', 'n_function_evals', 'n_jacobian_evals', 'n_lu_decompositions', 'n_steps']
't', 'atol', 'rtol', 'M', 'newton_tol', 'order', 'h', 'n_equal_steps', 'D',
'y0', 'scale_y0', 'kappa', 'gamma', 'alpha', 'c', 'error_const', 'J', 'LU', 'U',
'psi', 'n_function_evals', 'n_jacobian_evals', 'n_lu_decompositions', 'n_steps'
]
BDFState = collections.namedtuple('BDFState', BDFInternalStates)

jax.tree_util.register_pytree_node(
BDFState,
lambda xs: (tuple(xs), None),
lambda _, xs: BDFState(*xs))
BDFState,
lambda xs: (tuple(xs), None),
lambda _, xs: BDFState(*xs)
)


def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
"""
Expand Down Expand Up @@ -165,7 +169,9 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5))))

scale_y0 = atol + rtol * jnp.abs(y0)
y0 = _select_initial_conditions(fun, mass, t0, y0, state['newton_tol'], scale_y0)
y0, not_converged = _select_initial_conditions(
fun, mass, t0, y0, state['newton_tol'], scale_y0
)

f0 = fun(y0, t0)
order = 1
Expand Down Expand Up @@ -207,7 +213,7 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
tuple_state = BDFState(*[state[k] for k in BDFInternalStates])
y0, scale_y0 = _predict(tuple_state, D)
psi = _update_psi(tuple_state, D)
return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi)
return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi), not_converged


def _compute_R(order, factor):
Expand Down Expand Up @@ -239,7 +245,7 @@ def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
# if all differentiable variables then return y0 (can use normal python if since M
# is static)
if not jnp.any(algebraic_variables):
return y0
return y0, False

# calculate consistent initial conditions via a newton on -J_a @ delta = f_a This
# follows this reference:
Expand All @@ -256,7 +262,7 @@ def fun_a(y_a):
scale_y0_a = scale_y0[algebraic_variables]

d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype)
y_a = jnp.array(y0_a)
y_a = jnp.array(y0_a, copy=True)

# calculate neg jacobian of fun_a
J_a = jax.jacfwd(fun_a)(y_a)
Expand Down Expand Up @@ -290,13 +296,12 @@ def while_body(while_state):

return [k + 1, not_converged, dy_norm_old, d, y_a]


k, not_converged, dy_norm_old, d, y_a = jax.lax.while_loop(while_cond,
while_body,
while_state)
y_tilde = jax.ops.index_update(y0, algebraic_variables, y_a)

return y_tilde
return y_tilde, not_converged


def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0):
Expand Down Expand Up @@ -399,9 +404,7 @@ def _update_step_size(state, factor):
- psi term
"""
order = state.order
h = state.h

h *= factor
h = state.h * factor
n_equal_steps = 0
c = h * state.alpha[order]

Expand Down Expand Up @@ -432,6 +435,7 @@ def _update_step_size(state, factor):
n_lu_decompositions=n_lu_decompositions, h=h, c=c,
D=D, psi=psi, y0=y0, scale_y0=scale_y0)


def _update_jacobian(state, jac):
"""
we update the jacobian using J(t_{n+1}, y^0_{n+1})
Expand Down Expand Up @@ -481,7 +485,7 @@ def while_body(while_state):
pred = rate >= 1
pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol
pred *= dy_norm_old >= 0
k += pred * (NEWTON_MAXITER - k)
k += pred * (NEWTON_MAXITER - k - 1)

d += dy
y = y0 + d
Expand All @@ -495,11 +499,13 @@ def while_body(while_state):

return [k + 1, not_converged, dy_norm_old, d, y, n_function_evals]

k, not_converged, dy_norm_old, d, y, n_function_evals = jax.lax.while_loop(while_cond,
while_body,
while_state)
k, not_converged, dy_norm_old, d, y, n_function_evals = \
jax.lax.while_loop(while_cond,
while_body,
while_state)
return not_converged, k, y, d, state._replace(n_function_evals=n_function_evals)


def rms_norm(arg):
return jnp.sqrt(jnp.mean(arg**2))

Expand All @@ -508,7 +514,7 @@ def _prepare_next_step(state, d):
D = _update_difference_for_next_step(state, d)
psi = _update_psi(state, D)
y0, scale_y0 = _predict(state, D)
return state._replace(D=D,psi=psi,y0=y0,scale_y0=scale_y0)
return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0)


def _prepare_next_step_order_change(state, d, y, n_iter):
Expand Down Expand Up @@ -543,7 +549,7 @@ def _prepare_next_step_order_change(state, d, y, n_iter):
# now we have the three factors for orders k-1, k and k+1, pick the maximum in
# order to maximise the resultant step size
max_index = jnp.argmax(factors)
order = order + max_index - 1
order += max_index - 1

factor = jnp.min((MAX_FACTOR, safety * factors[max_index]))

Expand Down Expand Up @@ -578,16 +584,20 @@ def while_body(while_state):
# newton iteration did not converge, but jacobian has already been
# evaluated so reduce step size by 0.3 (as per [1]) and try again
state = tree_multimap(
partial(jnp.where, not_converged * updated_jacobian),
_update_step_size(state, 0.3),
state
partial(jnp.where, not_converged * updated_jacobian),
_update_step_size(state, 0.3),
state
)

# if not converged and jacobian not updated, then update the jacobian and try again
# if not converged and jacobian not updated, then update the jacobian and try
# again
(state, updated_jacobian) = tree_multimap(
partial(jnp.where, not_converged * (updated_jacobian == False)),
(_update_jacobian(state, jac), True),
(state, False)
partial(
jnp.where,
not_converged * (updated_jacobian == False) # noqa: E712
),
(_update_jacobian(state, jac), True),
(state, False + updated_jacobian)
)

safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
Expand All @@ -606,17 +616,19 @@ def while_body(while_state):
error_norm ** (-1 / (state.order + 1))))

(state, step_accepted) = tree_multimap(
partial(jnp.where, (not_converged == False) * (error_norm > 1)),
partial(
jnp.where,
(not_converged == False) * (error_norm > 1) # noqa: E712
),
(_update_step_size(state, factor), False),
(state, True)
(state, not_converged == False)
)

return [state, step_accepted, updated_jacobian, y, d, n_iter]

state, step_accepted, updated_jacobian, y, d, n_iter = \
jax.lax.while_loop(while_cond, while_body, while_state)


# take the accepted step
n_steps = state.n_steps + 1
t = state.t + state.h
Expand All @@ -625,7 +637,6 @@ def while_body(while_state):
# (see page 83 of [2])
n_equal_steps = state.n_equal_steps + 1


state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps)

state = tree_multimap(
Expand Down Expand Up @@ -802,7 +813,7 @@ def flax_scan(f, init, xs, length=None): # pragma: no cover
return carry, onp.stack(ys)


@partial(jax.jit, static_argnums=(0, 1, 2, 3))
@jax.partial(jax.jit, static_argnums=(0, 1, 2, 3))
def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
if mass is None:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_solvers/test_jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_model_solver(self):
t_first_solve = time.perf_counter() - t0
np.testing.assert_array_equal(solution.t, t_eval)
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t),
rtol=1e-7, atol=1e-7)
rtol=1e-6, atol=1e-6)

# Test time
self.assertEqual(
Expand Down

0 comments on commit 3e31c5f

Please sign in to comment.