Skip to content

Commit

Permalink
Allow non vector inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 18, 2023
1 parent 1bb9c2e commit b3cfe50
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 36 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "DiffEqCallbacks"
uuid = "459566f4-90b8-5000-8ac3-15dfb0a30def"
authors = ["Chris Rackauckas <[email protected]>"]
version = "2.29.1"
version = "2.30.0"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Expand All @@ -24,23 +25,24 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
DataStructures = "0.18"
DiffEqBase = "6.53.3"
ForwardDiff = "0.10"
Functors = "0.4"
NLsolve = "4.2"
OrdinaryDiffEq = "6.14"
Parameters = "0.12"
RecipesBase = "0.7, 0.8, 1.0"
RecursiveArrayTools = "2"
SciMLBase = "1.48.1"
Sundials = "4.19.2"
StaticArraysCore = "1.4"
Sundials = "4.19.2"
julia = "1.6"

[extras]
ODEProblemLibrary = "fdc4e326-1af4-4b90-96e7-779fcce2daa5"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
5 changes: 2 additions & 3 deletions src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module DiffEqCallbacks

using DiffEqBase,
RecursiveArrayTools, DataStructures, RecipesBase, StaticArraysCore,
NLsolve, ForwardDiff
using DiffEqBase, RecursiveArrayTools, DataStructures, RecipesBase, LinearAlgebra,
StaticArraysCore, NLsolve, ForwardDiff, Functors

import Base.Iterators

Expand Down
54 changes: 40 additions & 14 deletions src/integrating.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
# allocate_zeros
function allocate_zeros(p::AbstractArray{T}) where {T}
integral = similar(p)
fill!(integral, zero(T))
return integral
end
allocate_zeros(p::Tuple) = allocate_zeros.(p)
allocate_zeros(p::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_zeros(values(p)))
allocate_zeros(p) = fmap(allocate_zeros, p)

Check warning on line 9 in src/integrating.jl

View check run for this annotation

Codecov / codecov/patch

src/integrating.jl#L9

Added line #L9 was not covered by tests

# axpy!
recursive_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y)
recursive_axpy!(α, x::Tuple, y::Tuple) = recursive_axpy!.(α, x, y)
function recursive_axpy!(α, x::NamedTuple{F}, y::NamedTuple{F}) where {F}
return NamedTuple{F}(recursive_axpy!(α, values(x), values(y)))
end
recursive_axpy!(α, x, y) = fmap(Base.Fix1(recursive_axpy!, α), x, y)

Check warning on line 17 in src/integrating.jl

View check run for this annotation

Codecov / codecov/patch

src/integrating.jl#L17

Added line #L17 was not covered by tests

# scalar_mul!
recursive_scalar_mul!(x::AbstractArray, α) = x .*= α
recursive_scalar_mul!(x::Tuple, α) = recursive_scalar_mul!.(x, α)
function recursive_scalar_mul!(x::NamedTuple{F}, α) where {F}
return NamedTuple{F}(recursive_scalar_mul!(values(x), α))
end
recursive_scalar_mul!(x, α) = fmap(Base.Fix1(recursive_scalar_mul!, α), x)

Check warning on line 25 in src/integrating.jl

View check run for this annotation

Codecov / codecov/patch

src/integrating.jl#L25

Added line #L25 was not covered by tests

