Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…DiffEq.jl into Fixes10
  • Loading branch information
ParamThakkar123 committed Sep 15, 2024
2 parents 6880a2d + c16577c commit 6356272
Show file tree
Hide file tree
Showing 22 changed files with 225 additions and 182 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ jobs:
Pkg.develop(map(path ->Pkg.PackageSpec.(;path="$(@__DIR__)/lib/$(path)"), readdir("./lib")));
'
- uses: julia-actions/julia-runtest@v1
with:
coverage: false
check_bounds: auto
env:
GROUP: ${{ matrix.group }}
- uses: julia-actions/julia-processcoverage@v1
Expand Down
2 changes: 1 addition & 1 deletion docs/src/massmatrixdae/Rosenbrock.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ the former is more efficient, but the latter is more reliable.
For larger systems look at multistep methods.

!!! warn

In order to use OrdinaryDiffEqRosenbrock with DAEs that require a non-trivial
consistent initialization, a nonlinear solver is required and thus
`using OrdinaryDiffEqNonlinearSolve` is required or you must pass an `initializealg`
Expand Down
7 changes: 5 additions & 2 deletions lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,8 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order},
end
tmp = -uprev * bdf_coeffs[k, 2]
for i in 1:(k - 1)
tmp = @.. tmp - $(_reshape(view(u_corrector, :, i), axes(u))) * bdf_coeffs[k, i + 2]
tmp = @.. tmp -
$(_reshape(view(u_corrector, :, i), axes(u))) * bdf_coeffs[k, i + 2]
end
end

