-
-
Notifications
You must be signed in to change notification settings - Fork 137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reversible Solvers #528
base: main
Are you sure you want to change the base?
Reversible Solvers #528
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is really well done. I've left some comments but it's mostly around broader structural/testing/documentation stuff.
I've commented on your point 1 inline, and I think what you've done for point 2 looks good to me!
diffrax/_solver/reversible.py
Outdated
`adjoint=diffrax.ReversibleAdjoint()`. | ||
""" | ||
|
||
solver: AbstractRungeKutta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are implicit RK methods handled here as well? According to this annotation they are but I don't think I see them in the tests.
What is it about RK methods that privileges them here btw? IIUC I think any single-step method should work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also what about Euler
, which isn't implemented as an AbstractRungeKutta
but which does have the correct properties?
(It's done separately to be able to use as example code for how to write a solver.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you're right - any single step method should work. The reversible solver now works with any AbstractSolver
.
See the discussion on fsal
for more info.
diffrax/_solver/reversible.py
Outdated
def strong_order(self, terms: PyTree[AbstractTerm]) -> Optional[RealScalarLike]: | ||
return self.solver.strong_order(terms) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you expect this technique to work for SDEs? If so then do call that out explicitly in the docstring, to reassure people! :)
(In particular I'm thinking of the asychronous leapfrog method, which to our surprise did not work for SDEs...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have no theory here (does James' intuition count?), but numerically it works for SDEs! I've added SDEs to the docstring.
There's the detail that the second solver step (that steps backwards in time) should use the same Brownian increment as the first solver step. I believe this is handled by VirtualBrownianTree
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just added a check and test for UnsafeBrownianPath
in light of the above.
And thinking about this further, we require the same conditions as BacksolveAdjoint
; namely that the solver converges to the Stratonovich solution, so I've added a check and test for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm FWIW my intuition is also that this should work for SDEs. Sounds like a follow-up paper to be written :)
But... until that theory exists, I think I'd feel more comfortable issuing an error here instead, to try to minimize the possibility of footguns. Most users treat solvers like oracles, and I try to cater to that unfootgunable UX!
diffrax/_solver/reversible.py
Outdated
`adjoint=diffrax.ReversibleAdjoint()`. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do go ahead and a couple of references here! (See how we've done it in the other solvers.) At the very least including both your paper, and the various earlier pieces of work. Also make sure whatever you put here works with diffrax.citation
, so that folks have an easy way to cite you :)
What happens if I use just ReversibleAdjoint
with a different solver? What happens if I use Reversible
with a different adjoint? Is this safe to use with adaptive time stepping? The docstring here needs to make clear what a user should expect to happen as this interacts with the other components of Diffrax!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, good point - added.
Removing Reversible
from public API helps with control here.
test/test_reversible.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test checking how this interacts with events? It's not immediately obvious to me that this will actually do the right thing.
Also, it would be good to see some 'negative tests' checking that the appropriate error is raised if Reversible
is used in conjunction with e.g. SemiImplicitEuler
, or any other method that isn't supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Events seem to work on the forward reversible solve but raise the same error as BacksolveAdjoint
on the backward solve. I've added a catch to raise an error if you try to use ReversibleAdjoint
with events.
Negative tests for incompatible solvers, events and saveats have been added.
diffrax/_adjoint.py
Outdated
if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`saveat=SaveAt(t1=True)`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will probably not take long until someone asks to use this alongside SaveAt(ts=...)
!
I can see that this is probably trickier to handle because of the way we do interpolation to get outputs at ts
. Do you have any ideas for this?
(Either way, getting it working for that definitely isn't a prerequisite for merging, it's just a really solid nice-to-have.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I imagine SaveAt(steps=True)
is probably much easier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added functionality for SaveAt(steps=True)
, but SaveAt(ts=...)
is a tricky one.
Not a solution, but some thoughts:
The ReversibleAdjoint
computes gradients accurate to the numerical operations taken, rather than an approximation to the 'idealised' continuous-time adjoint ODE. This is then tricky when the numerical operations include interpolation and not just ODE solving.
In principle, the interpolated ys
are just a function of the stepped-to ys
. We can therefore calculate gradients for the stepped-to ys
and let AD handle the rest. This would require the interpolation routine to be separate to the solve routine, but I understand the memory drawbacks of this setup.
I imagine there isn't a huge demand to decouple the solve from the interpolation - but if it turns out this is relevant for other cases I'd be happy to give it a go!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On your thoughts -- I think this is exactly right. In principle we should be able to just solve backwards, and once we have the relevant y-values we can (re)interpolate the same solution we originally provided, and then pull the contangents backwards through that computation via autodiff. Code-wise that may be somewhat fiddly, but if you're willing to take it on then I expect that'll actually be a really useful use-case.
I'm not sure if this would be done by decoupling solve from interpolation. I expect it would be some _, vjp_fn = jax.vjp(compute_interpolation); vjp_fn(y_cotangent)
calls inside your while loop on the backward pass.
diffrax/_adjoint.py
Outdated
if not isinstance(solver, Reversible): | ||
raise ValueError( | ||
"Can only use `adjoint=ReversibleAdjoint()` with " | ||
"`Reversible()` solver." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we perhaps remove Reversible
from the public API altogether, and just have solver = Reversible(solver)
here? Make the Reversible
solver an implementation detail of the adjoint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like this idea :D
I've removed Reversible
from the public API and any AbstractSolver
passed to ReversibleAdjoint
is auto-wrapped. There is now a _Reversible
class within the _adjoint
module that is exclusively used by ReversibleAdjoint
. Do you think this is an appropriate home for the _Reversible
class or should I keep it elsewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For consistency I'd probably put it in _solvers/reversible.py
-- I've generally tended to organize things in this way, e.g. _terms.py::AdjointTerm
is used as part of BacksolveAdjoint
.
But that's only for consistency, one could imagine an alternate layout where these all lived next to their consumers instead.
diffrax/_solver/reversible.py
Outdated
solver_state: _SolverState, | ||
made_jump: BoolScalarLike, | ||
) -> tuple[Y, Optional[Y], DenseInfo, _SolverState, RESULTS]: | ||
(first_step, f0), z0 = solver_state |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will fail for non-FSAL Runge-Kutta solvers.
(Can you add a test for one of those to be sure we get correct behaviour?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment below.
diffrax/_solver/reversible.py
Outdated
# solver to ensure the vf is correct. | ||
# Can we avoid this re-evaluation? | ||
|
||
f0 = self.func(terms, t0, z0, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is really okay -- AbstractSolver.func
is something we try to use approximately never: it basically exists just to handle initial step size selection and steady state finding, which are both fairly heuristic and pretty far of the beaten path.
If I understand correctly, the issue is that your y1
isn't quite the value that is returned from a single step, so the FSAL property does not hold, and as such you need to reevaluate f0
? If so then I think you should be able to avoid this issue by ensuring that the RK solvers are used in non-FSAL form. This is one of the most complicated corners of the codebase, but take a look at the comment starting here:
diffrax/diffrax/_solver/runge_kutta.py
Line 532 in 0cb19e9
# |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We now disable the FSAL and SSAL properties in the _Reversible
init method (if a RK solver is used).
With this we can now make any AbstractSolver
reversible and we pass around the _Reversible
solver state by (original_solver_state, z_n)
. We also never unpack the original_solver_state
, so don't need to assume any structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! It makes me very happy that these are all possible, that really simplifies our lives ^^
diffrax/_solver/reversible.py
Outdated
step_z0, z_error, dense_info, _, result1 = self.solver.step( | ||
terms, t0, t1, z0, args, (first_step, f0), made_jump | ||
) | ||
y1 = (self.l * (ω(y0) - ω(z0)) + ω(step_z0)).ω | ||
|
||
f0 = self.func(terms, t1, y1, args) | ||
step_y1, y_error, _, _, result2 = self.solver.step( | ||
terms, t1, t0, y1, args, (first_step, f0), made_jump | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On these two evaluations of .step
-- take a look at eqx.internal.scan_trick
, which might allow you to collapse these two callsites into one. That can be used to half compilation time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CMIIW, I'm not sure we can use the scan trick here as the function return signature is different for each solver step?
That is, we only want to update the original_solver_state
and dense_info
when taking the forward-in-time step. So we don't return these on the backward-in-time step. IIUC, collapsing the two calls into one would require the returned carry to be the same on both calls?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it's usually possible to make this work even in arbitrary cases by welding two copies of your state together and then using a jnp.where
or lax.cond
to route between them based on which step you're on.
That said this is a pretty fiddly optimization, and I probably shouldn't have suggested it just yet! Once we're happy with everything else then we could do this later, but until then the un-scan-trick'd code is much easier to read.
commit ec1ebac Author: Sam McCallum <[email protected]> Date: Wed Nov 27 08:46:55 2024 +0000 tidy up function arguments commit 7b66f46 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 18:13:11 2024 +0000 beefy tests commit e713b5d Author: Sam McCallum <[email protected]> Date: Tue Nov 26 13:29:26 2024 +0000 update references commit 9acf6e0 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 13:12:26 2024 +0000 test incorrect solver commit 861aa97 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 13:05:05 2024 +0000 catch already reversible solvers commit 4b8b4c0 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 12:37:03 2024 +0000 error estimate may be pytree commit 0b01210 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 12:36:24 2024 +0000 tests commit 5435ab2 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 11:17:09 2024 +0000 Revert "leapfrog not compatible" This reverts commit d88e732. commit d88e732 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 11:15:32 2024 +0000 leapfrog not compatible commit 6e3f2de Author: Sam McCallum <[email protected]> Date: Tue Nov 26 11:13:30 2024 +0000 pytree state commit 3fa6432 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 10:28:26 2024 +0000 docs commit 2bfe820 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 09:34:36 2024 +0000 remove reversible.py solver file commit e7856d3 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 09:33:52 2024 +0000 fix tests for relative import commit 24d1935 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 09:18:05 2024 +0000 private reversible commit 8a7448e Author: Sam McCallum <[email protected]> Date: Tue Nov 26 08:56:40 2024 +0000 remove debug print commit 0391bc1 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 08:28:41 2024 +0000 tests commit 81a9a57 Author: Sam McCallum <[email protected]> Date: Tue Nov 26 08:23:41 2024 +0000 more tests commit 89f5731 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 20:52:51 2024 +0000 test implicit solvers + SDEs commit f30f47e Author: Sam McCallum <[email protected]> Date: Mon Nov 25 20:44:54 2024 +0000 remove t0, t1, solver_state tangents commit b903176 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 16:56:01 2024 +0000 docs commit acaa35f Author: Sam McCallum <[email protected]> Date: Mon Nov 25 12:56:50 2024 +0000 better steps=True commit 621e6f4 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 10:28:19 2024 +0000 remove ifs in grad_step loop commit 7dfb8e3 Author: Sam McCallum <[email protected]> Date: Mon Nov 25 09:15:18 2024 +0000 Disable fsal, ssal properties to allow any solver to be made reversible commit f160295 Author: Sam McCallum <[email protected]> Date: Fri Nov 22 15:09:57 2024 +0000 tests commit f327f66 Author: Sam McCallum <[email protected]> Date: Fri Nov 22 13:53:56 2024 +0000 ReversibleAdjoint compatible with SaveAt(steps=True) Reversible Solvers (v2) Changes: - `Reversible` solver is hidden from public API and automatically used with `ReversibleAdjoint` - compatible with any `AbstractSolver`, except methods that are already algebraically reversible - can now use `SaveAt(steps=True)` - works with ODEs/CDEs/SDEs - improved docs - improved tests
Thanks very much for the review and suggestions! I'll reply to individual comments inline but here is an overview:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, sorry for the (long) delay! 😅
So this mostly looks like it's in a good place to me. I think my one big concern is handling solver_state
-- as each solver is free to define any kind of state it likes, then (a) there need not be any way to reconstruct it reversibly (so the backward pass goes wrong), and (b) it may encode information from previous y
,z
-values that aren't actually the ones being propagated (so the forward pass goes wrong).
I think right now you're dodging this just because of the particular choice of solvers being considered, i.e. (forced-to-be-)non-FSAL Runge--Kutta solvers.
I think we've got a couple of possible options:
- We could expand the solver API a little bit. For example a method to construct the previous state reversibly, and a flag to indicate that the
y
-value has changed and the solver might need to consider its state invalidated. - We could also special-case down to e.g. just
AbstractRungeKutta
, and hardwire all the things about it that we know how to handle.
I am weakly leaning towards the first option, as it could allow us to 'already reversible' solvers like ReversibleHeun
or LeapfrogMidpoint
(c.f. #541).
- We can introduce an
AbstractReversibleSolver
, subclassed byReversibleHuen
,LeapfrogMidpoint
and your_Reversible
. - Then
ReversibleAdjoint
uses just the API provided byAbstractReversibleSolver
, without making assumptions about precisely which one. _Reversible
could afford to consume only those kinds of one-step solvers it knows how to handle (e.g. non-FSAL RK methods), and if other special use-cases arise later then we can always make more subclasses ofAbstractReversibleSolver
.
But I also realise that figuring out these details make for pretty complicated research questions, and I don't want to presume upon your appetite for tackling them!
Let me know what you think / also happy to have a longer chat about this via email etc if it's easier.
# | ||
# Information for reversible adjoint (save ts) | ||
# | ||
reversible_ts: Optional[eqxi.MaybeBuffer[Float[Array, " times_plus_1"]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the way, I think this should be registered as a buffer in _outer_buffers
used here:
Line 627 in 467d95f
cond_fun, body_fun, init_state, max_steps=max_steps, buffers=_outer_buffers |
As for what buffers are, see also some discussion on buffers here from one of Andraz's PRs:
+the docs for eqxi.while_loop
# Reversible info | ||
if isinstance(adjoint, ReversibleAdjoint): | ||
if max_steps is None: | ||
raise ValueError( | ||
"`max_steps=None` is incompatible with `ReversibleAdjoint`" | ||
) | ||
reversible_ts = jnp.full(max_steps + 1, jnp.inf, dtype=time_dtype) | ||
reversible_save_index = 0 | ||
else: | ||
reversible_ts = None | ||
reversible_save_index = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think what might be simpler here is to have:
if max_steps is None:
reversible_ts = None
reversible_save_index = None
else:
reversible_ts = jnp.full(...)
reversible_save_index = 0
so that we are always saving this information if possible.
The benefit of this is that in principle someone else could write their own ReversibleAdjoint2
and have it work without needing to be special-cased here inside the main diffeqsolve implementation: it would just consume the information made available to it.
Finally the ValueError
can be moved inside the implementation of ReversibleAdjoint
, if the necessary reversible_ts
information is not available.
This shouldn't really impose any performance penalty (a very small compile-time one only) because for any other adjoint method it will just be DCE'd.
``` | ||
""" | ||
|
||
l: float = 0.999 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we pick a more descriptive name for this parameter? E.g. when writing an optimizer then mathematically we may conventionally use λ but in code we would often use a variable with a name like learning_rate
.
# `is` check because this may return a Tracer from SaveAt(ts=<array>) | ||
if ( | ||
eqx.tree_equal(saveat, SaveAt(t1=True)) is not True | ||
and eqx.tree_equal(saveat, SaveAt(steps=True)) is not True | ||
and eqx.tree_equal(saveat, SaveAt(t0=True, steps=True)) is not True | ||
): | ||
raise ValueError( | ||
"Can only use `diffrax.ReversibleAdjoint` with " | ||
"`saveat=SaveAt(t1=True)` or `saveat=SaveAt(steps=True)`." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing that SaveAt(steps=True, t1=True)
or SaveAt(t0=True, steps=True, t1=True)
should also be allowed?
solver = _Reversible(solver, self.l) | ||
tprev = init_state.tprev | ||
tnext = init_state.tnext | ||
y = init_state.y | ||
|
||
init_state = eqx.tree_at( | ||
lambda s: s.solver_state, | ||
init_state, | ||
solver.init(terms, tprev, tnext, y, args), | ||
is_leaf=_is_none, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the _Reversible.init
here will re-call the underlying original solver.init
, I think unnecessarily? I think we should be able to do just init_state = eqx.tree_at(lambda s: s.solver_state, init_state, (init_state.solver_state, init_state.y))
, and then set class _Reversible: def init(...): assert False
.
@property | ||
def term_compatible_contr_kwargs(self): | ||
return self.solver.term_compatible_contr_kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this, and term_structure
, should not be necessary -- they are checked at the start of diffeqsolve
, which we are already past. I think it may be neater to set these to assert False
?
if isinstance(self.solver, AbstractRungeKutta): | ||
object.__setattr__(self.solver.tableau, "fsal", False) | ||
object.__setattr__(self.solver.tableau, "ssal", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this will produce a bug if we do:
solver = Tsit5()
diffeqsolve(..., solver, ReversibleAdjoint())
diffeqsolve(..., solver, RecursiveCheckpointAdjoint())
as the solver
is modified in-place.
I think I have a better solution: can we unconditionally pass made_jump=True
into self.solver.step
? This is our API to indicate to solvers that something has changed, and that their state may be out-of-date. Technically right now it's used to indicate jumps in the vector field, but we could re-use it (or add another flag) to indicate exogenous jumps in y
.
Alternatively it may be safer for now to only allow AbstractRungeKutta
here, and not general solvers -- it's not clear to me that any of this will really work with multi-step solvers like LeapfrogMidpoint
, for example.
solver_state = (original_solver_state, z1) | ||
result = update_result(result1, result2) | ||
|
||
return y1, _add_maybe_none(z_error, y_error), dense_info, solver_state, result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why add the error estimates together?
grad_z1 = jtu.tree_map(jnp.zeros_like, diff_z1) | ||
del diff_args, diff_terms, diff_z1 | ||
|
||
def grad_step(state): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm trying to figure out how solver_state
is handled here. I don't think it is correct?
solver_state
is some completely arbitrary information that is propagated forward step-by-step, internal to the solver. In particular we don't have an API for reconstructing this backwards in time reversibly.
A description of the performance bug with the current ReversibleAdjoint:
From the example below, we see that for import time
import diffrax as dfx
import equinox as eqx
import jax
import jax.numpy as jnp
import jax.random as jr
jax.config.update("jax_enable_x64", True)
class VectorField(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, y_dim, width_size, depth, key):
self.mlp = eqx.nn.MLP(y_dim, y_dim, width_size, depth, key=key)
def __call__(self, t, y, args):
return self.mlp(y)
@eqx.filter_jit
def solve(model, y0, adjoint):
term = dfx.ODETerm(model)
solver = dfx.Euler()
t0 = 0.0
t1 = 10.0
dt0 = 0.01
sol = dfx.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
saveat=dfx.SaveAt(t1=True),
adjoint=adjoint,
max_steps=1000,
)
return sol.ys
@eqx.filter_value_and_grad
def grad_loss(model, y0, adjoint):
ys = eqx.filter_vmap(solve, in_axes=(None, 0, None))(model, y0, adjoint)
return jnp.mean(ys**2)
def measure_runtime(y0, model, adjoint):
tic = time.time()
loss, grads = grad_loss(model, y0, adjoint)
toc = time.time()
print(f"Compile time: {(toc - tic):.5f}")
repeats = 10
tic = time.time()
for i in range(repeats):
loss, grads = jax.block_until_ready(grad_loss(model, y0, adjoint))
toc = time.time()
print(f"Runtime: {((toc - tic) / repeats):.5f}")
y_dim = 100
width_size = 100
depth = 4
model = VectorField(y_dim, width_size, depth, key=jr.PRNGKey(10))
print("Batch Size = 1")
print("--------------")
y0 = jnp.ones((1, y_dim))
print("Recursive")
adjoint = dfx.RecursiveCheckpointAdjoint()
measure_runtime(y0, model, adjoint)
print("Reversible")
adjoint = dfx.ReversibleAdjoint()
measure_runtime(y0, model, adjoint)
print("\nBatch Size = 1000")
print("-----------------")
y0 = jnp.ones((1000, y_dim))
print("Recursive")
adjoint = dfx.RecursiveCheckpointAdjoint()
measure_runtime(y0, model, adjoint)
print("Reversible")
adjoint = dfx.ReversibleAdjoint()
measure_runtime(y0, model, adjoint)
FWIW, I don't think this is specifically a problem with vmap, but a problem of scale. For example, if we keep
In principle, this is very wrong as we are only changing the cost of each step (The quoted runtimes are on GPU) |
Hey Patrick,
Here's an implementation of Reversible Solvers! This includes:
AbstractRungeKutta
method in diffrax algebraically reversible - seediffrax.Reversible
diffrax.ReversibleAdjoint
Main details I should highlight here:
The current implementation relies on the _SolverState type of
AbstractRungeKutta
methods. Specifically, as the reversible method switches between evaluating the vector field at y and z, we ensure thefsal
is correct by evaluating the vector field outside of the base Runge Kutta step. In principle this is unnecessary but required to fit with the behaviour ofAbstractRungeKutta
solvers; any ideas for how to avoid this?To backpropagate through the reversible solve we require knowledge of the
ts
that the solver visited. As this is not known a priori for adaptive step sizes, I've added a (teeny weeny) bit of infrastructure to the State in_integrate.py
. This allows us to save the ts that the solver stepped to which we make available toReversibleAdjoint
as a residual. The added State follows exactly the implementation of saving dense_ts and is only triggered whenadjoint=ReversibleAdjoint
.Best,
Sam