From 537d4834efa443391f0e39c9a05d9cd045163059 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Sat, 2 Mar 2024 02:24:01 +0000 Subject: [PATCH] Organize code based on solver-direct --- .../Comparison/direct-comparision.jl} | 14 ++- .../ComplexStep/complex_solver.jl | 18 +++ .../DualNumbers/dualnumber_definition.jl | 44 ++++++++ .../DualNumbers/dualnumber_tolerances.jl | 106 ++++++++++++++++++ code/FiniteDifferences/complex_solver.jl | 25 ----- 5 files changed, 179 insertions(+), 28 deletions(-) rename code/{FiniteDifferences/finite_differences.jl => DirectMethods/Comparison/direct-comparision.jl} (90%) create mode 100644 code/DirectMethods/ComplexStep/complex_solver.jl create mode 100644 code/DirectMethods/DualNumbers/dualnumber_definition.jl create mode 100644 code/DirectMethods/DualNumbers/dualnumber_tolerances.jl delete mode 100644 code/FiniteDifferences/complex_solver.jl diff --git a/code/FiniteDifferences/finite_differences.jl b/code/DirectMethods/Comparison/direct-comparision.jl similarity index 90% rename from code/FiniteDifferences/finite_differences.jl rename to code/DirectMethods/Comparison/direct-comparision.jl index 381d4c7..303f5c7 100644 --- a/code/FiniteDifferences/finite_differences.jl +++ b/code/DirectMethods/Comparison/direct-comparision.jl @@ -5,8 +5,9 @@ using OrdinaryDiffEq using CairoMakie using ComplexDiff using Zygote, ForwardDiff, SciMLSensitivity +using BenchmarkTools -include("./complex_solver.jl") +include("../ComplexStep/complex_solver.jl") # Parameters u0 = [0.0, 1.0] @@ -107,7 +108,7 @@ error_complex_high = abs.((derivative_true .- derivative_complex_high) ./ deriva # Complex step Differentiation derivative_complex_exact = ComplexDiff.derivative.(ω -> solution(t₁, u0, [ω]), p[1], stepsizes) -error_complex_exact = abs.((derivative_complex .- derivative_true)./derivative_true) +error_complex_exact = abs.((derivative_complex_exact .- derivative_true)./derivative_true) # Forward AD applied to numerical solver derivative_AD_low = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], p)[1][1] @@ -150,4 +151,11 @@ plot!(ax, [stepsizes[begin], stepsizes[end]],[error_AD_high, error_AD_high], col # Add legend fig[1, 2] = Legend(fig, ax) -save("FiniteDifferences/FiniteDifferences_derivative.pdf", fig) +save("Figures/DirectMethods_comparison.pdf", fig) + + +######### Benchmark ########### + +# It looks like complex step has better performance... both in speed and momory allocation. +# @benchmark derivative_complex_low = complexstep_differentiation.(Ref(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1]), Ref(p[1]), [1e-5]) +# @benchmark derivative_AD_low = Zygote.gradient(p->solve(ODEProblem(oscilatior!, u0, tspan, p), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], p)[1][1] \ No newline at end of file diff --git a/code/DirectMethods/ComplexStep/complex_solver.jl b/code/DirectMethods/ComplexStep/complex_solver.jl new file mode 100644 index 0000000..fc29e6a --- /dev/null +++ b/code/DirectMethods/ComplexStep/complex_solver.jl @@ -0,0 +1,18 @@ +using OrdinaryDiffEq + +function dyn!(du::Array{Complex{Float64}}, u::Array{Complex{Float64}}, p, t) + ω = p[1] + du[1] = u[2] + du[2] = - ω^2 * u[1] +end + +tspan = [0.0, 10.0] +du = Array{Complex{Float64}}([0.0]) +u0 = Array{Complex{Float64}}([0.0, 1.0]) + +function complexstep_differentiation(f::Function, p::Float64, ε::Float64) + p_complex = p + ε * im + return imag(f(p_complex)) / ε +end + +complexstep_differentiation(x -> solve(ODEProblem(dyn!, u0, tspan, [x]), Tsit5()).u[end][1], 20., 1e-3) diff --git a/code/DirectMethods/DualNumbers/dualnumber_definition.jl b/code/DirectMethods/DualNumbers/dualnumber_definition.jl new file mode 100644 index 0000000..aa2aaae --- /dev/null +++ b/code/DirectMethods/DualNumbers/dualnumber_definition.jl @@ -0,0 +1,44 @@ +using Base: @kwdef + +@kwdef struct DualNumber{F <: AbstractFloat} + value::F + derivative::F + # Inner constructors + # DualNumber(value, derivative) = new(value, derivative) + # function DualNumber(value::F) where {F <: AbstractFloat} + # new(value, 0.0) + # end +end + +# Outer constructor +function DualNumber(value::F) where {F <: AbstractFloat} + DualNumber(value, 0.0) +end + +# Do we need to define this on Base? + +# Chain rules for binary opperators + +# Binary sum +Base.:(+)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value + b.value, + derivative = a.derivative + b.derivative) + +# Binary product +Base.:(*)(a::DualNumber, b::DualNumber) = DualNumber(value = a.value * b.value, + derivative = a.value*b.derivative + a.derivative*b.value) + +# Power +Base.:(^)(a::DualNumber, b::AbstractFloat) = DualNumber(value = a.value ^ b, + derivative = b * a.value^(b-1) * a.derivative) + + +# Now we define a series of variables. We are interested in computing the derivative with respect to the variable "a": + +a = DualNumber(value=1.0, derivative=1.0) + +b = DualNumber(value=2.0, derivative=0.0) +c = DualNumber(value=3.0, derivative=0.0) + +# Now, we can evaluate a new DualNumber +result = a * b * c +# println("The derivative of a*b*c with respect to a is: ", result.derivative) \ No newline at end of file diff --git a/code/DirectMethods/DualNumbers/dualnumber_tolerances.jl b/code/DirectMethods/DualNumbers/dualnumber_tolerances.jl new file mode 100644 index 0000000..cee36d6 --- /dev/null +++ b/code/DirectMethods/DualNumbers/dualnumber_tolerances.jl @@ -0,0 +1,106 @@ +using Pkg +Pkg.activate(".") + +using SciMLSensitivity +using OrdinaryDiffEq +using Zygote +using ForwardDiff +using Infiltrator + +tspan = (0.0, 10.0) +u0 = [0.0] +reltol = 1e-6 +abstol = 1e-6 + +""" + dyn! + +This generates solutions u(t) = (t-θ)^5/5 that can be solved exactly with a 5th order integrator. +""" +function dyn!(du, u, p, t) + θ = p[1] + du .= (t .- θ).^4.0 +end + +p = [1.0] + +prob = ODEProblem(dyn!, u0, tspan, p) +sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol) + +# We can see that the time steps increase with non-stop +# @show diff(sol.t) + +function loss(p, sensealg) + prob = ODEProblem(dyn!, u0, tspan, p) + if isnothing(sensealg) + sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol) + else + sol = solve(prob, Tsit5(), sensealg=sensealg, reltol=reltol, abstol=abstol) + end + @show "Number of time steps: ", length(sol.t) + sol.u[end][1] +end + +function grad_true(p) + θ = p[1] + t = tspan[2] + θ^4 - (t - θ)^4 +end + +""" +An implementation of discrete forward sensitivity analysis through ForwardDiff.jl. +When used within adjoint differentiation (i.e. via Zygote), this will cause forward differentiation +of the solve call within the reverse-mode automatic differentiation environment. + +https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities/#SciMLSensitivity.ForwardDiffSensitivity +""" +# Original AD without correction + +condition(u, t, integrator) = true +function printstepsize!(integrator) + # @infiltrate + if length(integrator.sol.t) > 1 + # println("Stepsize at step ", length(integrator.sol.t), ": ", integrator.sol.t[end] - integrator.sol.t[end-1]) + end +end + +cb = DiscreteCallback(condition, printstepsize!) + +# g1 = Zygote.gradient(p -> loss(p, ForwardDiffSensitivity()), internalnorm = (u,t) -> sum(abs2,u/length(u)), p) +g1 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p), + Tsit5(), + u0 = u0, + p = p, + sensealg = ForwardDiffSensitivity(), + saveat = 0.1, + internalnorm = (u,t) -> sum(abs2, u/length(u)), + callback = cb, + reltol=1e-6, + abstol=1e-6).u[end][1], p) +@show g1 + +# Forward Sensitivity +# g2 = Zygote.gradient(p -> loss(p, ForwardSensitivity()), p) +# g2 = Zygote.gradient(p -> solve(prob, +# Tsit5(), +# sensealg = ForwardSensitivity(), +# saveat = 0.1, +# callback = cb, +# reltol=1e-12, +# abstol=1e-12).u[end][1], p) +# @show g2 + +# Corrected AD +# g3 = ForwardDiff.gradient(p -> loss(p, nothing), p) +g3 = Zygote.gradient(p -> solve(ODEProblem(dyn!, u0, tspan, p), + Tsit5(), + sensealg = ForwardDiffSensitivity(), + # saveat = 0.1, + # callback = cb, + reltol=1e-6, + abstol=1e-6).u[end][1], p) +@show g3 + +@show grad_true(p) + +# Define customized RK(4) solver with given timesteps to show the divergence of forward sensitivities \ No newline at end of file diff --git a/code/FiniteDifferences/complex_solver.jl b/code/FiniteDifferences/complex_solver.jl deleted file mode 100644 index 2ad3dee..0000000 --- a/code/FiniteDifferences/complex_solver.jl +++ /dev/null @@ -1,25 +0,0 @@ -using OrdinaryDiffEq -using CairoMakie -using ComplexDiff -using Zygote, ForwardDiff, SciMLSensitivity - -function oscilatior!(du_complex::Array{Complex{Float64}}, u::Array{Complex{Float64}}, p, t) - ω = p[1] - du_complex[1] = u[2] - du_complex[2] = - ω^2 * u[1] - nothing -end - -du_complex = Array{Complex{Float64}}([0.0]) -u0_complex = Array{Complex{Float64}}([0.0, 1.0]) -p_complex = Array{Complex{Float64}}([20.]) .+ 0.1im - -function complexstep_differentiation(f::Function, p::Float64, h::Float64) - p_complex = p .+ h * im - res = f(p_complex) - return imag(res) / h -end - -# sol = solve(ODEProblem(oscilatior!, u0_complex, tspan, p_complex), Tsit5(), reltol=1e-6, abstol=1e-6) -deriv = complexstep_differentiation(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, [x]), Tsit5(), reltol=1e-6, abstol=1e-6).u[end][1], 20., 0.001) -# deriv = ComplexDiff.derivative(x -> solve(ODEProblem(oscilatior!, u0_complex, tspan, x), Tsit5(), reltol=1e-6, abstol=1e-6), [20.], 0.1)