Skip to content

Commit

Permalink
Added example of adjoint and sensitivity methods for harmonic oscilator
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 committed May 30, 2024
1 parent d940d4a commit 9b58dd9
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 0 deletions.
27 changes: 27 additions & 0 deletions code/SolverMethods/Harmonic/adjoint_continuous.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Continuous Adjoint Method

include("harmonic.jl")
using RecursiveArrayTools

# Augmented dynamicis
function f_aug(z, p, t)
u, λ, L = z
du = f(u, p, t)
= - ∂f∂u(u, p, t)' * λ
dL = - λ' * ∂f∂p(u, p, t)
VectorOfArray([du, vec(dλ), vec(dL)])
end

# Solution of original ODE
prob = ODEProblem(f, u0, tspan, p)
sol = solve(prob, Euler(), dt=0.001)

# Final state
u1 = sol.u[end]
z1 = VectorOfArray([u1, [1.0, 0.0], zeros(length(p))])

aug_prob = ODEProblem(f_aug, z1, reverse(tspan), p)
u0_, λ0, dLdp_cont = solve(aug_prob, Euler(), dt=-0.001).u[end]


@test dLdp_cont dLdp_SciML[1]
32 changes: 32 additions & 0 deletions code/SolverMethods/Harmonic/adjoint_discrete.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Discrete adjoint method

include("harmonic.jl")

function discrete_adjoint_method(u0, tspan, p, dt)
u = u0
times = tspan[1]:dt:tspan[2]

λ = [1.0, 0.0]
∂L∂p = zeros(length(p))
u_store = [u]

# Forward pass to compute solution
for t in times[1:end-1]
u += dt * f(u, p, t)
push!(u_store, u)
end

# Reverse pass to compute adjoint
for (i, t) in enumerate(reverse(times)[2:end])
u_memory = u_store[end-i+1]
λ += dt * ∂f∂u(u_memory, p, t)' * λ
∂L∂p += dt * λ' * ∂f∂p(u_memory, p, t)
end

return ∂L∂p
end

dL∂p_disc = discrete_adjoint_method(u0, tspan, p, 0.001)

# Notice that there is still some numerical error in the case of the discrete adjoint
@test vec(dL∂p_disc) dLdp_SciML rtol=1e-3
29 changes: 29 additions & 0 deletions code/SolverMethods/Harmonic/forward_sensitivity_equations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Forward sensitivity equations

include("harmonic.jl")

function sensitivityequation(u0, tspan, p, dt)
u = u0
sensitivity = zeros(length(u), length(p))
for ti in tspan[1]:dt:tspan[2]
sensitivity += dt * (∂f∂u(u, p, ti) * sensitivity + ∂f∂p(u, p, ti))
u += dt * f(u, p, ti)
end
return u, sensitivity
end

u, s = sensitivityequation(u0, tspan , p, 0.001)

using OrdinaryDiffEq, ForwardDiff, Test

s_AD = ForwardDiff.jacobian(p -> solve(ODEProblem(f, u0, tspan, p), Tsit5()).u[end], p)

@test s_AD s rtol=0.01

### Let's do this with SciMLSensitivity

prob = ODEForwardSensitivityProblem(f!, u0, tspan, p)
sol = solve(prob, Tsit5())
u, dudp = extract_local_sensitivities(sol)

@test dudp[1][:, end] s_AD rtol=1e-3
46 changes: 46 additions & 0 deletions code/SolverMethods/Harmonic/harmonic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
Harmonic oscilator
"""

ω = 0.2
p = [ω]
u0 = [0.0, 1.0]
tspan = [0.0, 10.0]

# Dynamics
function f(u, p, t)
du₁ = u[2]
du₂ = - p[1]^2 * u[1]
return [du₁, du₂]
end

function f!(du, u, p, t)
du[1] = u[2]
du[2] = - p[1]^2 * u[1]
end

# Jacobian ∂f/∂p
function ∂f∂p(u, p, t)
Jac = zeros(length(u), length(p))
Jac[2,1] = -2*p[1]*u[1]
return Jac
end

# Jacobian ∂f/∂u
function ∂f∂u(u, p, t)
Jac = zeros(length(u), length(u))
Jac[1,2] = 1
Jac[2,1] = -p[1]^2
return Jac
end

# Ground truth gradient

function cost(p)
prob = ODEProblem(f, u0, tspan, p)
return solve(prob, Euler(), dt=0.001, save_everystep=false, sensealg=BacksolveAdjoint()).u[end][1]
end
cost(p)

dLdp_SciML = Zygote.gradient(p -> cost(p), p)[1]

0 comments on commit 9b58dd9

Please sign in to comment.