Skip to content

Commit

Permalink
Updated JAX example
Browse files Browse the repository at this point in the history
Co-authored-by: ChrisRackauckas <[email protected]>
  • Loading branch information
facusapienza21 and ChrisRackauckas committed Jun 8, 2024
1 parent 765a475 commit 7704ab5
Showing 1 changed file with 23 additions and 35 deletions.
58 changes: 23 additions & 35 deletions code/SensitivityForwardAD/testgradient_python.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,51 @@
import jax
import jax.numpy as jnp
from diffrax import diffeqsolve, ODETerm, Tsit5,
PIDController, BacksolveAdjoint,
RecursiveCheckpointAdjoint, DirectAdjoint
from diffrax import diffeqsolve, ODETerm, Tsit5, SaveAt, PIDController, RecursiveCheckpointAdjoint, DirectAdjoint

from jax import config
config.update("jax_enable_x64", True)

def vector_field(t, u, args):
x, y = u
a, b = args
dx = a * (x - y)
dy = b * (y - x)
a = args
dx = a * x - x * y
dy = -a * y + x * y
return dx, dy

def run(y0, adjoint = RecursiveCheckpointAdjoint(), tol = 1e-12):
def run(p0, adjoint = RecursiveCheckpointAdjoint(), tol = 1e-12):
term = ODETerm(vector_field)
solver = Tsit5(scan_kind="bounded")
stepsize_controller = PIDController(rtol=tol, atol=tol)
t0 = 0
t1 = 5.0
t1 = 10.0
ts = jnp.linspace(t0, t1, 101)
dt0 = 0.1
p0 = (1.0, 1.0)
y0 = (jnp.array(1.0), jnp.array(1.0))
saveat = SaveAt(ts=ts)

sol = diffeqsolve(term, solver, t0, t1, dt0, y0,
adjoint = adjoint,
stepsize_controller = stepsize_controller,
args=p0)
((x,), _) = sol.ys
return x
args=p0,
saveat=saveat)
return sum(sum(sol.ys))

y0 = (jnp.array(1.0), jnp.array(1.0))
p0 = 1.0
run(p0, RecursiveCheckpointAdjoint())

J = jax.jacrev(run)(y0)
# (Array(3755.79674193, dtype=float64, weak_type=True),
# Array(-3754.79674193, dtype=float64, weak_type=True))
# Array(6217.6080555, dtype=float64, weak_type=True)

J = jax.jacrev(lambda y0: run(y0, RecursiveCheckpointAdjoint(), tol = 1e-3))(y0)
# (Array(3755.79674193, dtype=float64, weak_type=True)
# Array(-3754.79674193, dtype=float64, weak_type=True))
# Array(6217.6080555, dtype=float64, weak_type=True)

J = jax.jacfwd(lambda y0: run(y0, DirectAdjoint()))(y0)
# (Array(3755.79674193, dtype=float64), Array(-3754.79674193, dtype=float64))
# Array(6217.6080555, dtype=float64)

J = jax.jacfwd(lambda y0: run(y0, DirectAdjoint(), tol = 1e-3))(y0)
# (Array(3755.79674193, dtype=float64), Array(-3754.79674193, dtype=float64))






J = jax.jacrev(lambda y0: run(y0, BacksolveAdjoint()))(y0)
# (Array(11013.73289742, dtype=float64, weak_type=True),
# Array(-11012.73289742, dtype=float64, weak_type=True))

J = jax.jacrev(lambda y0: run(y0, BacksolveAdjoint(), tol = 1e-3))(y0)
# (Array(10869.87401012, dtype=float64, weak_type=True),
# Array(-10868.87401012, dtype=float64, weak_type=True))
# Array(6217.6080555, dtype=float64)

y1 = (jnp.array(1.000001), jnp.array(1.0))
y0 = (jnp.array(1.0), jnp.array(1.0))
y1 = 1.000001
y0 = 1.0
(run(y1) - run(y0)) / .000001
# Array(11013.73156938, dtype=float64)
# Array(212.71060729, dtype=float64)

0 comments on commit 7704ab5

Please sign in to comment.