Skip to content

Commit

Permalink
Merge pull request #1870 from SciML/stiffnessdetect
Browse files Browse the repository at this point in the history
Make ExplicitRK handle stiffness detection
  • Loading branch information
ChrisRackauckas authored Feb 16, 2023
2 parents 229b640 + a600c14 commit 545de44
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ ArrayInterfaceGPUArrays = "0.1, 0.2"
ArrayInterfaceStaticArrays = "0.1"
ArrayInterfaceStaticArraysCore = "0.1.2"
DataStructures = "0.18"
DiffEqBase = "6.109.0"
DiffEqBase = "6.116.0"
DocStringExtensions = "0.8, 0.9"
ExponentialUtilities = "1.22"
FastBroadcast = "0.1.9, 0.2"
Expand Down
7 changes: 3 additions & 4 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ SciMLBase.forwarddiffs_model_time(alg::RosenbrockAlgorithm) = true
## OrdinaryDiffEq Internal Traits

isfsal(alg::Union{OrdinaryDiffEqAlgorithm, DAEAlgorithm}) = true
function isfsal(tab::DiffEqBase.ExplicitRKTableau{MType, VType, fsal}) where {MType, VType,
fsal}
fsal
end
isfsal(tab::DiffEqBase.ExplicitRKTableau) = tab.fsal

# isfsal(alg::CompositeAlgorithm) = isfsal(alg.algs[alg.current])
isfsal(alg::FunctionMap) = false
isfsal(alg::Rodas5) = false
Expand Down Expand Up @@ -922,6 +920,7 @@ ssp_coefficient(alg::KYK2014DGSSPRK_3S2) = 0.8417
ssp_coefficient(alg::SSPSDIRK2) = 4

# stability regions
alg_stability_size(alg::ExplicitRK) = alg.tableau.stability_size
alg_stability_size(alg::DP5) = 3.3066
alg_stability_size(alg::Tsit5) = 3.5068
alg_stability_size(alg::Vern6) = 4.8553
Expand Down
2 changes: 1 addition & 1 deletion src/constants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function constructDormandPrince(T::Type = Float64)
αEEst = map(T, αEEst)
c = map(T, c)
return (DiffEqBase.ExplicitRKTableau(A, c, α, 5, αEEst = αEEst, adaptiveorder = 4,
fsal = true))
fsal = true, stability_size = 3.3066))
end

"""
Expand Down
33 changes: 25 additions & 8 deletions src/perform_step/explicit_rk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ end
@muladd function perform_step!(integrator, cache::ExplicitRKConstantCache,
repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
alg = unwrap_alg(integrator, nothing)
alg = unwrap_alg(integrator, false)
@unpack A, c, α, αEEst, stages = cache
@unpack kk = cache

Expand All @@ -31,20 +31,28 @@ end
end

#Calc Last
utilde = zero(kk[1])
utilde_last = zero(kk[1])
for j in 1:(stages - 1)
utilde = utilde + A[j, end] * kk[j]
utilde_last = utilde_last + A[j, end] * kk[j]
end
kk[end] = f(uprev + dt * utilde, p, t + c[end] * dt)
u_beforefinal = uprev + dt * utilde_last
kk[end] = f(u_beforefinal, p, t + c[end] * dt)
integrator.destats.nf += 1
integrator.fsallast = kk[end] # Uses fsallast as temp even if not fsal

# Accumulate Result
utilde = α[1] * kk[1]
accum = α[1] * kk[1]
for i in 2:stages
utilde = utilde + α[i] * kk[i]
accum = accum + α[i] * kk[i]
end
u = uprev + dt * accum

if integrator.alg isa CompositeAlgorithm
# Hairer II, page 22
ϱu = integrator.opts.internalnorm(kk[end] - kk[end-1], t)
ϱd = integrator.opts.internalnorm(u - u_beforefinal, t)
integrator.eigen_est = ϱu / ϱd
end
u = uprev + dt * utilde

if integrator.opts.adaptive
utilde = αEEst[1] .* kk[1]
Expand Down Expand Up @@ -198,7 +206,7 @@ end

@muladd function perform_step!(integrator, cache::ExplicitRKCache, repeat_step = false)
@unpack t, dt, uprev, u, f, p = integrator
alg = unwrap_alg(integrator, nothing)
alg = unwrap_alg(integrator, false)
# αEEst is `α - αEEst`
@unpack A, c, α, αEEst, stages = cache.tab
@unpack kk, utilde, tmp, atmp = cache
Expand All @@ -211,6 +219,15 @@ end
runtime_split_fsal!(u, α, utilde, uprev, kk, dt, stages)
end

if integrator.alg isa CompositeAlgorithm
# Hairer II, page 22
@.. broadcast=false utilde= kk[end] - kk[end-1]
ϱu = integrator.opts.internalnorm(utilde, t)
@.. broadcast=false utilde=u - tmp
ϱd = integrator.opts.internalnorm(utilde, t)
integrator.eigen_est = ϱu / ϱd
end

if integrator.opts.adaptive
runtime_split_EEst!(tmp, αEEst, utilde, kk, dt, stages)
calculate_residuals!(atmp, tmp, uprev, u,
Expand Down
11 changes: 11 additions & 0 deletions test/interface/composite_algorithm_test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using OrdinaryDiffEq, Test, LinearAlgebra
import ODEProblemLibrary: prob_ode_linear, prob_ode_2Dlinear
using DiffEqDevTools

prob = prob_ode_2Dlinear
choice_function(integrator) = (Int(integrator.t < 0.5) + 1)
alg_double = CompositeAlgorithm((Tsit5(), Tsit5()), choice_function)
Expand All @@ -11,6 +13,12 @@ alg_switch = CompositeAlgorithm((Tsit5(), Vern7()), choice_function)
@test sol1.t == sol2.t
@test sol1(0.8) == sol2(0.8)

alg_double_erk = CompositeAlgorithm((ExplicitRK(), ExplicitRK()), choice_function)
@time sol1 = solve(prob_ode_linear, alg_double_erk)
@time sol2 = solve(prob_ode_linear, ExplicitRK())
@test sol1.t == sol2.t
@test sol1(0.8) == sol2(0.8)

integrator1 = init(prob, alg_double2)
integrator2 = init(prob, Vern6())
solve!(integrator1)
Expand Down Expand Up @@ -44,3 +52,6 @@ prob = ODEProblem((du, u, p, t) -> mul!(du, A, u), zeros(6), (0.0, 1000), tstops
callback = DiscreteCallback(condition, affect!));
sol = solve(prob, alg = AutoVern7(Rodas5()))
@test sol.t[end] == 1000.0

sol = solve(prob, alg = OrdinaryDiffEq.AutoAlgSwitch(ExplicitRK(constructVerner7()), Rodas5()))
@test sol.t[end] == 1000.0

0 comments on commit 545de44

Please sign in to comment.