Skip to content

Commit

Permalink
Disable fsal, ssal properties to allow any solver to be made reversible
Browse files Browse the repository at this point in the history
  • Loading branch information
sammccallum committed Nov 25, 2024
1 parent f160295 commit 7dfb8e3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 32 deletions.
27 changes: 12 additions & 15 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,14 +942,14 @@ def _loop_reversible_bwd(
del diff_args, diff_terms, diff_z1

def grad_step(state):
def solver_step(terms, t0, t1, y1, args):
step, _, _, _, _ = solver.solver.step(
terms, t0, t1, y1, args, (first_step, f0), False
def solver_step(terms, t0, t1, y1, args, original_solver_state):
step, _, _, original_solver_state, _ = solver.solver.step(
terms, t0, t1, y1, args, original_solver_state, False
)
return step
return step, original_solver_state

ts_index, y1, solver_state, grad_ys, grad_z1, grad_args, grad_terms = state
(first_step, f0), z1 = solver_state
original_solver_state, z1 = solver_state

t1 = ts[ts_index]
t0 = ts[ts_index - 1]
Expand All @@ -960,17 +960,14 @@ def solver_step(terms, t0, t1, y1, args):
grad_y1 = grad_ys[ts_index]
grad_y0 = grad_ys[ts_index - 1]

# TODO The solver steps switch between evaluating from z0
# and y1. Therefore, we re-evaluate f0 outside of the base
# solver to ensure the vf is correct.
# Can we avoid this re-evaluation?

f0 = solver.func(terms, t1, y1, args)
step_y1, vjp_fun_y1 = eqx.filter_vjp(solver_step, terms, t1, t0, y1, args)
step_y1, vjp_fun_y1, original_solver_state = eqx.filter_vjp(
solver_step, terms, t1, t0, y1, args, original_solver_state, has_aux=True
)
z0 = (ω(z1) - ω(y1) + ω(step_y1)).ω

f0 = solver.func(terms, t0, z0, args)
step_z0, vjp_fun_z0 = eqx.filter_vjp(solver_step, terms, t0, t1, z0, args)
step_z0, vjp_fun_z0, _ = eqx.filter_vjp(
solver_step, terms, t0, t1, z0, args, original_solver_state, has_aux=True
)

y0 = ((1 / solver.l) * (ω(y1) - ω(step_z0)) + ω(z0)).ω

Expand All @@ -995,7 +992,7 @@ def solver_step(terms, t0, t1, y1, args):
return (
ts_index,
y0,
((first_step, f0), z0),
(original_solver_state, z0),
grad_ys,
grad_z0,
grad_args,
Expand Down
37 changes: 20 additions & 17 deletions diffrax/_solver/reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
from .._solution import RESULTS, update_result
from .._term import AbstractTerm
from .base import AbstractAdaptiveSolver, AbstractWrappedSolver
from .base import AbstractAdaptiveSolver, AbstractSolver, AbstractWrappedSolver
from .runge_kutta import AbstractRungeKutta


Expand All @@ -16,25 +16,32 @@
_SolverState: TypeAlias = tuple[_BaseSolverState, Y]


def _add_maybe_none(x, y):
if x is None:
return None
else:
return x + y


class Reversible(
AbstractAdaptiveSolver[_SolverState], AbstractWrappedSolver[_SolverState]
):
"""
Reversible solver method.
Allows any Runge-Kutta method ([`diffrax.AbstractRungeKutta`][]) to be made
Allows any solver ([`diffrax.AbstractSolver`][]) to be made
algebraically reversible.
The convergence order of the reversible solver is inherited from the wrapped
Runge-Kutta method.
solver.
Backpropagation through the reversible solver implies very low memory usage and
exact gradient calculation (up to floating point errors). This is implemented in
[`diffrax.ReversibleAdjoint`][] and passed to [`diffrax.diffeqsolve`][] as
`adjoint=diffrax.ReversibleAdjoint()`.
"""

solver: AbstractRungeKutta
solver: AbstractSolver
l: RealScalarLike = 0.999

@property
Expand Down Expand Up @@ -71,6 +78,9 @@ def init(
y0: Y,
args: Args,
) -> _SolverState:
if isinstance(self.solver, AbstractRungeKutta):
object.__setattr__(self.solver.tableau, "fsal", False)
object.__setattr__(self.solver.tableau, "ssal", False)
original_solver_init = self.solver.init(terms, t0, t1, y0, args)
return (original_solver_init, y0)

Expand All @@ -84,29 +94,22 @@ def step(
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]:
(first_step, f0), z0 = solver_state

# TODO The solver steps switch between evaluating from z0
# and y1. Therefore, we re-evaluate f0 outside of the base
# solver to ensure the vf is correct.
# Can we avoid this re-evaluation?
original_solver_state, z0 = solver_state

f0 = self.func(terms, t0, z0, args)
step_z0, z_error, dense_info, _, result1 = self.solver.step(
terms, t0, t1, z0, args, (first_step, f0), made_jump
step_z0, z_error, dense_info, original_solver_state, result1 = self.solver.step(
terms, t0, t1, z0, args, original_solver_state, made_jump
)
y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω

f0 = self.func(terms, t1, y1, args)
step_y1, y_error, _, _, result2 = self.solver.step(
terms, t1, t0, y1, args, (first_step, f0), made_jump
terms, t1, t0, y1, args, original_solver_state, made_jump
)
z1 = (ω(y1) + ω(z0) - ω(step_y1)).ω

solver_state = ((first_step, f0), z1)
solver_state = (original_solver_state, z1)
result = update_result(result1, result2)

return y1, z_error + y_error, dense_info, solver_state, result
return y1, _add_maybe_none(z_error, y_error), dense_info, solver_state, result

def func(
self, terms: PyTree[AbstractTerm], t0: RealScalarLike, y0: Y, args: Args
Expand Down

0 comments on commit 7dfb8e3

Please sign in to comment.