"""
gauss_points::Vector{Vector{Float64}}
Expand Down Expand Up @@ -159,27 +185,26 @@ end

function (affect!::SavingIntegrandAffect)(integrator)
n = div(SciMLBase.alg_order(integrator.alg) + 1, 2)
integral = zeros(eltype(eltype(affect!.integrand_values.integrand)),
length(integrator.p))
integral = allocate_zeros(integrator.p)
for i in 1:n
t_temp = ((integrator.t - integrator.tprev) / 2) * gauss_points[n][i] +
t_temp = ((integrator.t - integrator.tprev) / 2) * gauss_points[n][i] +
(integrator.t + integrator.tprev) / 2
if DiffEqBase.isinplace(integrator.sol.prob)
curu = first(get_tmp_cache(integrator))
integrator(curu, t_temp)
if affect!.integrand_cache == Nothing
integral .+= gauss_weights[n][i] *
affect!.integrand_func(integrator(t_temp), t_temp, integrator)
if affect!.integrand_cache == nothing
recursive_axpy!(gauss_weights[n][i],

Check warning on line 196 in src/integrating.jl

View check run for this annotation

Codecov / codecov/patch

src/integrating.jl#L196

Added line #L196 was not covered by tests
affect!.integrand_func(curu, t_temp, integrator), integral)
else
affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator)
integral .+= gauss_weights[n][i] * affect!.integrand_cache
recursive_axpy!(gauss_weights[n][i], affect!.integrand_cache, integral)
end
else
integral .+= gauss_weights[n][i] *
affect!.integrand_func(integrator(t_temp), t_temp, integrator)
recursive_axpy!(gauss_weights[n][i],
affect!.integrand_func(integrator(t_temp), t_temp, integrator), integral)
end
end
integral *= -(integrator.t - integrator.tprev) / 2
recursive_scalar_mul!(integral, -(integrator.t - integrator.tprev) / 2)
push!(affect!.integrand_values.integrand, integral)
u_modified!(integrator, false)
end
Expand All @@ -188,7 +213,7 @@ end
```julia
IntegratingCallback(integrand_func,
integrand_values::IntegrandValues,
cache = Nothing)
cache = nothing)
```
Lets one define a function `integrand_func(u, t, integrator)` which
Expand All @@ -204,7 +229,7 @@ returns Integral(integrand_func(u(t),t)dt over the problem tspan.
`IntegrandValues(integrandType)`, i.e. give the type
that `integrand_func` will output (or higher compatible type).
- `cache` is provided to store `integrand_func` output for in-place problems.
if `cache` is `Nothing` but the problem is in-place, then `integrand_func`
if `cache` is `nothing` but the problem is in-place, then `integrand_func`
is assumed to not be in-place and will be called as `out = integrand_func(u, t, integrator)`.
The outputted values are saved into `integrand_values`. The values are found
Expand All @@ -217,10 +242,11 @@ via `integrand_values.integrand`.
If `integrand_func` is in-place, you must use `cache` to store the output of `integrand_func`.
"""
function IntegratingCallback(integrand_func, integrand_values::IntegrandValues, cache=Nothing)
function IntegratingCallback(integrand_func, integrand_values::IntegrandValues,
cache = nothing)
affect! = SavingIntegrandAffect(integrand_func, integrand_values, cache)
condition = (u, t, integrator) -> true
DiscreteCallback(condition, affect!, save_positions=(false,false))
DiscreteCallback(condition, affect!, save_positions = (false, false))
end

export IntegratingCallback, IntegrandValues
82 changes: 66 additions & 16 deletions test/interpolating_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,36 @@ function lotka_volterra(u, p, t)
return [dx, dy]
end

function lotka_volterra(u, p::NamedTuple, t)
x, y = u
α, β = p.x.αβ
δ, γ = p.δγ
dx = α * x - β * x * y
dy = -δ * y + γ * x * y
return [dx, dy]
end

function adjoint(u, p, t, sol)
return -vjp((x) -> lotka_volterra(x, p, t), sol(t), u)[1] -
Zygote.gradient((x) -> g(x, p, t), sol(t))[1]
end

function adjoint_inplace(du, u, p, t, sol)
du .= -vjp((x)->lotka_volterra(x,p,t), sol(t), u)[1] - Zygote.gradient((x)->g(x,p,t), sol(t))[1]
du .= -vjp((x) -> lotka_volterra(x, p, t), sol(t), u)[1] -
Zygote.gradient((x) -> g(x, p, t), sol(t))[1]
end

u0 = [1.0, 1.0] #initial condition
tspan = (0.0, 10.0) #simulation time
p = [1.5, 1.0, 3.0, 1.0] # Lotka-Volterra parameters
p_nt = (; x = (; αβ = [1.5, 1.0]), δγ = [3.0, 1.0]) # Lotka-Volterra parameters as NamedTuple

prob = ODEProblem(lotka_volterra, u0, tspan, p)
sol = solve(prob, Tsit5(), abstol = 1e-14, reltol = 1e-14)

prob_nt = remake(prob, p = p_nt)
sol_nt = solve(prob_nt, Tsit5(), abstol = 1e-14, reltol = 1e-14)

# total loss functional
function G(p)
tmp_prob = remake(prob, p = p)
Expand All @@ -59,24 +74,41 @@ function callback_saving_inplace(du, u, t, integrator, sol)
temp = sol(t)
du .= vjp((x) -> lotka_volterra(temp, x, t), integrator.p, u)[1]
end
cb = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol),
cb = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol),
integrand_values)
cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_inplace(du, u, t, integrator, sol),
integrand_values_inplace, zeros(length(p)))
prob_adjoint = ODEProblem((u, p, t) -> adjoint(u, p, t, sol),
[0.0, 0.0],
(tspan[end], tspan[1]),
p,
callback = cb)
prob_adjoint_inplace = ODEProblem((du, u, p, t) -> adjoint_inplace(du, u, p, t, sol),
[0.0, 0.0],
(tspan[end], tspan[1]),
p,
callback = cb_inplace)
cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_inplace(du,
u, t, integrator, sol),
integrand_values_inplace, zeros(length(p)))
prob_adjoint = ODEProblem((u, p, t) -> adjoint(u, p, t, sol), [0.0, 0.0],
(tspan[end], tspan[1]), p; callback = cb)
prob_adjoint_inplace = ODEProblem((du, u, p, t) -> adjoint_inplace(du, u, p, t, sol),
[0.0, 0.0], (tspan[end], tspan[1]), p; callback = cb_inplace)

