Skip to content

Commit

Permalink
Merge pull request #209 from SciML/integrating
Browse files Browse the repository at this point in the history
Fix IntegratingCallback and IntegratingSumCallback
  • Loading branch information
ChrisRackauckas authored Mar 4, 2024
2 parents 1ee0410 + fc38d7f commit 105ba48
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 86 deletions.
40 changes: 40 additions & 0 deletions docs/src/integrating.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,46 @@ step to the order of the numerical differential equation solve, thus achieving a
dense interpolation to be saved. By doing this via a callback, this method is able to easily integrate with functionality
that introduces discontinuities, like other callbacks, in a way that is more accurate than a direct integration post solve.

The `IntegratingSumCallback` is the same, but instead of returning the timeseries of the interval results of the integration,
it simply returns the final integral value.

```@docs
IntegratingCallback
IntegrandValues
IntegratingSumCallback
IntegrandValuesSum
```

## Example

```@example integrating
using OrdinaryDiffEq, DiffEqCallbacks, Test
prob = ODEProblem((u, p, t) -> [1.0], [0.0], (0.0, 1.0))
integrated = IntegrandValues(Float64, Vector{Float64})
sol = solve(prob, Euler(),
callback = IntegratingCallback(
(u, t, integrator) -> [1.0], integrated, Float64[0.0]),
dt = 0.1)
@test all(integrated.integrand .≈ [[0.1] for i in 1:10])
integrated = IntegrandValues(Float64, Vector{Float64})
sol = solve(prob, Euler(),
callback = IntegratingCallback(
(u, t, integrator) -> [u[1]], integrated, Float64[0.0]),
dt = 0.1)
@test all(integrated.integrand .≈ [[((n * 0.1)^2 - ((n - 1) * (0.1))^2) / 2] for n in 1:10])
@test sum(integrated.integrand)[1] ≈ 0.5
integrated = IntegrandValuesSum(zeros(1))
sol = solve(prob, Euler(),
callback = IntegratingSumCallback(
(u, t, integrator) -> [1.0], integrated, Float64[0.0]),
dt = 0.1)
@test integrated.integrand[1] == 1
integrated = IntegrandValuesSum(zeros(1))
sol = solve(prob, Euler(),
callback = IntegratingSumCallback(
(u, t, integrator) -> [u[1]], integrated, Float64[0.0]),
dt = 0.1)
@test integrated.integrand[1] == 0.5
```
1 change: 1 addition & 0 deletions src/DiffEqCallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Parameters: @unpack

import SciMLBase

include("functor_helpers.jl")
include("autoabstol.jl")
include("manifold.jl")
include("domain.jl")
Expand Down
130 changes: 130 additions & 0 deletions src/functor_helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# NOTE: `fmap` can handle all these cases without us defining them, but it often makes the
# code type unstable. So we define them here to make the code type stable.
# Handle Non-Array Parameters in a Generic Fashion
"""
recursive_copyto!(y, x)
`y[:] .= vec(x)` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_copyto!(y::AbstractArray, x::AbstractArray) = copyto!(y, x)
recursive_copyto!(y::Tuple, x::Tuple) = map(recursive_copyto!, y, x)
function recursive_copyto!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
map(recursive_copyto!, values(y), values(x))
end
recursive_copyto!(y::T, x::T) where {T} = fmap(recursive_copyto!, y, x)
recursive_copyto!(y, ::Nothing) = y
recursive_copyto!(::Nothing, ::Nothing) = nothing

"""
neg!(x)
`x .*= -1` for generic `x`. This is used to handle non-array parameters!
"""
recursive_neg!(x::AbstractArray) = (x .*= -1)
recursive_neg!(x::Tuple) = map(recursive_neg!, x)
recursive_neg!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_neg!, values(x)))
recursive_neg!(x) = fmap(recursive_neg!, x)
recursive_neg!(::Nothing) = nothing

"""
zero!(x)
`x .= 0` for generic `x`. This is used to handle non-array parameters!
"""
recursive_zero!(x::AbstractArray) = (x .= 0)
recursive_zero!(x::Tuple) = map(recursive_zero!, x)
recursive_zero!(x::NamedTuple{F}) where {F} = NamedTuple{F}(map(recursive_zero!, values(x)))
recursive_zero!(x) = fmap(recursive_zero!, x)
recursive_zero!(::Nothing) = nothing

"""
recursive_sub!(y, x)
`y .-= x` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_sub!(y::AbstractArray, x::AbstractArray) = axpy!(-1, x, y)
recursive_sub!(y::Tuple, x::Tuple) = map(recursive_sub!, y, x)
function recursive_sub!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
NamedTuple{F}(map(recursive_sub!, values(y), values(x)))
end
recursive_sub!(y::T, x::T) where {T} = fmap(recursive_sub!, y, x)
recursive_sub!(y, ::Nothing) = y
recursive_sub!(::Nothing, ::Nothing) = nothing

