Skip to content

Commit

Permalink
#1104 put in proper adjoint equations for semi-explicit dae index 1
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Jul 15, 2020
1 parent acc8314 commit f08dc57
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
27 changes: 22 additions & 5 deletions pybamm/solvers/jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def _compute_R(order, factor):


def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
# identify differentiable variables as zeros on diagonal
# identify algebraic variables as zeros on diagonal
algebraic_variables = jnp.diag(M) == 0.

# if all differentiable variables then return y0 (can use normal python if since M
Expand Down Expand Up @@ -700,9 +700,10 @@ def block_fun(i, j, Ai, Aj):

return jnp.block(blocks)

# NOTE: all code below (except the docstring on jax_bdf_integrate and other minor
# edits), to define the API of the jax solver and the ability to solve the adjoint
# sensitivities, has been copied from the JAX library at https://github.com/google/jax.
# NOTE: the code below (except the docstring on jax_bdf_integrate and other minor
# edits), has been modified from the JAX library at https://github.com/google/jax.
# The main difference is the addition of support for semi-explicit dae index 1 problems
# via the addition of a mass matrix.
# This is under an Apache license, a short form of which is given here:
#
# Copyright 2018 Google LLC
Expand Down Expand Up @@ -833,13 +834,29 @@ def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args):
def _bdf_odeint_rev(func, mass, rtol, atol, res, g):
ys, ts, args = res

diag_mass = jnp.diag(mass)

def aug_dynamics(augmented_state, t, *args):
"""Original system augmented with vjp_y, vjp_t and vjp_args."""
y, y_bar, *_ = augmented_state
# `t` here is negative time, so we need to negate again to get back to
# normal time. See the `odeint` invocation in `scan_fun` below.
y_dot, vjpfun = jax.vjp(func, y, -t, *args)
return (-y_dot, *vjpfun(y_bar))

# Adjoint equations for semi-explicit dae index 1 system from
#
# Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity
# analysis for differential-algebraic equations: The adjoint DAE system and its
# numerical solution. SIAM journal on scientific computing, 24(3), 1076-1089.
#
# y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a
# 0 = J_da^T y_bar_d + J_aa^T y_bar_d

y_bar_dot, t_bar, args_bar = vjpfun(y_bar)
# identify algebraic variables as zeros on diagonal
y_bar_dot = jnp.where(diag_mass == 0., -y_bar_dot, y_bar_dot)

return (-y_dot, y_bar_dot, t_bar, args_bar)

y_bar = g[-1]
ts_bar = []
Expand Down
1 change: 0 additions & 1 deletion tests/unit/test_solvers/test_jax_bdf_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def solve_bdf(rate):

self.assertAlmostEqual(grad_bdf, grad_num, places=3)

@unittest.skip("sensitivities not yet supported on for dae models")
def test_mass_matrix_with_sensitivities(self):
# Solve
t_eval = np.linspace(0.0, 1.0, 80)
Expand Down

0 comments on commit f08dc57

Please sign in to comment.