sol_adjoint = solve(prob_adjoint, Tsit5(), abstol = 1e-14, reltol = 1e-14)
sol_adjoint_inplace = solve(prob_adjoint_inplace, Tsit5(), abstol = 1e-14, reltol = 1e-14)

function callback_saving_inplace_nt(du, u, t, integrator, sol)
temp = sol(t)
res = vjp((x) -> lotka_volterra(temp, x, t), integrator.p, u)[1]
DiffEqCallbacks.fmap((y, x) -> copyto!(y, x), du, res)
return du
end
integrand_values_nt = IntegrandValues(typeof(p_nt))
integrand_values_inplace_nt = IntegrandValues(typeof(p_nt))
cb = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol),
integrand_values_nt)
cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_inplace_nt(du,
u, t, integrator, sol),
integrand_values_inplace_nt, DiffEqCallbacks.allocate_zeros(p_nt))
prob_adjoint_nt = ODEProblem((u, p, t) -> adjoint(u, p, t, sol_nt), [0.0, 0.0],
(tspan[end], tspan[1]), p_nt; callback = cb)
prob_adjoint_nt_inplace = ODEProblem((du, u, p, t) -> adjoint_inplace(du, u, p, t, sol_nt),
[0.0, 0.0], (tspan[end], tspan[1]), p_nt; callback = cb_inplace)

sol_adjoint_nt = solve(prob_adjoint_nt, Tsit5(), abstol = 1e-14, reltol = 1e-14)
sol_adjoint_nt_inplace = solve(prob_adjoint_nt_inplace, Tsit5(), abstol = 1e-14,
reltol = 1e-14)

function compute_dGdp(integrand)
temp = zeros(length(integrand.integrand), length(integrand.integrand[1]))
for i in 1:length(integrand.integrand)
Expand All @@ -90,8 +122,22 @@ end
dGdp_new = compute_dGdp(integrand_values)
dGdp_new_inplace = compute_dGdp(integrand_values_inplace)

function compute_dGdp_nt(integrand)
temp = zeros(length(integrand.integrand), 4)
for i in 1:length(integrand.integrand)
temp[i, 1:2] .= integrand.integrand[i].x.αβ
temp[i, 3:4] .= integrand.integrand[i].δγ
end
return sum(temp, dims = 1)[:]
end

dGdp_new_nt = compute_dGdp_nt(integrand_values_nt)
dGdp_new_inplace_nt = compute_dGdp_nt(integrand_values_inplace_nt)

@test isapprox(dGdp_ForwardDiff, dGdp_new, atol = 1e-11, rtol = 1e-11)
@test isapprox(dGdp_ForwardDiff, dGdp_new_inplace, atol = 1e-11, rtol = 1e-11)
@test isapprox(dGdp_ForwardDiff, dGdp_new_nt, atol = 1e-11, rtol = 1e-11)
@test isapprox(dGdp_ForwardDiff, dGdp_new_inplace_nt, atol = 1e-11, rtol = 1e-11)

#### TESTING ON LINEAR SYSTEM WITH ANALYTICAL SOLUTION ####
function simple_linear_system(u, p, t)
Expand Down Expand Up @@ -184,7 +230,11 @@ function callback_saving_linear_inplace(du, u, t, integrator, sol)
end
cb = IntegratingCallback((u, t, integrator) -> callback_saving_linear(u, t, integrator, sol),
integrand_values)
cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_linear_inplace(du, u, t, integrator, sol),
cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_linear_inplace(du,
u,
t,
integrator,
sol),
integrand_values_inplace, zeros(length(p)))
prob_adjoint = ODEProblem((u, p, t) -> adjoint_linear(u, p, t, sol),
[0.0, 0.0],
Expand All @@ -201,7 +251,7 @@ sol_adjoint_inplace = solve(prob_adjoint_inplace, Tsit5(), abstol = 1e-14, relto

dGdp_new = compute_dGdp(integrand_values)
dGdp_new_inplace = compute_dGdp(integrand_values_inplace)
dGdp_analytical = analytical_derivative(p,tspan[end])
dGdp_analytical = analytical_derivative(p, tspan[end])

@test isapprox(dGdp_analytical, dGdp_new, atol = 1e-11, rtol = 1e-11)
@test isapprox(dGdp_analytical, dGdp_new_inplace, atol = 1e-11, rtol = 1e-11)

0 comments on commit b3cfe50

Please sign in to comment.