"""
recursive_add!(y, x)
`y .+= x` for generic `x` and `y`. This is used to handle non-array parameters!
"""
recursive_add!(y::AbstractArray, x::AbstractArray) = y .+= x
recursive_add!(y::Tuple, x::Tuple) = recursive_add!.(y, x)
function recursive_add!(y::NamedTuple{F}, x::NamedTuple{F}) where {F}
NamedTuple{F}(recursive_add!(values(y), values(x)))
end
recursive_add!(y::T, x::T) where {T} = fmap(recursive_add!, y, x)
recursive_add!(y, ::Nothing) = y
recursive_add!(::Nothing, ::Nothing) = nothing

"""
allocate_vjp(λ, x)
allocate_vjp(x)
`similar(λ, size(x))` for generic `x`. This is used to handle non-array parameters!
"""
allocate_vjp::AbstractArray, x::AbstractArray) = similar(λ, size(x))
allocate_vjp::AbstractArray, x::Tuple) = allocate_vjp.((λ,), x)
function allocate_vjp::AbstractArray, x::NamedTuple{F}) where {F}
NamedTuple{F}(allocate_vjp.((λ,), values(x)))
end
allocate_vjp::AbstractArray, x) = fmap(Base.Fix1(allocate_vjp, λ), x)

allocate_vjp(x::AbstractArray) = similar(x)
allocate_vjp(x::Tuple) = allocate_vjp.(x)
allocate_vjp(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_vjp.(values(x)))
allocate_vjp(x) = fmap(allocate_vjp, x)

"""
allocate_zeros(x)
`zero.(x)` for generic `x`. This is used to handle non-array parameters!
"""
allocate_zeros(x::AbstractArray) = zero.(x)
allocate_zeros(x::Tuple) = allocate_zeros.(x)
allocate_zeros(x::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_zeros.(values(x)))
allocate_zeros(x) = fmap(allocate_zeros, x)

"""
recursive_copy(y)
`copy(y)` for generic `y`. This is used to handle non-array parameters!
"""
recursive_copy(y::AbstractArray) = copy(y)
recursive_copy(y::Tuple) = recursive_copy.(y)
recursive_copy(y::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_copy.(values(y)))
recursive_copy(y) = fmap(recursive_copy, y)

"""
recursive_adjoint(y)
`adjoint(y)` for generic `y`. This is used to handle non-array parameters!
"""
recursive_adjoint(y::AbstractArray) = adjoint(y)
recursive_adjoint(y::Tuple) = recursive_adjoint.(y)
recursive_adjoint(y::NamedTuple{F}) where {F} = NamedTuple{F}(recursive_adjoint.(values(y)))
recursive_adjoint(y) = fmap(recursive_adjoint, y)

# scalar_mul!
recursive_scalar_mul!(x::AbstractArray, α) = x .*= α
recursive_scalar_mul!(x::Tuple, α) = recursive_scalar_mul!.(x, α)
function recursive_scalar_mul!(x::NamedTuple{F}, α) where {F}
return NamedTuple{F}(recursive_scalar_mul!(values(x), α))
end
recursive_scalar_mul!(x, α) = fmap(Base.Fix1(recursive_scalar_mul!, α), x)

# axpy!
recursive_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y)
recursive_axpy!(α, x::Tuple, y::Tuple) = recursive_axpy!.(α, x, y)
function recursive_axpy!(α, x::NamedTuple{F}, y::NamedTuple{F}) where {F}
return NamedTuple{F}(recursive_axpy!(α, values(x), values(y)))
end
recursive_axpy!(α, x, y) = fmap(Base.Fix1(recursive_axpy!, α), x, y)
64 changes: 19 additions & 45 deletions src/integrating.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,3 @@
# allocate_zeros
function allocate_zeros(p::AbstractArray{T}) where {T}
integral = similar(p)
fill!(integral, zero(T))
return integral
end
allocate_zeros(p::Tuple) = allocate_zeros.(p)
allocate_zeros(p::NamedTuple{F}) where {F} = NamedTuple{F}(allocate_zeros(values(p)))
allocate_zeros(p) = fmap(allocate_zeros, p)

# axpy!
recursive_axpy!(α, x::AbstractArray, y::AbstractArray) = axpy!(α, x, y)
recursive_axpy!(α, x::Tuple, y::Tuple) = recursive_axpy!.(α, x, y)
function recursive_axpy!(α, x::NamedTuple{F}, y::NamedTuple{F}) where {F}
return NamedTuple{F}(recursive_axpy!(α, values(x), values(y)))
end
recursive_axpy!(α, x, y) = fmap(Base.Fix1(recursive_axpy!, α), x, y)

