Skip to content

Commit

Permalink
fix tests for relative import
Browse files Browse the repository at this point in the history
  • Loading branch information
sammccallum committed Nov 26, 2024
1 parent 24d1935 commit e7856d3
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 15 deletions.
1 change: 0 additions & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@
Midpoint as Midpoint,
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
Reversible as Reversible,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
Expand Down
18 changes: 5 additions & 13 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,10 +1051,10 @@ class ReversibleAdjoint(AbstractAdjoint):
Backpropagate through [`diffrax.diffeqsolve`][] using the reversible solver
method.
This method wraps the passed solver to create an algebraically reversible version
of that solver. In doing so, gradient calculation is exact (up to floating point
errors) and backpropagation becomes a linear in time $O(t)$ and constant in memory
$O(1)$ algorithm.
This method automatically wraps the passed solver to create an algebraically
reversible version of that solver. In doing so, gradient calculation is exact
(up to floating point errors) and backpropagation becomes a linear in time $O(t)$
and constant in memory $O(1)$ algorithm.
The reversible adjoint can be used when solving ODEs/CDEs/SDEs and is
compatible with any [`diffrax.AbstractSolver`][]. Adaptive step sizes are also
Expand Down Expand Up @@ -1188,15 +1188,7 @@ class _Reversible(
Reversible solver method.
Allows any solver ([`diffrax.AbstractSolver`][]) to be made algebraically
reversible. The convergence order of the reversible solver is inherited from the
wrapped solver.
Gradient calculation through the reversible solver is exact (up to floating
point errors) and backpropagation becomes a linear in time $O(t)$ and constant in
memory $O(1)$ algorithm.
This is implemented in [`diffrax.ReversibleAdjoint`][] and passed to
[`diffrax.diffeqsolve`][] as `adjoint=diffrax.ReversibleAdjoint()`.
reversible. This is a private API, exclusively for [`diffrax.ReversibleAdjoint`][].
"""

solver: AbstractSolver
Expand Down
3 changes: 2 additions & 1 deletion test/test_reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.random as jr
import optimistix as optx
import pytest
from diffrax._adjoint import _Reversible
from jaxtyping import Array

from .helpers import tree_allclose
Expand Down Expand Up @@ -136,7 +137,7 @@ def _loss(y0__args__term, solver, saveat, adjoint, stepsize_controller):


def _compare_loss(y0__args__term, base_solver, saveat, stepsize_controller):
reversible_solver = diffrax.Reversible(base_solver)
reversible_solver = _Reversible(base_solver)

loss, grads_base = _loss(
y0__args__term,
Expand Down

0 comments on commit e7856d3

Please sign in to comment.