Expand Down Expand Up @@ -1170,7 +1171,9 @@ function perform_step!(integrator, cache::FBDFConstantCache{max_order},
terk *= abs(dt^(k))
else
for i in 2:(k + 1)
terk = @.. terk + fd_weights[i, k + 1] * $(_reshape(view(u_history, :, i - 1), axes(u)))
terk = @.. terk +
fd_weights[i, k + 1] *
$(_reshape(view(u_history, :, i - 1), axes(u)))
end
terk *= abs(dt^(k))
end
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqBDF/src/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ function choose_order!(alg::FBDF, integrator,
terk_tmp = similar(u)
@.. terk_tmp = fd_weights[k - 2, 1] * _vec(u)
for i in 2:(k - 2)
@.. terk_tmp += fd_weights[i, k - 2] * $(_reshape(view(u_history, :, i - 1), axes(u)))
@.. terk_tmp += fd_weights[i, k - 2] *
$(_reshape(view(u_history, :, i - 1), axes(u)))
end
@.. terk_tmp *= abs(dt^(k - 2))
end
Expand Down
12 changes: 12 additions & 0 deletions lib/OrdinaryDiffEqCore/src/caches/basic_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ abstract type OrdinaryDiffEqMutableCache <: OrdinaryDiffEqCache end
struct ODEEmptyCache <: OrdinaryDiffEqConstantCache end
struct ODEChunkCache{CS} <: OrdinaryDiffEqConstantCache end

ismutablecache(cache::OrdinaryDiffEqMutableCache) = true
ismutablecache(cache::OrdinaryDiffEqConstantCache) = false

# Don't worry about the potential alloc on a constant cache
get_fsalfirstlast(cache::OrdinaryDiffEqConstantCache, u) = (zero(u), zero(u))

Expand All @@ -13,6 +16,10 @@ mutable struct CompositeCache{T, F} <: OrdinaryDiffEqCache
current::Int
end

function ismutablecache(cache::CompositeCache{T, F}) where {T, F}
eltype(T) <: OrdinaryDiffEqMutableCache
end

function get_fsalfirstlast(cache::CompositeCache, u)
_x = get_fsalfirstlast(cache.caches[1], u)
if first(_x) !== nothing
Expand Down Expand Up @@ -44,6 +51,11 @@ function get_fsalfirstlast(cache::DefaultCache, u)
(cache.u, cache.u) # will be overwritten by the cache choice
end

function ismutablecache(cache::DefaultCache{
T1, T2, T3, T4, T5, T6, A, F, uType}) where {T1, T2, T3, T4, T5, T6, A, F, uType}
T1 <: OrdinaryDiffEqMutableCache
end

function alg_cache(alg::CompositeAlgorithm, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
Expand Down
32 changes: 18 additions & 14 deletions lib/OrdinaryDiffEqCore/src/initialize_dae.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,23 @@ function default_nlsolve(
end

function default_nlsolve(
::Nothing, isinplace::Val{false}, u::Nothing, ::NonlinearProblem, autodiff = false)
nothing
::Nothing, isinplace::Val{false}, u::Nothing, ::NonlinearProblem, autodiff = false)
nothing
end

function default_nlsolve(
::Nothing, isinplace::Val{false}, u::Nothing, ::NonlinearLeastSquaresProblem, autodiff = false)
nothing
::Nothing, isinplace::Val{false}, u::Nothing,
::NonlinearLeastSquaresProblem, autodiff = false)
nothing
end

function OrdinaryDiffEqCore.default_nlsolve(::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
function OrdinaryDiffEqCore.default_nlsolve(
::Nothing, isinplace, u, ::NonlinearProblem, autodiff = false)
error("This ODE requires a DAE initialization and thus a nonlinear solve but no nonlinear solve has been loaded. To solve this problem, do `using OrdinaryDiffEqNonlinearSolve` or pass a custom `nlsolve` choice into the `initializealg`.")
end

function OrdinaryDiffEqCore.default_nlsolve(::Nothing, isinplace, u, ::NonlinearLeastSquaresProblem, autodiff = false)
function OrdinaryDiffEqCore.default_nlsolve(
::Nothing, isinplace, u, ::NonlinearLeastSquaresProblem, autodiff = false)
error("This ODE requires a DAE initialization and thus a nonlinear solve but no nonlinear solve has been loaded. To solve this problem, do `using OrdinaryDiffEqNonlinearSolve` or pass a custom `nlsolve` choice into the `initializealg`.")
end

Expand Down Expand Up @@ -179,12 +182,13 @@ end

## CheckInit
struct CheckInitFailureError <: Exception
normresid
abstol
normresid::Any
abstol::Any
end

function Base.showerror(io::IO, e::CheckInitFailureError)
print(io, "CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)")
print(io,
"CheckInit specified but initialization not satisifed. normresid = $(e.normresid) > abstol = $(e.abstol)")
end

function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
Expand All @@ -202,7 +206,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))

normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end
Expand All @@ -221,7 +225,7 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::CheckInit,
resid = _vec(du)[algebraic_eqs]

normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end
Expand All @@ -234,7 +238,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,

f(resid, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end
Expand All @@ -252,7 +256,7 @@ function _initialize_dae!(integrator, prob::DAEProblem,

resid = f(integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@ end
function DiffEqBase.reeval_internals_due_to_modification!(
integrator::ODEIntegrator, continuous_modification = true;
callback_initializealg = nothing)

if integrator.isdae
DiffEqBase.initialize_dae!(integrator, isnothing(callback_initializealg) ? integrator.initializealg : callback_initializealg)
DiffEqBase.initialize_dae!(integrator,
isnothing(callback_initializealg) ? integrator.initializealg :
callback_initializealg)
update_uprev!(integrator)
end

Expand Down
6 changes: 1 addition & 5 deletions lib/OrdinaryDiffEqCore/src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,7 @@ function reset_fsal!(integrator)
# Ignore DAEs but they already re-ran initialization
# Mass matrix DAEs do need to reset FSAL if available
if !(integrator.sol.prob isa DAEProblem)
if integrator.cache isa OrdinaryDiffEqMutableCache ||
(integrator.cache isa CompositeCache &&
integrator.cache.caches[1] isa OrdinaryDiffEqMutableCache) ||
(integrator.cache isa DefaultCache &&
integrator.cache.cache1 isa OrdinaryDiffEqMutableCache)
if ismutablecache(integrator.cache)
integrator.f(integrator.fsalfirst, integrator.u, integrator.p, integrator.t)
else
integrator.fsalfirst = integrator.f(integrator.u, integrator.p, integrator.t)
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqCore/src/interp_func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ end

function strip_cache(cache)
if hasfield(typeof(cache), :jac_config) || hasfield(typeof(cache), :grad_config) ||
hasfield(typeof(cache), :nlsolver) || hasfield(typeof(cache), :tf) || hasfield(typeof(cache), :uf)
hasfield(typeof(cache), :nlsolver) || hasfield(typeof(cache), :tf) ||
hasfield(typeof(cache), :uf)
fieldnums = length(fieldnames(typeof(cache)))
noth_list = fill(nothing, fieldnums)
cache_type_name = Base.typename(typeof(cache)).wrapper
Expand Down
3 changes: 2 additions & 1 deletion lib/OrdinaryDiffEqDefault/test/default_solver_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using OrdinaryDiffEqDefault, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner, OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF
using OrdinaryDiffEqDefault, OrdinaryDiffEqTsit5, OrdinaryDiffEqVerner,
OrdinaryDiffEqRosenbrock, OrdinaryDiffEqBDF
using Test, LinearSolve, LinearAlgebra, SparseArrays, StaticArrays

f_2dlinear = (du, u, p, t) -> (@. du = p * u)
Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqDifferentiation/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,4 @@ end
@generated function pick_static_chunksize(::Val{chunksize}) where {chunksize}
x = ForwardDiff.pickchunksize(chunksize)
:(Val{$x}())
end
end
48 changes: 24 additions & 24 deletions lib/OrdinaryDiffEqDifferentiation/src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct StaticWOperator{isinv, T, F} <: AbstractSciMLOperator{T}
# doing to how StaticArrays and StaticArraysCore are split up
StaticArrays.LU(LowerTriangular(W), UpperTriangular(W), SVector{n}(1:n))
else
lu(W, check=false)
lu(W, check = false)
end
# when constructing W for the first time for the type
# inv(W) can be singular
Expand Down Expand Up @@ -938,28 +938,28 @@ function LinearSolve.init_cacheval(
end

for alg in [LinearSolve.AppleAccelerateLUFactorization,
LinearSolve.BunchKaufmanFactorization,
LinearSolve.CHOLMODFactorization,
LinearSolve.CholeskyFactorization,
LinearSolve.CudaOffloadFactorization,
LinearSolve.DiagonalFactorization,
LinearSolve.FastLUFactorization,
LinearSolve.FastQRFactorization,
LinearSolve.GenericFactorization,
LinearSolve.GenericLUFactorization,
LinearSolve.KLUFactorization,
LinearSolve.LDLtFactorization,
LinearSolve.LUFactorization,
LinearSolve.MKLLUFactorization,
LinearSolve.MetalLUFactorization,
LinearSolve.NormalBunchKaufmanFactorization,
LinearSolve.NormalCholeskyFactorization,
LinearSolve.QRFactorization,
LinearSolve.RFLUFactorization,
LinearSolve.SVDFactorization,
LinearSolve.SimpleLUFactorization,
LinearSolve.SparspakFactorization,
LinearSolve.UMFPACKFactorization]
LinearSolve.BunchKaufmanFactorization,
LinearSolve.CHOLMODFactorization,
LinearSolve.CholeskyFactorization,
LinearSolve.CudaOffloadFactorization,
LinearSolve.DiagonalFactorization,
LinearSolve.FastLUFactorization,
LinearSolve.FastQRFactorization,
LinearSolve.GenericFactorization,
LinearSolve.GenericLUFactorization,
LinearSolve.KLUFactorization,
LinearSolve.LDLtFactorization,
LinearSolve.LUFactorization,
LinearSolve.MKLLUFactorization,
LinearSolve.MetalLUFactorization,
LinearSolve.NormalBunchKaufmanFactorization,
LinearSolve.NormalCholeskyFactorization,
LinearSolve.QRFactorization,
LinearSolve.RFLUFactorization,
LinearSolve.SVDFactorization,
LinearSolve.SimpleLUFactorization,
LinearSolve.SparspakFactorization,
LinearSolve.UMFPACKFactorization]
@eval function LinearSolve.init_cacheval(alg::$alg, A::WOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool,
assumptions::OperatorAssumptions)
Expand Down Expand Up @@ -1003,4 +1003,4 @@ function resize_J_W!(cache, integrator, i)
end

nothing
end
end
20 changes: 12 additions & 8 deletions lib/OrdinaryDiffEqExtrapolation/src/extrapolation_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1193,7 +1193,7 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
for j in 2:j_int
f(k, cache.u_temp1, p, t + (j - 1) * dt_int)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@.. broadcast=false linsolve_tmps[1]=k - (u_temp1 - u_temp2)/dt_int
@.. broadcast=false linsolve_tmps[1]=k - (u_temp1 - u_temp2) / dt_int

linsolve = cache.linsolve[1]

Expand Down Expand Up @@ -1270,7 +1270,8 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
p, t + (j - 1) * dt_int_temp)
@.. broadcast=false linsolve_tmps[Threads.threadid()]=k_tmps[Threads.threadid()] -
(u_temp3[Threads.threadid()] -
u_temp4[Threads.threadid()])/dt_int_temp
u_temp4[Threads.threadid()]) /
dt_int_temp

linsolve = cache.linsolve[Threads.threadid()]

Expand Down Expand Up @@ -1354,7 +1355,8 @@ function perform_step!(integrator, cache::ImplicitDeuflhardExtrapolationCache,
p, t + (j - 1) * dt_int_temp)
@.. broadcast=false linsolve_tmps[Threads.threadid()]=k_tmps[Threads.threadid()] -
(u_temp3[Threads.threadid()] -
u_temp4[Threads.threadid()])/dt_int_temp
u_temp4[Threads.threadid()]) /
dt_int_temp

linsolve = cache.linsolve[Threads.threadid()]

Expand Down Expand Up @@ -2555,7 +2557,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
for j in 2:(j_int + 1)
f(k, cache.u_temp1, p, t + (j - 1) * dt_int)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@.. broadcast=false linsolve_tmps[1]=k - (u_temp1 - u_temp2)/dt_int
@.. broadcast=false linsolve_tmps[1]=k - (u_temp1 - u_temp2) / dt_int

linsolve = cache.linsolve[1]

Expand Down Expand Up @@ -2634,9 +2636,10 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
f(k_tmps[Threads.threadid()],
cache.u_temp3[Threads.threadid()],
p, t + (j - 1) * dt_int_temp)
@.. broadcast=false linsolve_tmps[Threads.threadid()]= k_tmps[Threads.threadid()] -
@.. broadcast=false linsolve_tmps[Threads.threadid()]=k_tmps[Threads.threadid()] -
(u_temp3[Threads.threadid()] -
u_temp4[Threads.threadid()]) / dt_int_temp
u_temp4[Threads.threadid()]) /
dt_int_temp

