Skip to content

Commit

Permalink
#1104 convert to diagonal mass matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 16, 2020
1 parent 957611e commit 02bfd2c
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 36 deletions.
72 changes: 48 additions & 24 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _bdf_odeint(fun, mass, rtol, atol, y0, t_eval, *args):
function to evaluate the time derivative of the solution `y` at time
`t` as `func(y, t, *args)`, producing the same shape/structure as `y0`.
mass: ndarray
constant mass matrix with shape (n,n)
diagonal of the mass matrix with shape (n,)
y0: ndarray
initial state vector, has shape (n,)
t_eval: ndarray
Expand Down Expand Up @@ -147,7 +147,7 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
function with signature (y, t), where t is a scalar time and y is a ndarray with
shape (n,), returns the jacobian matrix of fun as an ndarray with shape (n,n)
mass: ndarray
constant mass matrix with shape (n,n)
diagonal of the mass matrix with shape (n,)
t0: float
initial time
y0: ndarray
Expand Down Expand Up @@ -201,7 +201,10 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):

J = jac(y0, t0)
state['J'] = J
state['LU'] = jax.scipy.linalg.lu_factor(state['M'] - c * J)

Psi = -c * J
Psi = jax.ops.index_add(Psi, jnp.diag_indices_from(J), state['M'])
state['LU'] = jax.scipy.linalg.lu_factor(Psi)

state['U'] = _compute_R(order, 1)
state['psi'] = None
Expand Down Expand Up @@ -241,7 +244,7 @@ def _compute_R(order, factor):

def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
# identify algebraic variables as zeros on diagonal
algebraic_variables = jnp.diag(M) == 0.
algebraic_variables = M == 0.

# if all differentiable variables then return y0 (can use normal python if since M
# is static)
Expand Down Expand Up @@ -399,7 +402,9 @@ def _update_step_size_and_lu(state, factor):
state = _update_step_size(state, factor)

# redo lu (c has changed)
LU = jax.scipy.linalg.lu_factor(state.M - state.c * state.J)
Psi = -state.c * state.J
Psi = jax.ops.index_add(Psi, jnp.diag_indices_from(state.J), state.M)
LU = jax.scipy.linalg.lu_factor(Psi)
n_lu_decompositions = state.n_lu_decompositions + 1

