-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Organize code based on solver-direct
- Loading branch information
1 parent
6688e70
commit 537d483
Showing
5 changed files
with
179 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
using OrdinaryDiffEq | ||
|
||
function dyn!(du::Array{Complex{Float64}}, u::Array{Complex{Float64}}, p, t) | ||
ω = p[1] | ||
du[1] = u[2] | ||
du[2] = - ω^2 * u[1] | ||
end | ||
|
||
tspan = [0.0, 10.0] | ||
du = Array{Complex{Float64}}([0.0]) | ||
u0 = Array{Complex{Float64}}([0.0, 1.0]) | ||
|
||
function complexstep_differentiation(f::Function, p::Float64, ε::Float64) | ||
p_complex = p + ε * im | ||
return imag(f(p_complex)) / ε | ||
end | ||
|
||
complexstep_differentiation(x -> solve(ODEProblem(dyn!, u0, tspan, [x]), Tsit5()).u[end][1], 20., 1e-3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
using Base: @kwdef | ||
|
||
@kwdef struct DualNumber{F <: AbstractFloat} | ||
value::F | ||
derivative::F | ||
# Inner constructors | ||
# DualNumber(value, derivative) = new(value, derivative) | ||
# function DualNumber(value::F) where {F <: AbstractFloat} | ||
# new(value, 0.0) | ||
# end | ||
end | ||
|
||
# Outer constructor | ||
function DualNumber(value::F) where {F <: AbstractFloat} | ||
DualNumber(value, 0.0) | ||
end | ||
|
||
# Do we need to define this on Base? | ||
|
||
# Chain rules for binary opperators | ||
|
||
# Binary sum | ||
Base.:(+)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value + b.value, | ||
derivative = a.derivative + b.derivative) | ||
|
||
# Binary product | ||
Base.:(*)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value * b.value, | ||
derivative = a.value*b.derivative + a.derivative*b.value) | ||
|
||
# Power | ||
Base.:(^)(a::DualNumber, b::AbstractFloat) = DualNumber(value = a.value ^ b, | ||
derivative = b * a.value^(b-1) * a.derivative) | ||
|
||
|
||
# Now we define a series of variables. We are interested in computing the derivative with respect to the variable "a": | ||
|
||
a = DualNumber(value=1.0, derivative=1.0) | ||
|
||
b = DualNumber(value=2.0, derivative=0.0) | ||
c = DualNumber(value=3.0, derivative=0.0) | ||
|
||
# Now, we can evaluate a new DualNumber | ||
result = a * b * c | ||
# println("The derivative of a*b*c with respect to a is: ", result.derivative) |
106 changes: 106 additions & 0 deletions
106
code/DirectMethods/DualNumbers/dualnumber_tolerances.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
using Pkg | ||
Pkg.activate(".") | ||
|
||
using SciMLSensitivity | ||
using OrdinaryDiffEq | ||
using Zygote | ||
using ForwardDiff | ||
using Infiltrator | ||
|
||
tspan = (0.0, 10.0) | ||
u0 = [0.0] | ||
reltol = 1e-6 | ||
abstol = 1e-6 | ||
|
||
""" | ||
dyn! | ||
This generates solutions u(t) = (t-θ)^5/5 that can be solved exactly with a 5th order integrator. | ||
""" | ||
function dyn!(du, u, p, t) | ||
θ = p[1] | ||
du .= (t .- θ).^4.0 | ||
end | ||
|
||
p = [1.0] | ||
|
||
prob = ODEProblem(dyn!, u0, tspan, p) | ||
sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol) | ||
|
||
# We can see that the time steps increase with non-stop | ||
# @show diff(sol.t) | ||
|
||
function loss(p, sensealg) | ||
prob = ODEProblem(dyn!, u0, tspan, p) | ||
if isnothing(sensealg) | ||
sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol) | ||
else | ||
sol = solve(prob, Tsit5(), sensealg=sensealg, reltol=reltol, abstol=abstol) | ||
end | ||
@show "Number of time steps: ", length(sol.t) | ||
sol.u[end][1] | ||
end | ||
|
||
function grad_true(p) | ||
θ = p[1] | ||
t = tspan[2] | ||
θ^4 - (t - θ)^4 | ||
end | ||
|
||
""" | ||
An implementation of discrete forward sensitivity analysis through ForwardDiff.jl. | ||
When used within adjoint differentiation (i.e. via Zygote), this will cause forward differentiation | ||
of the solve call within the reverse-mode automatic differentiation environment. | ||
https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#SciMLSensitivity.ForwardDiffSensitivity | ||
""" | ||
# Original AD without correction | ||
|
||
condition(u, t, integrator) = true | ||
function printstepsize!(integrator) | ||
# @infiltrate | ||
if length(integrator.sol.t) > 1 | ||
# println("Stepsize at step ", length(integrator.sol.t), ": ", integrator.sol.t[end] - integrator.sol.t[end-1]) | ||
end | ||
end | ||
|
||
cb = DiscreteCallback(condition, printstepsize!) | ||
|
||
# g1 = Zygote.gradient(p -> loss(p, ForwardDiffSensitivity()), internalnorm = (u,t) -> sum(abs2,u/length(u)), p) | ||
g1 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p), | ||
Tsit5(), | ||
u0 = u0, | ||
p = p, | ||
sensealg = ForwardDiffSensitivity(), | ||
saveat = 0.1, | ||
internalnorm = (u,t) -> sum(abs2, u/length(u)), | ||
callback = cb, | ||
reltol=1e-6, | ||
abstol=1e-6).u[end][1], p) | ||
@show g1 | ||
|
||
# Forward Sensitivity | ||
# g2 = Zygote.gradient(p -> loss(p, ForwardSensitivity()), p) | ||
# g2 = Zygote.gradient(p -> solve(prob, | ||
# Tsit5(), | ||
# sensealg = ForwardSensitivity(), | ||
# saveat = 0.1, | ||
# callback = cb, | ||
# reltol=1e-12, | ||
# abstol=1e-12).u[end][1], p) | ||
# @show g2 | ||
|
||
# Corrected AD | ||
# g3 = ForwardDiff.gradient(p -> loss(p, nothing), p) | ||
g3 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p), | ||
Tsit5(), | ||
sensealg = ForwardDiffSensitivity(), | ||
# saveat = 0.1, | ||
# callback = cb, | ||
reltol=1e-6, | ||
abstol=1e-6).u[end][1], p) | ||
@show g3 | ||
|
||
@show grad_true(p) | ||
|
||
# Define customized RK(4) solver with given timesteps to show the divergence of forward sensitivities |
This file was deleted.
Oops, something went wrong.