Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix nlprob to match the initialization system #860

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/ODE_nlsolve.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""
$(TYPEDEF)

A collection of all the data required for custom ODE Nonlinear problem solving
"""
struct ODE_NLProbData{NLProb, UNLProb, NLProbMap, NLProbPmap}
"""
The `AbstractNonlinearProblem` to define custom nonlinear problems to be used for
implicit time discretizations. This allows to use extra structure of the ODE function (e.g.
multi-level structure). The nonlinear function must match that form of the function implicit
ODE integration algorithms need do solve the a nonlinear problems,
specifically of the form `z = outer_tmp + dt⋅f(γ⋅z+inner_tmp,p,t)`.
Here `z` is the stage solution vector, `p` is the parameter of the ODE problem, `t` is
the time, `dt` the respective time increment`, `γ` is some scaling factor and the temporary
variables are some compatible vectors set by the specific solver.
Note that this field will not be used for integrators such as fully-implicit Runge-Kutta methods
that need to solve different nonlinear systems.
The inner nonlinear function of the nonlinear problem is in general of the form `g(z,p') = 0`
where `p'` is a NamedTuple with all information about the specific nonlinear problem at hand to solve
for a specific time discretization. Specifically, it is `(;dt, γ, inner_tmp, outer_tmp, t, p)`, such that
`g(z,p') = dt⋅f(γ⋅z+inner_tmp,p,t) + outer_tmp - z = 0`.
"""
nlprob::NLProb
"""
A function which takes `(nlprob, value_provider)` and updates
the parameters of the former with their values in the latter.
If absent (`nothing`) this will not be called, and the parameters
in `nlprob` will be used without modification. `value_provider`
refers to a value provider as defined by SymbolicIndexingInterface.jl.
Usually this will refer to a problem or integrator.
"""
update_nlprob!::UNLProb
"""
A function which takes the solution of `nlprob` and returns
the state vector of the original problem.
"""
nlprobmap::NLProbMap
"""
A function which takes the solution of `nlprob` and returns
the parameter object of the original problem. If absent (`nothing`),
this will not be called and the parameters of the problem being
solved will be returned as-is.
"""
nlprobpmap::NLProbPmap
end

3 changes: 2 additions & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,8 @@ Internal. Used for signifying the AD context comes from a Tracker.jl context.
"""
struct TrackerOriginator <: ADOriginator end

include("initialization.jl")
include("ODE_nlsolve.jl")
include("utils.jl")
include("function_wrappers.jl")
include("scimlfunctions.jl")
Expand Down Expand Up @@ -744,7 +746,6 @@ include("ensemble/ensemble_problems.jl")
include("ensemble/basic_ensemble_solve.jl")
include("ensemble/ensemble_analysis.jl")

include("initialization.jl")
include("solve.jl")
include("interpolation.jl")
include("integrator_interface.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ Check if the algebraic constraints are satisfied, and error if they aren't. Retu
the `u0` and `p` as-is, and is always successful if it returns. Valid only for
`ODEProblem` and `DAEProblem`. Requires a `DEIntegrator` as its second argument.
"""
function get_initial_values(prob::ODEProblem, integrator, f, alg::CheckInit,
function get_initial_values(prob::AbstractODEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
Expand Down Expand Up @@ -135,7 +135,7 @@ function _evaluate_f_dae(integrator, f, isinplace::Val{false}, args...)
return f(args...)
end

function get_initial_values(prob::DAEProblem, integrator, f, alg::CheckInit,
function get_initial_values(prob::AbstractDAEProblem, integrator, f, alg::CheckInit,
isinplace::Union{Val{true}, Val{false}}; kwargs...)
u0 = state_values(integrator)
p = parameter_values(integrator)
Expand Down
56 changes: 25 additions & 31 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,6 @@ the usage of `f`. These include:
based on the sparsity pattern. Defaults to `nothing`, which means a color vector will be
internally computed on demand when required. The cost of this operation is highly dependent
on the sparsity pattern.
- `nlprob`: a `NonlinearProblem` that solves `f(u, t, p) = u_tmp`
where the nonlinear parameters are the tuple `(t, u_tmp, p)`.
This will be used as the nonlinear problem inside an implicit solver by specifying `u, u_tmp` and `t`
such that solving this function produces a solution to the implicit step of your solver.

## iip: In-Place vs Out-Of-Place

`iip` is the optional boolean for determining whether a given function is written to
Expand Down Expand Up @@ -406,7 +401,7 @@ numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV,
SYS, ID, NLP} <: AbstractODEFunction{iip}
SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -424,7 +419,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
colorvec::TCV
sys::SYS
initialization_data::ID
nlprob::NLP
nlprob_data::NLP
oscardssmith marked this conversation as resolved.
Show resolved Hide resolved
end

@doc doc"""
Expand Down Expand Up @@ -527,8 +522,7 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O,
TCV, SYS, ID, NLP} <: AbstractODEFunction{iip}
TPJ, O, TCV, SYS, ID<:Union{Nothing, OverrideInitData}, NLP<:Union{Nothing, ODE_NLProbData}} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -547,8 +541,8 @@ struct SplitFunction{
observed::O
colorvec::TCV
sys::SYS
nlprob::NLP
initialization_data::ID
nlprob_data::NLP
end

@doc doc"""
Expand Down Expand Up @@ -2446,9 +2440,9 @@ function ODEFunction{iip, specialize}(f;
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing
nothing,
nlprob_data = __has_nlprob_data(f) ? f.nlprob_data : nothing,
) where {iip,
specialize
}
Expand Down Expand Up @@ -2506,10 +2500,10 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob)
observed, _colorvec, sys, initdata, nlprob_data)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2518,11 +2512,11 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix,
typeof(sys), typeof(initdata), typeof(nlprob_data)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob)
observed, _colorvec, sys, initdata, nlprob_data)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2531,11 +2525,11 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initdata), typeof(nlprob)}(
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
_f, mass_matrix, analytic, tgrad,
jac, jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initdata, nlprob)
observed, _colorvec, sys, initdata, nlprob_data)
end
end

Expand All @@ -2552,23 +2546,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any}(
typeof(f.sys), Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
typeof(f.paramjac),
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}(
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob_data)}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob_data)
end
end

Expand Down Expand Up @@ -2703,7 +2697,7 @@ end
@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
initializeprobmap = nothing, initializeprobpmap = nothing, nlprob = nothing, initialization_data = nothing)
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob_data = nothing)
f1 = ODEFunction(f1)
f2 = ODEFunction(f2)

Expand All @@ -2721,11 +2715,11 @@ end
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initdata), typeof(nlprob)}(
typeof(sys), typeof(initdata), typeof(nlprob_data)}(
f1, f2, mass_matrix,
cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initdata, nlprob)
initdata, nlprob_data)
end
function SplitFunction{iip, specialize}(f1, f2;
mass_matrix = __has_mass_matrix(f1) ?
Expand Down Expand Up @@ -2762,7 +2756,7 @@ function SplitFunction{iip, specialize}(f1, f2;
f1.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing,
nlprob_data = __has_nlprob_data(f1) ? f1.nlprob_data : nothing,
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
nothing
) where {iip,
Expand All @@ -2776,23 +2770,23 @@ function SplitFunction{iip, specialize}(f1, f2;
if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
Any, Any, Union{Nothing, OverrideInitData}, Union{Nothing, ODE_NLProbData}}(f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initdata, nlprob)
observed, colorvec, sys, initdata, nlprob_data)
else
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
typeof(_func_cache), typeof(analytic),
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(colorvec),
typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2,
typeof(sys), typeof(initdata), typeof(nlprob_data)}(f1, f2,
mass_matrix, _func_cache, analytic, tgrad, jac,
jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initdata, nlprob)
initdata, nlprob_data)
end
end

Expand Down Expand Up @@ -4488,7 +4482,7 @@ __has_colorvec(f) = isdefined(f, :colorvec)
__has_sys(f) = isdefined(f, :sys)
__has_analytic_full(f) = isdefined(f, :analytic_full)
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_nlprob(f) = isdefined(f, :nlprob)
__has_nlprob_data(f) = isdefined(f, :nlprob_data)
function __has_initializeprob(f)
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
end
Expand Down
Loading