Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add fields to OverrideInit, better nlsolve_alg handling #857

Merged
merged 11 commits into from
Nov 18, 2024
Merged
2 changes: 1 addition & 1 deletion .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ jobs:
with:
file: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: true
fail_ci_if_error: false
1 change: 0 additions & 1 deletion src/ODE_nlsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,3 @@ struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap}
"""
nlprobpmap::NLProbPmap
end

13 changes: 11 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve!
import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: AbstractADType
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset
using Expronicon.ADT: @match

Expand Down Expand Up @@ -351,7 +351,16 @@ struct CheckInit <: DAEInitializationAlgorithm end
"""
$(TYPEDEF)
"""
struct OverrideInit <: DAEInitializationAlgorithm end
struct OverrideInit{T1, T2, F} <: DAEInitializationAlgorithm
abstol::T1
reltol::T2
nlsolve::F
end

function OverrideInit(; abstol = nothing, reltol = nothing, nlsolve = nothing)
OverrideInit(abstol, reltol, nlsolve)
end
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)

# PDE Discretizations

Expand Down
101 changes: 64 additions & 37 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,26 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm)
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.")
end

struct OverrideInitNoTolerance <: Exception
tolerance::Symbol
end

function Base.showerror(io::IO, e::OverrideInitNoTolerance)
print(io,
"Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.")
end

"""
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if
Utility function to evaluate the RHS, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...)
function _evaluate_f(integrator, f, isinplace::Val{true}, args...)
tmp = first(get_tmp_cache(integrator))
f(tmp, args...)
return tmp
end

function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...)
function _evaluate_f(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

Expand All @@ -98,53 +107,49 @@ _vec(v::AbstractVector) = v

Check if the algebraic constraints are satisfied, and error if they aren't. Returns
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument.

Keyword arguments:
- `abstol`: The absolute value below which the norm of the residual of algebraic equations
should lie. The norm function used is `integrator.opts.internalnorm` if present, and
`LinearAlgebra.norm` if not.
"""
function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
function get_initial_values(
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't handle DAEProblem?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's an AbstractDAEProblem dispatch below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then this one should just be for ODEProblem?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDEProblem dispatches here too

isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)
M = f.mass_matrix

algebraic_vars = [all(iszero, x) for x in eachcol(M)]
algebraic_eqs = [all(iszero, x) for x in eachrow(M)]
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true
update_coefficients!(M, u0, p, t)
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t)
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t)
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

integrator.opts.internalnorm: This doesn't handle Sundials or ODEInterface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I didn't know that. Makes sense that SciMLBase can't make the same assumptions OrdinaryDiffEq did


normresid = integrator.opts.internalnorm(tmp, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(tmp, t) : norm(tmp)
if normresid > abstol
throw(CheckInitFailureError(normresid, abstol))
end
return u0, p, true
end

"""
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if
it is in-place or simply calling the function if not.
"""
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...)
tmp = get_tmp_cache(integrator)[2]
f(tmp, args...)
return tmp
end

function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
function get_initial_values(
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
t = current_time(integrator)

resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = integrator.opts.internalnorm(resid, t)
if normresid > integrator.opts.abstol
throw(CheckInitFailureError(normresid, integrator.opts.abstol))
resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t)
normresid = isdefined(integrator.opts, :internalnorm) ?
integrator.opts.internalnorm(resid, t) : norm(resid)

if normresid > abstol
throw(CheckInitFailureError(normresid, abstol))
end
return u0, p, true
end
Expand All @@ -155,12 +160,19 @@ end
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`.
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is.
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword
argument, failing which this function will throw an error. The success value returned
depends on the success of the nonlinear solve.

The success value returned depends on the success of the nonlinear solve.

Keyword arguments:
- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will
throw an error.
- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value
provided to the `OverrideInit` constructor takes priority over this keyword argument.
If the former is `nothing`, this keyword argument will be used. If it is also not provided,
an error will be thrown.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

Expand All @@ -171,15 +183,30 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

if nlsolve_alg === nothing
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
if nlsolve_alg === nothing && state_values(initprob) !== nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

nlsol = solve(initprob, nlsolve_alg)
if alg.abstol !== nothing
_abstol = alg.abstol
elseif abstol !== nothing
_abstol = abstol
else
throw(OverrideInitNoTolerance(:abstol))
end
if alg.reltol !== nothing
_reltol = alg.reltol
elseif reltol !== nothing
_reltol = reltol
else
throw(OverrideInitNoTolerance(:reltol))
end
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
Expand Down
14 changes: 9 additions & 5 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,8 @@ numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
SYS, ID <: Union{Nothing, OverrideInitData}, NLP <: Union{Nothing, ODE_NLProbData}} <:
AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand Down Expand Up @@ -522,7 +523,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O, TCV, SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS, ID <: Union{Nothing, OverrideInitData},
NLP <: Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand Down Expand Up @@ -2442,7 +2444,7 @@ function ODEFunction{iip, specialize}(f;
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing
) where {iip,
specialize
}
Expand Down Expand Up @@ -2500,7 +2502,8 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob_data)
Expand Down Expand Up @@ -2770,7 +2773,8 @@ function SplitFunction{iip, specialize}(f1, f2;
if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(f1, f2, mass_matrix, _func_cache,
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
Expand Down
61 changes: 61 additions & 0 deletions test/downstream/initialization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using OrdinaryDiffEq, Sundials, SciMLBase, Test

@testset "CheckInit" begin
abstol = 1e-10
@testset "Sundials + ODEProblem" begin
function rhs(u, p, t)
return [u[1] * t, u[1]^2 - u[2]^2]
end
function rhs!(du, u, p, t)
du[1] = u[1] * t
du[2] = u[1]^2 - u[2]^2
end

oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0])
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0])

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0))
integ = init(prob, Sundials.ARKODE())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
end
end

@testset "Sundials + DAEProblem" begin
function daerhs(du, u, p, t)
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2]
end
function daerhs!(resid, du, u, p, t)
resid[1] = du[1] - u[1] * t - p
resid[2] = u[1]^2 - u[2]^2
end

oopfn = DAEFunction{false}(daerhs)
iipfn = DAEFunction{true}(daerhs!)

@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn]
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0)
integ = init(prob, Sundials.IDA())
u0, _, success = SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
@test success
@test u0 == prob.u0

integ.u[2] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)

integ.u[2] = 1.0
integ.du[1] = 2.0
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values(
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol)
end
end
end
Loading
Loading