From b3cfe5068fd2960e85be83dd1163dda05118167b Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 18 Sep 2023 15:22:47 -0400 Subject: [PATCH] Allow non vector inputs --- Project.toml | 8 ++-- src/DiffEqCallbacks.jl | 5 +-- src/integrating.jl | 54 +++++++++++++++++------- test/interpolating_tests.jl | 82 +++++++++++++++++++++++++++++-------- 4 files changed, 113 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index ddd6f22d..e3e12d60 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,13 @@ name = "DiffEqCallbacks" uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def" authors = ["Chris Rackauckas "] -version = "2.29.1" +version = "2.30.0" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" @@ -24,14 +25,15 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" DataStructures = "0.18" DiffEqBase = "6.53.3" ForwardDiff = "0.10" +Functors = "0.4" NLsolve = "4.2" OrdinaryDiffEq = "6.14" Parameters = "0.12" RecipesBase = "0.7, 0.8, 1.0" RecursiveArrayTools = "2" SciMLBase = "1.48.1" -Sundials = "4.19.2" StaticArraysCore = "1.4" +Sundials = "4.19.2" julia = "1.6" [extras] @@ -39,8 +41,8 @@ ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" -Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/src/DiffEqCallbacks.jl b/src/DiffEqCallbacks.jl index 8481f705..70345581 100644 --- a/src/DiffEqCallbacks.jl +++ b/src/DiffEqCallbacks.jl @@ -1,8 +1,7 @@ module DiffEqCallbacks -using DiffEqBase, - RecursiveArrayTools, DataStructures, RecipesBase, StaticArraysCore, - NLsolve, ForwardDiff +using DiffEqBase, RecursiveArrayTools, DataStructures, RecipesBase, LinearAlgebra, + StaticArraysCore, NLsolve, ForwardDiff, Functors import Base.Iterators diff --git a/src/integrating.jl b/src/integrating.jl index 5e6e7fcd..084af67f 100644 --- a/src/integrating.jl +++ b/src/integrating.jl @@ -1,3 +1,29 @@ +# allocate_zeros +function allocate_zeros(p::AbstractArray{T}) where {T} + integral = similar(p) + fill!(integral, zero(T)) + return integral +end +allocate_zeros(p::Tuple) = allocate_zeros.(p) +allocate_zeros(p::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_zeros(values(p))) +allocate_zeros(p) = fmap(allocate_zeros, p) + +# axpy! +recursive_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y) +recursive_axpy!(α, x::Tuple, y::Tuple) = recursive_axpy!.(α, x, y) +function recursive_axpy!(α, x::NamedTuple{F}, y::NamedTuple{F}) where {F} + return NamedTuple{F}(recursive_axpy!(α, values(x), values(y))) +end +recursive_axpy!(α, x, y) = fmap(Base.Fix1(recursive_axpy!, α), x, y) + +# scalar_mul! +recursive_scalar_mul!(x::AbstractArray, α) = x .*= α +recursive_scalar_mul!(x::Tuple, α) = recursive_scalar_mul!.(x, α) +function recursive_scalar_mul!(x::NamedTuple{F}, α) where {F} + return NamedTuple{F}(recursive_scalar_mul!(values(x), α)) +end +recursive_scalar_mul!(x, α) = fmap(Base.Fix1(recursive_scalar_mul!, α), x) + """ gauss_points::Vector{Vector{Float64}} @@ -159,27 +185,26 @@ end function (affect!::SavingIntegrandAffect)(integrator) n = div(SciMLBase.alg_order(integrator.alg) + 1, 2) - integral = zeros(eltype(eltype(affect!.integrand_values.integrand)), - length(integrator.p)) + integral = allocate_zeros(integrator.p) for i in 1:n - t_temp = ((integrator.t - integrator.tprev) / 2) * gauss_points[n][i] + + t_temp = ((integrator.t - integrator.tprev) / 2) * gauss_points[n][i] + (integrator.t + integrator.tprev) / 2 if DiffEqBase.isinplace(integrator.sol.prob) curu = first(get_tmp_cache(integrator)) integrator(curu, t_temp) - if affect!.integrand_cache == Nothing - integral .+= gauss_weights[n][i] * - affect!.integrand_func(integrator(t_temp), t_temp, integrator) + if affect!.integrand_cache == nothing + recursive_axpy!(gauss_weights[n][i], + affect!.integrand_func(curu, t_temp, integrator), integral) else affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator) - integral .+= gauss_weights[n][i] * affect!.integrand_cache + recursive_axpy!(gauss_weights[n][i], affect!.integrand_cache, integral) end else - integral .+= gauss_weights[n][i] * - affect!.integrand_func(integrator(t_temp), t_temp, integrator) + recursive_axpy!(gauss_weights[n][i], + affect!.integrand_func(integrator(t_temp), t_temp, integrator), integral) end end - integral *= -(integrator.t - integrator.tprev) / 2 + recursive_scalar_mul!(integral, -(integrator.t - integrator.tprev) / 2) push!(affect!.integrand_values.integrand, integral) u_modified!(integrator, false) end @@ -188,7 +213,7 @@ end ```julia IntegratingCallback(integrand_func, integrand_values::IntegrandValues, - cache = Nothing) + cache = nothing) ``` Lets one define a function `integrand_func(u, t, integrator)` which @@ -204,7 +229,7 @@ returns Integral(integrand_func(u(t),t)dt over the problem tspan. `IntegrandValues(integrandType)`, i.e. give the type that `integrand_func` will output (or higher compatible type). - `cache` is provided to store `integrand_func` output for in-place problems. - if `cache` is `Nothing` but the problem is in-place, then `integrand_func` + if `cache` is `nothing` but the problem is in-place, then `integrand_func` is assumed to not be in-place and will be called as `out = integrand_func(u, t, integrator)`. The outputted values are saved into `integrand_values`. The values are found @@ -217,10 +242,11 @@ via `integrand_values.integrand`. If `integrand_func` is in-place, you must use `cache` to store the output of `integrand_func`. """ -function IntegratingCallback(integrand_func, integrand_values::IntegrandValues, cache=Nothing) +function IntegratingCallback(integrand_func, integrand_values::IntegrandValues, + cache = nothing) affect! = SavingIntegrandAffect(integrand_func, integrand_values, cache) condition = (u, t, integrator) -> true - DiscreteCallback(condition, affect!, save_positions=(false,false)) + DiscreteCallback(condition, affect!, save_positions = (false, false)) end export IntegratingCallback, IntegrandValues diff --git a/test/interpolating_tests.jl b/test/interpolating_tests.jl index d9685d90..b19a0bc6 100644 --- a/test/interpolating_tests.jl +++ b/test/interpolating_tests.jl @@ -24,21 +24,36 @@ function lotka_volterra(u, p, t) return [dx, dy] end +function lotka_volterra(u, p::NamedTuple, t) + x, y = u + α, β = p.x.αβ + δ, γ = p.δγ + dx = α * x - β * x * y + dy = -δ * y + γ * x * y + return [dx, dy] +end + function adjoint(u, p, t, sol) return -vjp((x) -> lotka_volterra(x, p, t), sol(t), u)[1] - Zygote.gradient((x) -> g(x, p, t), sol(t))[1] end function adjoint_inplace(du, u, p, t, sol) - du .= -vjp((x)->lotka_volterra(x,p,t), sol(t), u)[1] - Zygote.gradient((x)->g(x,p,t), sol(t))[1] + du .= -vjp((x) -> lotka_volterra(x, p, t), sol(t), u)[1] - + Zygote.gradient((x) -> g(x, p, t), sol(t))[1] end u0 = [1.0, 1.0] #initial condition tspan = (0.0, 10.0) #simulation time p = [1.5, 1.0, 3.0, 1.0] # Lotka-Volterra parameters +p_nt = (; x = (; αβ = [1.5, 1.0]), δγ = [3.0, 1.0]) # Lotka-Volterra parameters as NamedTuple + prob = ODEProblem(lotka_volterra, u0, tspan, p) sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14) +prob_nt = remake(prob, p = p_nt) +sol_nt = solve(prob_nt, Tsit5(), abstol = 1e-14, reltol = 1e-14) + # total loss functional function G(p) tmp_prob = remake(prob, p = p) @@ -59,24 +74,41 @@ function callback_saving_inplace(du, u, t, integrator, sol) temp = sol(t) du .= vjp((x) -> lotka_volterra(temp, x, t), integrator.p, u)[1] end -cb = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol), +cb = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol), integrand_values) -cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_inplace(du, u, t, integrator, sol), - integrand_values_inplace, zeros(length(p))) -prob_adjoint = ODEProblem((u, p, t) -> adjoint(u, p, t, sol), - [0.0, 0.0], - (tspan[end], tspan[1]), - p, - callback = cb) -prob_adjoint_inplace = ODEProblem((du, u, p, t) -> adjoint_inplace(du, u, p, t, sol), - [0.0, 0.0], - (tspan[end], tspan[1]), - p, - callback = cb_inplace) +cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_inplace(du, + u, t, integrator, sol), + integrand_values_inplace, zeros(length(p))) +prob_adjoint = ODEProblem((u, p, t) -> adjoint(u, p, t, sol), [0.0, 0.0], + (tspan[end], tspan[1]), p; callback = cb) +prob_adjoint_inplace = ODEProblem((du, u, p, t) -> adjoint_inplace(du, u, p, t, sol), + [0.0, 0.0], (tspan[end], tspan[1]), p; callback = cb_inplace) sol_adjoint = solve(prob_adjoint, Tsit5(), abstol = 1e-14, reltol = 1e-14) sol_adjoint_inplace = solve(prob_adjoint_inplace, Tsit5(), abstol = 1e-14, reltol = 1e-14) +function callback_saving_inplace_nt(du, u, t, integrator, sol) + temp = sol(t) + res = vjp((x) -> lotka_volterra(temp, x, t), integrator.p, u)[1] + DiffEqCallbacks.fmap((y, x) -> copyto!(y, x), du, res) + return du +end +integrand_values_nt = IntegrandValues(typeof(p_nt)) +integrand_values_inplace_nt = IntegrandValues(typeof(p_nt)) +cb = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol), + integrand_values_nt) +cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_inplace_nt(du, + u, t, integrator, sol), + integrand_values_inplace_nt, DiffEqCallbacks.allocate_zeros(p_nt)) +prob_adjoint_nt = ODEProblem((u, p, t) -> adjoint(u, p, t, sol_nt), [0.0, 0.0], + (tspan[end], tspan[1]), p_nt; callback = cb) +prob_adjoint_nt_inplace = ODEProblem((du, u, p, t) -> adjoint_inplace(du, u, p, t, sol_nt), + [0.0, 0.0], (tspan[end], tspan[1]), p_nt; callback = cb_inplace) + +sol_adjoint_nt = solve(prob_adjoint_nt, Tsit5(), abstol = 1e-14, reltol = 1e-14) +sol_adjoint_nt_inplace = solve(prob_adjoint_nt_inplace, Tsit5(), abstol = 1e-14, + reltol = 1e-14) + function compute_dGdp(integrand) temp = zeros(length(integrand.integrand), length(integrand.integrand[1])) for i in 1:length(integrand.integrand) @@ -90,8 +122,22 @@ end dGdp_new = compute_dGdp(integrand_values) dGdp_new_inplace = compute_dGdp(integrand_values_inplace) +function compute_dGdp_nt(integrand) + temp = zeros(length(integrand.integrand), 4) + for i in 1:length(integrand.integrand) + temp[i, 1:2] .= integrand.integrand[i].x.αβ + temp[i, 3:4] .= integrand.integrand[i].δγ + end + return sum(temp, dims = 1)[:] +end + +dGdp_new_nt = compute_dGdp_nt(integrand_values_nt) +dGdp_new_inplace_nt = compute_dGdp_nt(integrand_values_inplace_nt) + @test isapprox(dGdp_ForwardDiff, dGdp_new, atol = 1e-11, rtol = 1e-11) @test isapprox(dGdp_ForwardDiff, dGdp_new_inplace, atol = 1e-11, rtol = 1e-11) +@test isapprox(dGdp_ForwardDiff, dGdp_new_nt, atol = 1e-11, rtol = 1e-11) +@test isapprox(dGdp_ForwardDiff, dGdp_new_inplace_nt, atol = 1e-11, rtol = 1e-11) #### TESTING ON LINEAR SYSTEM WITH ANALYTICAL SOLUTION #### function simple_linear_system(u, p, t) @@ -184,7 +230,11 @@ function callback_saving_linear_inplace(du, u, t, integrator, sol) end cb = IntegratingCallback((u, t, integrator) -> callback_saving_linear(u, t, integrator, sol), integrand_values) -cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_linear_inplace(du, u, t, integrator, sol), +cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_linear_inplace(du, + u, + t, + integrator, + sol), integrand_values_inplace, zeros(length(p))) prob_adjoint = ODEProblem((u, p, t) -> adjoint_linear(u, p, t, sol), [0.0, 0.0], @@ -201,7 +251,7 @@ sol_adjoint_inplace = solve(prob_adjoint_inplace, Tsit5(), abstol = 1e-14, relto dGdp_new = compute_dGdp(integrand_values) dGdp_new_inplace = compute_dGdp(integrand_values_inplace) -dGdp_analytical = analytical_derivative(p,tspan[end]) +dGdp_analytical = analytical_derivative(p, tspan[end]) @test isapprox(dGdp_analytical, dGdp_new, atol = 1e-11, rtol = 1e-11) @test isapprox(dGdp_analytical, dGdp_new_inplace, atol = 1e-11, rtol = 1e-11)