-
-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add fields to OverrideInit
, better nlsolve_alg
handling
#857
Changes from all commits
e0fade7
9173f80
e1d93b1
e3217f3
d75e2f3
6e38d68
64e2ee0
77b72e3
e1c03d7
66d62c5
7f93fb2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,4 +85,4 @@ jobs: | |
with: | ||
file: lcov.info | ||
token: ${{ secrets.CODECOV_TOKEN }} | ||
fail_ci_if_error: true | ||
fail_ci_if_error: false |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,4 +43,3 @@ struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap} | |
""" | ||
nlprobpmap::NLProbPmap | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,17 +68,26 @@ function Base.showerror(io::IO, e::OverrideInitMissingAlgorithm) | |
"OverrideInit specified but no NonlinearSolve.jl algorithm provided. Provide an algorithm via the `nlsolve_alg` keyword argument to `get_initial_values`.") | ||
end | ||
|
||
struct OverrideInitNoTolerance <: Exception | ||
tolerance::Symbol | ||
end | ||
|
||
function Base.showerror(io::IO, e::OverrideInitNoTolerance) | ||
print(io, | ||
"Tolerances were not provided to `OverrideInit`. `$(e.tolerance)` must be provided as a keyword argument to `get_initial_values` or as a keyword argument to the `OverrideInit` constructor.") | ||
end | ||
|
||
""" | ||
Utility function to evaluate the RHS of the ODE, using the integrator's `tmp_cache` if | ||
Utility function to evaluate the RHS, using the integrator's `tmp_cache` if | ||
it is in-place or simply calling the function if not. | ||
""" | ||
function _evaluate_f_ode(integrator, f, isinplace::Val{true}, args...) | ||
function _evaluate_f(integrator, f, isinplace::Val{true}, args...) | ||
tmp = first(get_tmp_cache(integrator)) | ||
f(tmp, args...) | ||
return tmp | ||
end | ||
|
||
function _evaluate_f_ode(integrator, f, isinplace::Val{false}, args...) | ||
function _evaluate_f(integrator, f, isinplace::Val{false}, args...) | ||
return f(args...) | ||
end | ||
|
||
|
@@ -98,53 +107,49 @@ _vec(v::AbstractVector) = v | |
|
||
Check if the algebraic constraints are satisfied, and error if they aren't. Returns | ||
the `u0` and `p` as-is, and is always successful if it returns. Valid only for | ||
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument. | ||
`AbstractDEProblem` and `AbstractDAEProblem`. Requires a `DEIntegrator` as its second argument. | ||
|
||
Keyword arguments: | ||
- `abstol`: The absolute value below which the norm of the residual of algebraic equations | ||
should lie. The norm function used is `integrator.opts.internalnorm` if present, and | ||
`LinearAlgebra.norm` if not. | ||
""" | ||
function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit, | ||
isinplace::Union{Val{true}, Val{false}}; kwargs...) | ||
function get_initial_values( | ||
prob::AbstractDEProblem, integrator::DEIntegrator, f, alg::CheckInit, | ||
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...) | ||
u0 = state_values(integrator) | ||
p = parameter_values(integrator) | ||
t = current_time(integrator) | ||
M = f.mass_matrix | ||
|
||
algebraic_vars = [all(iszero, x) for x in eachcol(M)] | ||
algebraic_eqs = [all(iszero, x) for x in eachrow(M)] | ||
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return | ||
(iszero(algebraic_vars) || iszero(algebraic_eqs)) && return u0, p, true | ||
update_coefficients!(M, u0, p, t) | ||
tmp = _evaluate_f_ode(integrator, f, isinplace, u0, p, t) | ||
tmp = _evaluate_f(integrator, f, isinplace, u0, p, t) | ||
tmp .= ArrayInterface.restructure(tmp, algebraic_eqs .* _vec(tmp)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm... I didn't know that. Makes sense that SciMLBase can't make the same assumptions OrdinaryDiffEq did |
||
|
||
normresid = integrator.opts.internalnorm(tmp, t) | ||
if normresid > integrator.opts.abstol | ||
throw(CheckInitFailureError(normresid, integrator.opts.abstol)) | ||
normresid = isdefined(integrator.opts, :internalnorm) ? | ||
integrator.opts.internalnorm(tmp, t) : norm(tmp) | ||
if normresid > abstol | ||
throw(CheckInitFailureError(normresid, abstol)) | ||
end | ||
return u0, p, true | ||
end | ||
|
||
""" | ||
Utility function to evaluate the RHS of the DAE, using the integrator's `tmp_cache` if | ||
it is in-place or simply calling the function if not. | ||
""" | ||
function _evaluate_f_dae(integrator, f, isinplace::Val{true}, args...) | ||
tmp = get_tmp_cache(integrator)[2] | ||
f(tmp, args...) | ||
return tmp | ||
end | ||
|
||
function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...) | ||
return f(args...) | ||
end | ||
|
||
function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit, | ||
isinplace::Union{Val{true}, Val{false}}; kwargs...) | ||
function get_initial_values( | ||
prob::AbstractDAEProblem, integrator::DEIntegrator, f, alg::CheckInit, | ||
isinplace::Union{Val{true}, Val{false}}; abstol, kwargs...) | ||
u0 = state_values(integrator) | ||
p = parameter_values(integrator) | ||
t = current_time(integrator) | ||
|
||
resid = _evaluate_f_dae(integrator, f, isinplace, integrator.du, u0, p, t) | ||
normresid = integrator.opts.internalnorm(resid, t) | ||
if normresid > integrator.opts.abstol | ||
throw(CheckInitFailureError(normresid, integrator.opts.abstol)) | ||
resid = _evaluate_f(integrator, f, isinplace, integrator.du, u0, p, t) | ||
normresid = isdefined(integrator.opts, :internalnorm) ? | ||
integrator.opts.internalnorm(resid, t) : norm(resid) | ||
|
||
if normresid > abstol | ||
throw(CheckInitFailureError(normresid, abstol)) | ||
end | ||
return u0, p, true | ||
end | ||
|
@@ -155,12 +160,19 @@ end | |
Solve a `NonlinearProblem`/`NonlinearLeastSquaresProblem` to obtain the initial `u0` and | ||
`p`. Requires that `f` have the field `initialization_data` which is an `OverrideInitData`. | ||
If the field is absent or the value is `nothing`, return `u0` and `p` successfully as-is. | ||
The NonlinearSolve.jl algorithm to use must be specified through the `nlsolve_alg` keyword | ||
argument, failing which this function will throw an error. The success value returned | ||
depends on the success of the nonlinear solve. | ||
|
||
The success value returned depends on the success of the nonlinear solve. | ||
|
||
Keyword arguments: | ||
- `nlsolve_alg`: The NonlinearSolve.jl algorithm to use. If not provided, this function will | ||
throw an error. | ||
- `abstol`, `reltol`: The `abstol` (`reltol`) to use for the nonlinear solve. The value | ||
provided to the `OverrideInit` constructor takes priority over this keyword argument. | ||
If the former is `nothing`, this keyword argument will be used. If it is also not provided, | ||
an error will be thrown. | ||
""" | ||
function get_initial_values(prob, valp, f, alg::OverrideInit, | ||
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...) | ||
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, abstol = nothing, reltol = nothing, kwargs...) | ||
u0 = state_values(valp) | ||
p = parameter_values(valp) | ||
|
||
|
@@ -171,15 +183,30 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, | |
initdata::OverrideInitData = f.initialization_data | ||
initprob = initdata.initializeprob | ||
|
||
if nlsolve_alg === nothing | ||
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) | ||
if nlsolve_alg === nothing && state_values(initprob) !== nothing | ||
throw(OverrideInitMissingAlgorithm()) | ||
end | ||
|
||
if initdata.update_initializeprob! !== nothing | ||
initdata.update_initializeprob!(initprob, valp) | ||
end | ||
|
||
nlsol = solve(initprob, nlsolve_alg) | ||
if alg.abstol !== nothing | ||
_abstol = alg.abstol | ||
elseif abstol !== nothing | ||
_abstol = abstol | ||
else | ||
throw(OverrideInitNoTolerance(:abstol)) | ||
end | ||
if alg.reltol !== nothing | ||
_reltol = alg.reltol | ||
elseif reltol !== nothing | ||
_reltol = reltol | ||
else | ||
throw(OverrideInitNoTolerance(:reltol)) | ||
end | ||
nlsol = solve(initprob, nlsolve_alg; abstol = _abstol, reltol = _reltol) | ||
|
||
u0 = initdata.initializeprobmap(nlsol) | ||
if initdata.initializeprobpmap !== nothing | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
using OrdinaryDiffEq, Sundials, SciMLBase, Test | ||
|
||
@testset "CheckInit" begin | ||
abstol = 1e-10 | ||
@testset "Sundials + ODEProblem" begin | ||
function rhs(u, p, t) | ||
return [u[1] * t, u[1]^2 - u[2]^2] | ||
end | ||
function rhs!(du, u, p, t) | ||
du[1] = u[1] * t | ||
du[2] = u[1]^2 - u[2]^2 | ||
end | ||
|
||
oopfn = ODEFunction{false}(rhs, mass_matrix = [1 0; 0 0]) | ||
iipfn = ODEFunction{true}(rhs!, mass_matrix = [1 0; 0 0]) | ||
|
||
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] | ||
prob = ODEProblem(f, [1.0, 1.0], (0.0, 1.0)) | ||
integ = init(prob, Sundials.ARKODE()) | ||
u0, _, success = SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) | ||
@test success | ||
@test u0 == prob.u0 | ||
|
||
integ.u[2] = 2.0 | ||
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) | ||
end | ||
end | ||
|
||
@testset "Sundials + DAEProblem" begin | ||
function daerhs(du, u, p, t) | ||
return [du[1] - u[1] * t - p, u[1]^2 - u[2]^2] | ||
end | ||
function daerhs!(resid, du, u, p, t) | ||
resid[1] = du[1] - u[1] * t - p | ||
resid[2] = u[1]^2 - u[2]^2 | ||
end | ||
|
||
oopfn = DAEFunction{false}(daerhs) | ||
iipfn = DAEFunction{true}(daerhs!) | ||
|
||
@testset "Inplace = $(SciMLBase.isinplace(f))" for f in [oopfn, iipfn] | ||
prob = DAEProblem(f, [1.0, 0.0], [1.0, 1.0], (0.0, 1.0), 1.0) | ||
integ = init(prob, Sundials.IDA()) | ||
u0, _, success = SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) | ||
@test success | ||
@test u0 == prob.u0 | ||
|
||
integ.u[2] = 2.0 | ||
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) | ||
|
||
integ.u[2] = 1.0 | ||
integ.du[1] = 2.0 | ||
@test_throws SciMLBase.CheckInitFailureError SciMLBase.get_initial_values( | ||
prob, integ, f, SciMLBase.CheckInit(), Val(SciMLBase.isinplace(f)); abstol) | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't handle DAEProblem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an
AbstractDAEProblem
dispatch belowThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But then this one should just be for ODEProblem?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SDEProblem dispatches here too