Skip to content

Commit

Permalink
feat: add fields to OverrideInit, better nlsolve_alg handling
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 11, 2024
1 parent c61b13d commit 66f5229
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 12 deletions.
12 changes: 10 additions & 2 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve!
import FunctionWrappersWrappers
import RuntimeGeneratedFunctions
import EnumX
import ADTypes: AbstractADType
import ADTypes: ADTypes, AbstractADType
import Accessors: @set, @reset
using Expronicon.ADT: @match

Expand Down Expand Up @@ -351,7 +351,15 @@ struct CheckInit <: DAEInitializationAlgorithm end
"""
$(TYPEDEF)
"""
struct OverrideInit <: DAEInitializationAlgorithm end
struct OverrideInit{T, F} <: DAEInitializationAlgorithm
abstol::T
nlsolve::F
end

function OverrideInit(; abstol = 1e-10, nlsolve = nothing)
OverrideInit(abstol, nlsolve)
end
OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing)

# PDE Discretizations

Expand Down
7 changes: 4 additions & 3 deletions src/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ argument, failing which this function will throw an error. The success value ret
depends on the success of the nonlinear solve.
"""
function get_initial_values(prob, valp, f, alg::OverrideInit,
isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...)
iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, autodiff = false, kwargs...)
u0 = state_values(valp)
p = parameter_values(valp)

Expand All @@ -171,15 +171,16 @@ function get_initial_values(prob, valp, f, alg::OverrideInit,
initdata::OverrideInitData = f.initialization_data
initprob = initdata.initializeprob

if nlsolve_alg === nothing
nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing))
if nlsolve_alg === nothing && state_values(initprob) !== nothing
throw(OverrideInitMissingAlgorithm())
end

if initdata.update_initializeprob! !== nothing
initdata.update_initializeprob!(initprob, valp)
end

nlsol = solve(initprob, nlsolve_alg)
nlsol = solve(initprob, nlsolve_alg; abstol = alg.abstol)

u0 = initdata.initializeprobmap(nlsol)
if initdata.initializeprobpmap !== nothing
Expand Down
42 changes: 35 additions & 7 deletions test/initialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,43 @@ end
end

@testset "Solves" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())
@testset "with explicit alg" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(),
Val(false); nlsolve_alg = NewtonRaphson())

@test u0 [2.0, 2.0]
@test p 1.0
@test success
@test u0 [2.0, 2.0]
@test p 1.0
@test success

initprob.p[1] = 1.0
initprob.p[1] = 1.0
end
@testset "with alg in `OverrideInit`" begin
u0, p, success = SciMLBase.get_initial_values(
prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()),
Val(false))

@test u0 [2.0, 2.0]
@test p 1.0
@test success

initprob.p[1] = 1.0
end
@testset "with trivial problem and no alg" begin
iprob = NonlinearProblem((u, p) -> 0.0, nothing, 1.0)
iprobmap = (_) -> [1.0, 1.0]
initdata = SciMLBase.OverrideInitData(iprob, nothing, iprobmap, nothing)
_fn = ODEFunction(rhs2; initialization_data = initdata)
_prob = ODEProblem(_fn, [2.0, 0.0], (0.0, 1.0), 1.0)
_integ = init(_prob; initializealg = NoInit())

u0, p, success = SciMLBase.get_initial_values(
_prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false))

@test u0 [1.0, 1.0]
@test p 1.0
@test success
end
end

@testset "Solves with non-integrator value provider" begin
Expand Down

0 comments on commit 66f5229

Please sign in to comment.