# scalar_mul!
recursive_scalar_mul!(x::AbstractArray, α) = x .*= α
recursive_scalar_mul!(x::Tuple, α) = recursive_scalar_mul!.(x, α)
function recursive_scalar_mul!(x::NamedTuple{F}, α) where {F}
return NamedTuple{F}(recursive_scalar_mul!(values(x), α))
end
recursive_scalar_mul!(x, α) = fmap(Base.Fix1(recursive_scalar_mul!, α), x)

"""
gauss_points::Vector{Vector{Float64}}
Expand Down Expand Up @@ -181,12 +155,13 @@ end
mutable struct SavingIntegrandAffect{
IntegrandFunc,
tType,
integrandType,
integrandCacheType
IntegrandType,
IntegrandCacheType
}
integrand_func::IntegrandFunc
integrand_values::IntegrandValues{tType, integrandType}
integrand_cache::integrandCacheType
integrand_values::IntegrandValues{tType, IntegrandType}
integrand_cache::IntegrandCacheType
accumulation_cache::IntegrandCacheType
end

function (affect!::SavingIntegrandAffect)(integrator)
Expand All @@ -196,7 +171,7 @@ function (affect!::SavingIntegrandAffect)(integrator)
else
n = div(SciMLBase.alg_order(integrator.alg) + 1, 2)
end
integral = allocate_zeros(integrator.p)
accumulation_cache = recursive_zero!(affect!.accumulation_cache)
for i in 1:n
t_temp = ((integrator.t - integrator.tprev) / 2) * gauss_points[n][i] +
(integrator.t + integrator.tprev) / 2
Expand All @@ -205,30 +180,30 @@ function (affect!::SavingIntegrandAffect)(integrator)
integrator(curu, t_temp)
if affect!.integrand_cache == nothing
recursive_axpy!(gauss_weights[n][i],
affect!.integrand_func(curu, t_temp, integrator), integral)
affect!.integrand_func(curu, t_temp, integrator), accumulation_cache)
else
affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator)
recursive_axpy!(gauss_weights[n][i], affect!.integrand_cache, integral)
recursive_axpy!(
gauss_weights[n][i], affect!.integrand_cache, accumulation_cache)
end
else
recursive_axpy!(gauss_weights[n][i],
affect!.integrand_func(integrator(t_temp), t_temp, integrator), integral)
affect!.integrand_func(integrator(t_temp), t_temp, integrator), accumulation_cache)
end
end
recursive_scalar_mul!(integral, -(integrator.t - integrator.tprev) / 2)
recursive_scalar_mul!(accumulation_cache, (integrator.t - integrator.tprev) / 2)
push!(affect!.integrand_values.ts, integrator.t)
push!(affect!.integrand_values.integrand, integral)
push!(affect!.integrand_values.integrand, recursive_copy(accumulation_cache))
u_modified!(integrator, false)
end

"""
```julia
IntegratingCallback(integrand_func,
integrand_values::IntegrandValues,
cache = nothing)
integrand_values::IntegrandValues, integrand_prototype)
```
Let one define a function `integrand_func(u, t, integrator)` which
Let one define a function `integrand_func(u, t, integrator)::typeof(integrand_prototype)` which
returns Integral(integrand_func(u(t),t)dt over the problem tspan.
## Arguments
Expand All @@ -240,9 +215,7 @@ returns Integral(integrand_func(u(t),t)dt over the problem tspan.
`integrand_func(t, u, integrator)::integrandType`. It's specified via
`IntegrandValues(integrandType)`, i.e. give the type
that `integrand_func` will output (or higher compatible type).
- `cache` is provided to store `integrand_func` output for in-place problems.
if `cache` is `nothing` but the problem is in-place, then `integrand_func`
is assumed to not be in-place and will be called as `out = integrand_func(u, t, integrator)`.
- `integrand_prototype` is a prototype of the output from the integrand.
The outputted values are saved into `integrand_values`. The values are found
via `integrand_values.integrand`.
Expand All @@ -254,9 +227,10 @@ via `integrand_values.integrand`.
If `integrand_func` is in-place, you must use `cache` to store the output of `integrand_func`.
"""
function IntegratingCallback(integrand_func, integrand_values::IntegrandValues,
cache = nothing)
affect! = SavingIntegrandAffect(integrand_func, integrand_values, cache)
function IntegratingCallback(
integrand_func, integrand_values::IntegrandValues, integrand_prototype)
affect! = SavingIntegrandAffect(integrand_func, integrand_values, integrand_prototype,
allocate_zeros(integrand_prototype))
condition = (u, t, integrator) -> true
DiscreteCallback(condition, affect!, save_positions = (false, false))
end
Expand Down
39 changes: 16 additions & 23 deletions src/integrating_sum.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# addition
recursive_add!(x::AbstractArray, y::AbstractArray) = x .+= y
recursive_add!(x::Tuple, y::Tuple) = recursive_add!.(x, y)
function recursive_add!(x::NamedTuple{F}, y::NamedTuple{F}) where {F}
return NamedTuple{F}(recursive_add!(values(x), values(y)))
end

"""
IntegrandValuesSum{integrandType}
Expand All @@ -29,10 +22,11 @@ function Base.show(io::IO, integrand_values::IntegrandValuesSum)
"\nintegrand:\n", integrand_values.integrand)
end

mutable struct SavingIntegrandSumAffect{IntegrandFunc, integrandType, integrandCacheType}
mutable struct SavingIntegrandSumAffect{IntegrandFunc, integrandType, IntegrandCacheType}
integrand_func::IntegrandFunc
integrand_values::IntegrandValuesSum{integrandType}
integrand_cache::integrandCacheType
integrand_cache::IntegrandCacheType
accumulation_cache::IntegrandCacheType
end

function (affect!::SavingIntegrandSumAffect)(integrator)
Expand All @@ -42,7 +36,7 @@ function (affect!::SavingIntegrandSumAffect)(integrator)
else
n = div(SciMLBase.alg_order(integrator.alg) + 1, 2)
end
integral = allocate_zeros(integrator.p)
accumulation_cache = recursive_zero!(affect!.accumulation_cache)
for i in 1:n
t_temp = ((integrator.t - integrator.tprev) / 2) * gauss_points[n][i] +
(integrator.t + integrator.tprev) / 2
Expand All @@ -51,18 +45,19 @@ function (affect!::SavingIntegrandSumAffect)(integrator)
integrator(curu, t_temp)
if affect!.integrand_cache == nothing
recursive_axpy!(gauss_weights[n][i],
affect!.integrand_func(curu, t_temp, integrator), integral)
affect!.integrand_func(curu, t_temp, integrator), accumulation_cache)
else
affect!.integrand_func(affect!.integrand_cache, curu, t_temp, integrator)
recursive_axpy!(gauss_weights[n][i], affect!.integrand_cache, integral)
recursive_axpy!(
gauss_weights[n][i], affect!.integrand_cache, accumulation_cache)
end
else
recursive_axpy!(gauss_weights[n][i],
affect!.integrand_func(integrator(t_temp), t_temp, integrator), integral)
affect!.integrand_func(integrator(t_temp), t_temp, integrator), accumulation_cache)
end
end
recursive_scalar_mul!(integral, -(integrator.t - integrator.tprev) / 2)
recursive_add!(affect!.integrand_values.integrand, integral)
recursive_scalar_mul!(accumulation_cache, (integrator.t - integrator.tprev) / 2)
recursive_add!(affect!.integrand_values.integrand, accumulation_cache)
u_modified!(integrator, false)
end