return state._replace(LU=LU, n_lu_decompositions=n_lu_decompositions)
Expand Down Expand Up @@ -450,7 +455,9 @@ def _update_jacobian(state, jac):
"""
J = jac(state.y0, state.t + state.h)
n_jacobian_evals = state.n_jacobian_evals + 1
LU = jax.scipy.linalg.lu_factor(state.M - state.c * J)
Psi = -state.c * state.J
Psi = jax.ops.index_add(Psi, jnp.diag_indices_from(state.J), state.M)
LU = jax.scipy.linalg.lu_factor(Psi)
n_lu_decompositions = state.n_lu_decompositions + 1
return state._replace(J=J, n_jacobian_evals=n_jacobian_evals, LU=LU,
n_lu_decompositions=n_lu_decompositions)
Expand Down Expand Up @@ -482,7 +489,7 @@ def while_body(while_state):
k, converged, dy_norm_old, d, y, n_function_evals = while_state
f_eval = fun(y, t)
n_function_evals += 1
b = c * f_eval - M @ (psi + d)
b = c * f_eval - M * (psi + d)
dy = jax.scipy.linalg.lu_solve(LU, b)
dy_norm = jnp.sqrt(jnp.mean((dy / scale_y0)**2))
rate = dy_norm / dy_norm_old
Expand Down Expand Up @@ -765,7 +772,7 @@ def jax_bdf_integrate(func, y0, t_eval, *args, rtol=1e-6, atol=1e-6, mass=None):
atol: (optional) float
absolute tolerance for the solver
mass: (optional) ndarray
constant mass matrix with shape (n,n)
diagonal of the mass matrix with shape (n,)
Returns
-------
Expand Down Expand Up @@ -794,6 +801,9 @@ def _check_arg(arg):
in_avals = tuple(safe_map(abstractify, flat_args))
converted, consts = closure_convert(func, in_tree, in_avals)

if mass is None:
mass = onp.ones(y0.shape[0], dtype=y0.dtype)

return _bdf_odeint_wrapper(converted, mass, rtol, atol, y0, t_eval, *consts, *args)


Expand Down Expand Up @@ -834,10 +844,7 @@ def flax_scan(f, init, xs, length=None): # pragma: no cover
@jax.partial(jax.jit, static_argnums=(0, 1, 2, 3))
def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
y0, unravel = ravel_pytree(y0)
if mass is None:
mass = jnp.identity(y0.shape[0], dtype=y0.dtype)
else:
mass = block_diag(tree_flatten(mass)[0])
mass, _ = ravel_pytree(mass)
func = ravel_first_arg(func, unravel)
out = _bdf_odeint(func, mass, rtol, atol, y0, ts, *args)
return jax.vmap(unravel)(out)
Expand All @@ -851,8 +858,6 @@ def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args):
def _bdf_odeint_rev(func, mass, rtol, atol, res, g):
ys, ts, args = res

diag_mass = jnp.diag(mass)

def aug_dynamics(augmented_state, t, *args):
"""Original system augmented with vjp_y, vjp_t and vjp_args."""
y, y_bar, *_ = augmented_state
Expand All @@ -862,27 +867,46 @@ def aug_dynamics(augmented_state, t, *args):

# Adjoint equations for semi-explicit dae index 1 system from
#
# Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity
# [1] Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity
# analysis for differential-algebraic equations: The adjoint DAE system and its
# numerical solution. SIAM journal on scientific computing, 24(3), 1076-1089.
#
# y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a
# 0 = J_da^T y_bar_d + J_aa^T y_bar_d

y_bar_dot, *rest = vjpfun(y_bar)
# identify algebraic variables as zeros on diagonal
y_bar_dot = jnp.where(diag_mass == 0., -y_bar_dot, y_bar_dot)

return (-y_dot, y_bar_dot, *rest)

y_bar = g[-1]
algebraic_variables = mass == 0.
differentiable_variables = algebraic_variables == False # noqa: E712

def initialise(g0, y0, t0):
# [1] gives init conditions for y_bar_a = g_d - J_ad^T (J_aa^T)^-1 g_a
if jnp.any(algebraic_variables):
J = jax.jacfwd(func)(y0, t0, *args)

# boolean arguments not implemented in jnp.ix_
J_aa = J[onp.ix_(algebraic_variables, algebraic_variables)]
J_ad = J[onp.ix_(algebraic_variables, differentiable_variables)]
LU = jax.scipy.linalg.lu_factor(J_aa)
g0_a = g0[algebraic_variables]
invJ_aa = jax.scipy.linalg.lu_solve(LU, g0_a)
y_bar = jax.ops.index_update(
g0, differentiable_variables,
(g0_a - J_ad @ invJ_aa) / mass[differentiable_variables]
)
else:
y_bar = g0 / mass
return y_bar

y_bar = initialise(g[-1], ys[-1], ts[-1])
ts_bar = []
t0_bar = 0.

def arg_to_identity(arg):
return jnp.identity(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype)

aug_mass = (mass, mass, jnp.array(1.), tree_map(arg_to_identity, args))
def arg_to_ones(arg):
return onp.ones(arg.shape[0] if arg.ndim > 0 else 1, dtype=arg.dtype)
aug_mass = (mass, mass, jnp.array(1.), tree_map(arg_to_ones, args))

def scan_fun(carry, i):
y_bar, t0_bar, args_bar = carry
Expand All @@ -897,10 +921,10 @@ def scan_fun(carry, i):
rtol=rtol, atol=atol)
y_bar, t0_bar, args_bar = tree_map(op.itemgetter(1), (y_bar, t0_bar, args_bar))
# Add gradient from current output
y_bar = y_bar + g[i - 1]
y_bar = y_bar + initialise(g[i - 1], ys[i - 1], ts[i - 1])
return (y_bar, t0_bar, args_bar), t_bar

init_carry = (g[-1], 0., tree_map(jnp.zeros_like, args))
init_carry = (y_bar, t0_bar, tree_map(jnp.zeros_like, args))
(y_bar, t0_bar, args_bar), rev_ts_bar = jax.lax.scan(
scan_fun, init_carry, jnp.arange(len(ts) - 1, 0, -1))
ts_bar = jnp.concatenate([jnp.array([t0_bar]), rev_ts_bar[::-1]])
Expand Down
6 changes: 4 additions & 2 deletions pybamm/solvers/jax_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def create_solve(self, model, t_eval):
y0 = jnp.array(model.y0).reshape(-1)
mass = None
if self.method == 'BDF':
mass = model.mass_matrix.entries.toarray()
mass = model.mass_matrix.entries.diagonal()
if onp.count_nonzero(mass) != model.mass_matrix.entries.nnz:
raise RuntimeError("Solver only supports a diagonal mass matrix")

def rhs_ode(y, t, inputs):
return model.rhs_eval(t, y, inputs),
Expand All @@ -146,7 +148,7 @@ def solve_model_rk45(inputs):
return jnp.transpose(y)

def solve_model_bdf(inputs):
y, stepper = pybamm.jax_bdf_integrate(
y = pybamm.jax_bdf_integrate(
rhs_dae,
y0,
t_eval,
Expand Down
14 changes: 4 additions & 10 deletions tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,7 @@ def fun(y, t):
y[1] - 2.0 * y[0],
])

mass = jax.numpy.array([
[1.0, 0.0],
[0.0, 0.0],
])
mass = jax.numpy.array([2.0, 0.0])

# give some bad initial conditions, solver should calculate correct ones using
# this as a guess
Expand All @@ -78,7 +75,7 @@ def fun(y, t):
t1 = time.perf_counter() - t0

# test accuracy
soln = np.exp(0.1 * t_eval)
soln = np.exp(0.05 * t_eval)
np.testing.assert_allclose(y[:, 0], soln,
rtol=1e-7, atol=1e-7)
np.testing.assert_allclose(y[:, 1], 2.0 * soln,
Expand All @@ -92,7 +89,7 @@ def fun(y, t):
self.assertLess(t2, t1)

# test second run is accurate
np.testing.assert_allclose(y[:, 0], np.exp(0.1 * t_eval),
np.testing.assert_allclose(y[:, 0], np.exp(0.05 * t_eval),
rtol=1e-7, atol=1e-7)

def test_solver_sensitivities(self):
Expand Down Expand Up @@ -150,10 +147,7 @@ def fun(y, t, inputs):
y[1] - 2.0 * y[0],
])

mass = jax.numpy.array([
[1.0, 0.0],
[0.0, 0.0],
])
mass = jax.numpy.array([2.0, 0.0])

y0 = jax.numpy.array([1.0, 2.0])

Expand Down

0 comments on commit 02bfd2c

Please sign in to comment.