Skip to content

Commit

Permalink
catch already reversible solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
sammccallum committed Nov 26, 2024
1 parent 4b8b4c0 commit 861aa97
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
9 changes: 9 additions & 0 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
AbstractSolver,
AbstractStratonovichSolver,
AbstractWrappedSolver,
LeapfrogMidpoint,
ReversibleHeun,
SemiImplicitEuler,
)
from ._term import AbstractTerm, AdjointTerm

Expand Down Expand Up @@ -1126,6 +1129,12 @@ def loop(
"`diffrax.ReversibleAdjoint` is not compatible with events."
)

if isinstance(solver, (SemiImplicitEuler, ReversibleHeun, LeapfrogMidpoint)):
raise ValueError(
"`diffrax.ReversibleAdjoint` is not compatible with solvers that are "
f"intrinsically algebraically reversible, such as {solver}."
)

solver = _Reversible(solver, self.l)
tprev = init_state.tprev
tnext = init_state.tnext
Expand Down
26 changes: 26 additions & 0 deletions test/test_reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,29 @@ def test_incorrect_saveat():
stepsize_controller=diffrax.ConstantStepSize(),
pytree_state=False,
)


def test_incorrect_solver():
y0 = (jnp.array([0.9, 5.4]), jnp.array([0.9, 5.4]))
args = (0.1, -1)
terms = (
diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)),
diffrax.ODETerm(_VectorField(nondiff_arg=1, diff_arg=-0.1)),
)
incompatible_solvers = (
diffrax.SemiImplicitEuler(),
diffrax.ReversibleHeun(),
diffrax.LeapfrogMidpoint(),
)
for solver in incompatible_solvers:
with pytest.raises(ValueError):
diffrax.diffeqsolve(
terms,
solver,
t0=0,
t1=5,
dt0=0.01,
y0=y0,
args=args,
adjoint=diffrax.ReversibleAdjoint(),
)

0 comments on commit 861aa97

Please sign in to comment.