From f08dc5726e1b6c768dcc241a8a9f6cfdcb9a2909 Mon Sep 17 00:00:00 2001 From: Martin Robinson Date: Wed, 15 Jul 2020 08:06:50 +0100 Subject: [PATCH] #1104 put in proper adjoint equations for semi-explicit dae index 1 --- pybamm/solvers/jax_bdf_solver.py | 27 +++++++++++++++---- .../unit/test_solvers/test_jax_bdf_solver.py | 1 - 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/pybamm/solvers/jax_bdf_solver.py b/pybamm/solvers/jax_bdf_solver.py index 67a8a92a35..e356c04f7c 100644 --- a/pybamm/solvers/jax_bdf_solver.py +++ b/pybamm/solvers/jax_bdf_solver.py @@ -239,7 +239,7 @@ def _compute_R(order, factor): def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0): - # identify differentiable variables as zeros on diagonal + # identify algebraic variables as zeros on diagonal algebraic_variables = jnp.diag(M) == 0. # if all differentiable variables then return y0 (can use normal python if since M @@ -700,9 +700,10 @@ def block_fun(i, j, Ai, Aj): return jnp.block(blocks) -# NOTE: all code below (except the docstring on jax_bdf_integrate and other minor -# edits), to define the API of the jax solver and the ability to solve the adjoint -# sensitivities, has been copied from the JAX library at https://github.com/google/jax. +# NOTE: the code below (except the docstring on jax_bdf_integrate and other minor +# edits), has been modified from the JAX library at https://github.com/google/jax. +# The main difference is the addition of support for semi-explicit dae index 1 problems +# via the addition of a mass matrix. # This is under an Apache license, a short form of which is given here: # # Copyright 2018 Google LLC @@ -833,13 +834,29 @@ def _bdf_odeint_fwd(func, mass, rtol, atol, y0, ts, *args): def _bdf_odeint_rev(func, mass, rtol, atol, res, g): ys, ts, args = res + diag_mass = jnp.diag(mass) + def aug_dynamics(augmented_state, t, *args): """Original system augmented with vjp_y, vjp_t and vjp_args.""" y, y_bar, *_ = augmented_state # `t` here is negative time, so we need to negate again to get back to # normal time. See the `odeint` invocation in `scan_fun` below. y_dot, vjpfun = jax.vjp(func, y, -t, *args) - return (-y_dot, *vjpfun(y_bar)) + + # Adjoint equations for semi-explicit dae index 1 system from + # + # Cao, Y., Li, S., Petzold, L., & Serban, R. (2003). Adjoint sensitivity + # analysis for differential-algebraic equations: The adjoint DAE system and its + # numerical solution. SIAM journal on scientific computing, 24(3), 1076-1089. + # + # y_bar_dot_d = -J_dd^T y_bar_d - J_ad^T y_bar_a + # 0 = J_da^T y_bar_d + J_aa^T y_bar_d + + y_bar_dot, t_bar, args_bar = vjpfun(y_bar) + # identify algebraic variables as zeros on diagonal + y_bar_dot = jnp.where(diag_mass == 0., -y_bar_dot, y_bar_dot) + + return (-y_dot, y_bar_dot, t_bar, args_bar) y_bar = g[-1] ts_bar = [] diff --git a/tests/unit/test_solvers/test_jax_bdf_solver.py b/tests/unit/test_solvers/test_jax_bdf_solver.py index 9674145af2..f76871387e 100644 --- a/tests/unit/test_solvers/test_jax_bdf_solver.py +++ b/tests/unit/test_solvers/test_jax_bdf_solver.py @@ -140,7 +140,6 @@ def solve_bdf(rate): self.assertAlmostEqual(grad_bdf, grad_num, places=3) - @unittest.skip("sensitivities not yet supported on for dae models") def test_mass_matrix_with_sensitivities(self): # Solve t_eval = np.linspace(0.0, 1.0, 80)