Skip to content

Commit

Permalink
Organize code based on solver-direct
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 committed Mar 2, 2024
1 parent 6688e70 commit 537d483
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using OrdinaryDiffEq
using CairoMakie
using ComplexDiff
using Zygote, ForwardDiff, SciMLSensitivity
using BenchmarkTools

include("./complex_solver.jl")
include("../ComplexStep/complex_solver.jl")

# Parameters
u0 = [0.0, 1.0]
Expand Down Expand Up @@ -107,7 +108,7 @@ error_complex_high = abs.((derivative_true .- derivative_complex_high) ./ deriva

# Complex step Differentiation
derivative_complex_exact = ComplexDiff.derivative.(ω -> solution(t₁, u0, [ω]), p[1], stepsizes)
error_complex_exact = abs.((derivative_complex .- derivative_true)./derivative_true)
error_complex_exact = abs.((derivative_complex_exact .- derivative_true)./derivative_true)

# Forward AD applied to numerical solver
derivative_AD_low = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], p)[1][1]
Expand Down Expand Up @@ -150,4 +151,11 @@ plot!(ax, [stepsizes[begin], stepsizes[end]],[error_AD_high, error_AD_high], col
# Add legend
fig[1, 2] = Legend(fig, ax)

save("FiniteDifferences/FiniteDifferences_derivative.pdf", fig)
save("Figures/DirectMethods_comparison.pdf", fig)


######### Benchmark ###########

# It looks like complex step has better performance... both in speed and momory allocation.
# @benchmark derivative_complex_low = complexstep_differentiation.(Ref(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1]), Ref(p[1]), [1e-5])
# @benchmark derivative_AD_low = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], p)[1][1]
18 changes: 18 additions & 0 deletions code/DirectMethods/ComplexStep/complex_solver.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using OrdinaryDiffEq

function dyn!(du::Array{Complex{Float64}}, u::Array{Complex{Float64}}, p, t)
ω = p[1]
du[1] = u[2]
du[2] = - ω^2 * u[1]
end

tspan = [0.0, 10.0]
du = Array{Complex{Float64}}([0.0])
u0 = Array{Complex{Float64}}([0.0, 1.0])

function complexstep_differentiation(f::Function, p::Float64, ε::Float64)
p_complex = p + ε * im
return imag(f(p_complex)) / ε
end

complexstep_differentiation(x -> solve(ODEProblem(dyn!, u0, tspan, [x]), Tsit5()).u[end][1], 20., 1e-3)
44 changes: 44 additions & 0 deletions code/DirectMethods/DualNumbers/dualnumber_definition.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using Base: @kwdef

@kwdef struct DualNumber{F <: AbstractFloat}
value::F
derivative::F
# Inner constructors
# DualNumber(value, derivative) = new(value, derivative)
# function DualNumber(value::F) where {F <: AbstractFloat}
# new(value, 0.0)
# end
end

# Outer constructor
function DualNumber(value::F) where {F <: AbstractFloat}
DualNumber(value, 0.0)
end

# Do we need to define this on Base?

# Chain rules for binary opperators

# Binary sum
Base.:(+)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value + b.value,
derivative = a.derivative + b.derivative)

# Binary product
Base.:(*)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value * b.value,
derivative = a.value*b.derivative + a.derivative*b.value)

# Power
Base.:(^)(a::DualNumber, b::AbstractFloat) = DualNumber(value = a.value ^ b,
derivative = b * a.value^(b-1) * a.derivative)


# Now we define a series of variables. We are interested in computing the derivative with respect to the variable "a":

a = DualNumber(value=1.0, derivative=1.0)

b = DualNumber(value=2.0, derivative=0.0)
c = DualNumber(value=3.0, derivative=0.0)

# Now, we can evaluate a new DualNumber
result = a * b * c
# println("The derivative of a*b*c with respect to a is: ", result.derivative)
106 changes: 106 additions & 0 deletions code/DirectMethods/DualNumbers/dualnumber_tolerances.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using Pkg
Pkg.activate(".")

using SciMLSensitivity
using OrdinaryDiffEq
using Zygote
using ForwardDiff
using Infiltrator

tspan = (0.0, 10.0)
u0 = [0.0]
reltol = 1e-6
abstol = 1e-6

"""
dyn!
This generates solutions u(t) = (t-θ)^5/5 that can be solved exactly with a 5th order integrator.
"""
function dyn!(du, u, p, t)
θ = p[1]
du .= (t .- θ).^4.0
end

p = [1.0]

prob = ODEProblem(dyn!, u0, tspan, p)
sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol)

# We can see that the time steps increase with non-stop
# @show diff(sol.t)

function loss(p, sensealg)
prob = ODEProblem(dyn!, u0, tspan, p)
if isnothing(sensealg)
sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol)
else
sol = solve(prob, Tsit5(), sensealg=sensealg, reltol=reltol, abstol=abstol)
end
@show "Number of time steps: ", length(sol.t)
sol.u[end][1]
end

function grad_true(p)
θ = p[1]
t = tspan[2]
θ^4 - (t - θ)^4
end

"""
An implementation of discrete forward sensitivity analysis through ForwardDiff.jl.
When used within adjoint differentiation (i.e. via Zygote), this will cause forward differentiation
of the solve call within the reverse-mode automatic differentiation environment.
https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#SciMLSensitivity.ForwardDiffSensitivity
"""
# Original AD without correction

condition(u, t, integrator) = true
function printstepsize!(integrator)
# @infiltrate
if length(integrator.sol.t) > 1
# println("Stepsize at step ", length(integrator.sol.t), ": ", integrator.sol.t[end] - integrator.sol.t[end-1])
end
end

cb = DiscreteCallback(condition, printstepsize!)

# g1 = Zygote.gradient(p -> loss(p, ForwardDiffSensitivity()), internalnorm = (u,t) -> sum(abs2,u/length(u)), p)
g1 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p),
Tsit5(),
u0 = u0,
p = p,
sensealg = ForwardDiffSensitivity(),
saveat = 0.1,
internalnorm = (u,t) -> sum(abs2, u/length(u)),
callback = cb,
reltol=1e-6,
abstol=1e-6).u[end][1], p)
@show g1

# Forward Sensitivity
# g2 = Zygote.gradient(p -> loss(p, ForwardSensitivity()), p)
# g2 = Zygote.gradient(p -> solve(prob,
# Tsit5(),
# sensealg = ForwardSensitivity(),
# saveat = 0.1,
# callback = cb,
# reltol=1e-12,
# abstol=1e-12).u[end][1], p)
# @show g2

# Corrected AD
# g3 = ForwardDiff.gradient(p -> loss(p, nothing), p)
g3 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p),
Tsit5(),
sensealg = ForwardDiffSensitivity(),
# saveat = 0.1,
# callback = cb,
reltol=1e-6,
abstol=1e-6).u[end][1], p)
@show g3

@show grad_true(p)

# Define customized RK(4) solver with given timesteps to show the divergence of forward sensitivities
25 changes: 0 additions & 25 deletions code/FiniteDifferences/complex_solver.jl

This file was deleted.

0 comments on commit 537d483

Please sign in to comment.