Skip to content

Commit

Permalink
current implementation, but does not seem to reduce allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
acoh64 committed Aug 15, 2023
1 parent d47cfc1 commit 56ac8c9
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 9 deletions.
11 changes: 6 additions & 5 deletions src/integrating.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ function Base.show(io::IO, integrand_values::IntegrandValues)
"\nintegrand:\n", integrand_values.integrand)
end

mutable struct SavingIntegrandAffect{IntegrandFunc, integrandType}
mutable struct SavingIntegrandAffect{IntegrandFunc, integrandType, integrandCacheType}
integrand_func::IntegrandFunc
integrand_values::IntegrandValues{integrandType}
integrand_cache::integrandCacheType
end

function (affect!::SavingIntegrandAffect)(integrator)
Expand All @@ -166,8 +167,8 @@ function (affect!::SavingIntegrandAffect)(integrator)
if DiffEqBase.isinplace(integrator.sol.prob)
curu = first(get_tmp_cache(integrator))
integrator(curu, t_temp)
integral .+= gauss_weights[n][i] *
affect!.integrand_func(curu, t_temp, integrator)
affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator)
integral .+= gauss_weights[n][i] * affect!.integrand_cache

Check warning on line 171 in src/integrating.jl

View check run for this annotation

Codecov / codecov/patch

src/integrating.jl#L170-L171

Added lines #L170 - L171 were not covered by tests
else
integral .+= gauss_weights[n][i] *
affect!.integrand_func(integrator(t_temp), t_temp, integrator)
Expand Down Expand Up @@ -205,8 +206,8 @@ The outputted values are saved into `integrand_values`. Time points are found vi
This method is currently limited to ODE solvers of order 10 or lower. Open an issue if other
solvers are required.
"""
function IntegratingCallback(integrand_func, integrand_values::IntegrandValues)
affect! = SavingIntegrandAffect(integrand_func, integrand_values)
function IntegratingCallback(integrand_func, integrand_values::IntegrandValues, cache)
affect! = SavingIntegrandAffect(integrand_func, integrand_values, cache)

Check warning on line 210 in src/integrating.jl

View check run for this annotation

Codecov / codecov/patch

src/integrating.jl#L209-L210

Added lines #L209 - L210 were not covered by tests
condition = (u, t, integrator) -> true
DiscreteCallback(condition, affect!, save_positions=(false,false))
end
Expand Down
32 changes: 28 additions & 4 deletions test/interpolating_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@ function callback_saving(u, t, integrator, sol)
temp = sol(t)
return vjp((x) -> lotka_volterra(temp, x, t), integrator.p, u)[1]
end
function callback_saving_inplace(du, u, t, integrator, sol)
temp = sol(t)
du .= vjp((x) -> lotka_volterra(temp, x, t), integrator.p, u)[1]
end
cb = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol),
integrand_values)
cb_inplace = IntegratingCallback((u, t, integrator) -> callback_saving(u, t, integrator, sol),
integrand_values_inplace)
integrand_values, 0.0)
cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_inplace(du, u, t, integrator, sol),
integrand_values_inplace, zeros(length(p)))
prob_adjoint = ODEProblem((u, p, t) -> adjoint(u, p, t, sol),
[0.0, 0.0],
(tspan[end], tspan[1]),
Expand Down Expand Up @@ -100,6 +104,11 @@ function adjoint_linear(u, p, t, sol)
return -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

function adjoint_linear_inplace(du, u, p, t, sol)
a, b = p
du .= -[0 b; -a 0] * u - 2.0 * (sol(t) .- 1.0)
end

u0 = [1.0, 1.0] #initial condition
tspan = (0.0, 10.0) #simulation time
p = [1.0, 2.0] # parameters
Expand Down Expand Up @@ -166,18 +175,33 @@ function analytical_derivative(p, t)
end

integrand_values = IntegrandValues(Vector{Float64})
integrand_values_inplace = IntegrandValues(Vector{Float64})
function callback_saving_linear(u, t, integrator, sol)
return [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end
function callback_saving_linear_inplace(du, u, t, integrator, sol)
du .= [-sol(t)[2] 0; 0 sol(t)[1]]' * u
end
cb = IntegratingCallback((u, t, integrator) -> callback_saving_linear(u, t, integrator, sol),
integrand_values)
integrand_values, 0.0)
cb_inplace = IntegratingCallback((du, u, t, integrator) -> callback_saving_linear_inplace(du, u, t, integrator, sol),
integrand_values_inplace, zeros(length(p)))
prob_adjoint = ODEProblem((u, p, t) -> adjoint_linear(u, p, t, sol),
[0.0, 0.0],
(tspan[end], tspan[1]),
p,
callback = cb)
prob_adjoint_inplace = ODEProblem((du, u, p, t) -> adjoint_linear_inplace(du, u, p, t, sol),
[0.0, 0.0],
(tspan[end], tspan[1]),
p,
callback = cb_inplace)
sol_adjoint = solve(prob_adjoint, Tsit5(), abstol = 1e-14, reltol = 1e-14)
sol_adjoint_inplace = solve(prob_adjoint_inplace, Tsit5(), abstol = 1e-14, reltol = 1e-14)

dGdp_new = compute_dGdp(integrand_values)
dGdp_new_inplace = compute_dGdp(integrand_values_inplace)
dGdp_analytical = analytical_derivative(p,tspan[end])

@test isapprox(dGdp_analytical, dGdp_new, atol = 1e-11, rtol = 1e-11)
@test isapprox(dGdp_analytical, dGdp_new_inplace, atol = 1e-11, rtol = 1e-11)

0 comments on commit 56ac8c9

Please sign in to comment.