Expand All @@ -85,9 +80,7 @@ returns Integral(integrand_func(u(t),t)dt over the problem tspan.
`integrand_func(t, u, integrator)::integrandType`. It's specified via
`IntegrandValues(integrandType)`, i.e. give the type
that `integrand_func` will output (or higher compatible type).
- `cache` is provided to store `integrand_func` output for in-place problems.
if `cache` is `nothing` but the problem is in-place, then `integrand_func`
is assumed to not be in-place and will be called as `out = integrand_func(u, t, integrator)`.
- `integrand_prototype` is a prototype of the output from the integrand.
The outputted values are saved into `integrand_values`. The values are found
via `integrand_values.integrand`.
Expand All @@ -96,12 +89,12 @@ via `integrand_values.integrand`.
This method is currently limited to ODE solvers of order 10 or lower. Open an issue if other
solvers are required.
If `integrand_func` is in-place, you must use `cache` to store the output of `integrand_func`.
"""
function IntegratingSumCallback(integrand_func, integrand_values::IntegrandValuesSum,
cache = nothing)
affect! = SavingIntegrandSumAffect(integrand_func, integrand_values, cache)
function IntegratingSumCallback(
integrand_func, integrand_values::IntegrandValuesSum, integrand_prototype)
affect! = SavingIntegrandSumAffect(
integrand_func, integrand_values, integrand_prototype,
allocate_zeros(integrand_prototype))
condition = (u, t, integrator) -> true
DiscreteCallback(condition, affect!, save_positions = (false, false))
end
Expand Down
Loading

0 comments on commit 105ba48

Please sign in to comment.