linsolve = cache.linsolve[Threads.threadid()]
if !repeat_step && j == 1
Expand Down Expand Up @@ -2717,7 +2720,8 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
for j in 2:(j_int_temp + 1)
f(ktmp, cache.u_temp3[tid], p, t + (j - 1) * dt_int_temp)
@.. broadcast=false linsolvetmp=ktmp -
(u_temp3[tid] - u_temp4[tid])/dt_int_temp
(u_temp3[tid] - u_temp4[tid]) /
dt_int_temp

linsolve = cache.linsolve[tid]

Expand Down Expand Up @@ -2832,7 +2836,7 @@ function perform_step!(integrator, cache::ImplicitHairerWannerExtrapolationCache
for j in 2:(j_int + 1)
f(k, cache.u_temp1, p, t + (j - 1) * dt_int)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)
@.. broadcast=false linsolve_tmps[1]=k - (u_temp1 - u_temp2)/dt_int
@.. broadcast=false linsolve_tmps[1]=k - (u_temp1 - u_temp2) / dt_int

linsolve = cache.linsolve[1]

Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqFIRK/src/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ end
rhs3 = @.. broadcast=false fw3 - β1dt * Mw2-α1dt * Mw3
rhs4 = @.. broadcast=false fw4 - α2dt * Mw4+β2dt * Mw5
rhs5 = @.. broadcast=false fw5 - β2dt * Mw4-α2dt * Mw5
dw1 = _reshape(LU1 \ _vec(rhs1), axes(u))
dw1 = _reshape(LU1 \ _vec(rhs1), axes(u))
dw23 = _reshape(LU2 \ _vec(@.. broadcast=false rhs2+rhs3 * im), axes(u))
dw45 = _reshape(LU3 \ _vec(@.. broadcast=false rhs4+rhs5 * im), axes(u))
integrator.stats.nsolve += 3
Expand Down
12 changes: 7 additions & 5 deletions lib/OrdinaryDiffEqNonlinearSolve/src/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
function default_nlsolve(::Nothing, isinplace::Val{true}, u, ::NonlinearProblem, autodiff = false)
function default_nlsolve(
::Nothing, isinplace::Val{true}, u, ::NonlinearProblem, autodiff = false)
FastShortcutNonlinearPolyalg(;
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(
::Nothing, isinplace::Val{true}, u, ::NonlinearLeastSquaresProblem, autodiff = false)
::Nothing, isinplace::Val{true}, u, ::NonlinearLeastSquaresProblem, autodiff = false)
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(::Nothing, isinplace::Val{false}, u, ::NonlinearProblem, autodiff = false)
function default_nlsolve(
::Nothing, isinplace::Val{false}, u, ::NonlinearProblem, autodiff = false)
FastShortcutNonlinearPolyalg(;
autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(
::Nothing, isinplace::Val{false}, u, ::NonlinearLeastSquaresProblem, autodiff = false)
::Nothing, isinplace::Val{false}, u, ::NonlinearLeastSquaresProblem, autodiff = false)
FastShortcutNLLSPolyalg(; autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
::NonlinearProblem, autodiff = false)
::NonlinearProblem, autodiff = false)
SimpleTrustRegion(autodiff = autodiff ? AutoForwardDiff() : AutoFiniteDiff())
end
function default_nlsolve(::Nothing, isinplace::Val{false}, u::StaticArray,
Expand Down
Loading

0 comments on commit 6356272

Please sign in to comment.