Skip to content

Commit

Permalink
better steps=True
Browse files Browse the repository at this point in the history
  • Loading branch information
sammccallum committed Nov 25, 2024
1 parent 621e6f4 commit acaa35f
Showing 1 changed file with 46 additions and 28 deletions.
74 changes: 46 additions & 28 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,20 +918,30 @@ def _loop_reversible_bwd(
del residuals, solver_state1

grad_final_state, _ = grad_final_state__aux_stats
# If true we must be using SaveAt(t1=True)
if saveat.subs.t1:
# If true we must be using SaveAt(t1=True).
t1_only = saveat.subs.t1
if t1_only:
y1 = ys[-1]
grad_y1 = grad_final_state.save_state.ys[-1]
ts_length = ts.shape[0]
ys_dim = ys.shape[1]
grad_ys = jnp.zeros((ts_length, ys_dim), dtype=ys.dtype)
grad_ys = grad_ys.at[ts_final_index].set(grad_y1)
# If false then we must be using SaveAt(t0=True, steps=True)
grad_ys = grad_final_state.save_state.ys[-1]
grad_ys = jtu.tree_map(_materialise_none, y1, grad_ys)
grad_y0_zeros = jtu.tree_map(jnp.zeros_like, grad_ys)

# Otherwise we must be using SaveAt(..., steps=True) due to the guard in
# ReversibleAdjoint. If y0 is not saved (t0=False) then we prepend grad_y0 (zeros).
else:
y1 = ys[ts_final_index]
grad_ys = grad_final_state.save_state.ys
if saveat.subs.t0:
y1 = ys[ts_final_index]
grad_ys = grad_final_state.save_state.ys
else:
y1 = ys[ts_final_index - 1]
grad_ys = grad_final_state.save_state.ys
grad_y0 = jtu.tree_map(lambda x: jnp.zeros_like(x[0]), grad_ys)
grad_ys = jtu.tree_map(
lambda x, y: jnp.concatenate([x[None], y]), grad_y0, grad_ys
)

grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)

grad_ys = jtu.tree_map(_materialise_none, ys, grad_ys)
del grad_final_state, grad_final_state__aux_stats

y, args, terms = y__args__terms
Expand All @@ -958,8 +968,13 @@ def solver_step(terms, t0, t1, y1, args, original_solver_state):
t1 = ts[ts_index]
t0 = ts[ts_index - 1]

grad_y1 = grad_ys[ts_index]
grad_y0 = grad_ys[ts_index - 1]
if t1_only:
grad_y1 = grad_ys
grad_y0 = grad_y0_zeros # pyright: ignore

else:
grad_y1 = grad_ys[ts_index]
grad_y0 = grad_ys[ts_index - 1]

step_y1, vjp_fun_y1, original_solver_state = eqx.filter_vjp(
solver_step, terms, t1, t0, y1, args, original_solver_state, has_aux=True
Expand All @@ -982,8 +997,11 @@ def solver_step(terms, t0, t1, y1, args, original_solver_state):
grad_terms = (ω(grad_terms) - ω(grad_step_y1[0]) + ω(grad_step_z0[0])).ω
grad_args = (ω(grad_args) - ω(grad_step_y1[4]) + ω(grad_step_z0[4])).ω

grad_ys = grad_ys.at[ts_index].set(grad_y1)
grad_ys = grad_ys.at[ts_index - 1].set(grad_y0)
if t1_only:
grad_ys = grad_y0
else:
grad_ys = grad_ys.at[ts_index].set(grad_y1)
grad_ys = grad_ys.at[ts_index - 1].set(grad_y0)

ts_index = ts_index - 1

Expand Down Expand Up @@ -1013,7 +1031,10 @@ def cond_fun(state):

state = eqxi.while_loop(cond_fun, grad_step, state, kind="lax")
_, _, _, grad_ys, grad_z0, grad_args, grad_terms = state
grad_y0 = grad_ys[0]
if t1_only:
grad_y0 = grad_ys
else:
grad_y0 = grad_ys[0]

return (ω(grad_y0) + ω(grad_z0)).ω, grad_args, grad_terms

Expand Down Expand Up @@ -1048,18 +1069,15 @@ def loop(
**kwargs,
):
# `is` check because this may return a Tracer from SaveAt(ts=<array>)
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True:
if eqx.tree_equal(saveat, SaveAt(steps=True)) is True:
raise ValueError(
"If saving steps, include `t0` by "
"`saveat=SaveAt(t0=True, steps=True)`."
)

elif eqx.tree_equal(saveat, SaveAt(t0=True, steps=True)) is not True:
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`saveat=SaveAt(t1=True)` or `saveat=SaveAt(t0=True, steps=True)`."
)
if (
eqx.tree_equal(saveat, SaveAt(t1=True)) is not True
and eqx.tree_equal(saveat, SaveAt(steps=True)) is not True
and eqx.tree_equal(saveat, SaveAt(t0=True, steps=True)) is not True
):
raise ValueError(
"Can only use `adjoint=ReversibleAdjoint()` with "
"`saveat=SaveAt(t1=True)` or `saveat=SaveAt(steps=True)`."
)

if not isinstance(solver, Reversible):
raise ValueError(
Expand Down

0 comments on commit acaa35f

Please sign in to comment.