From 66f522909e86d195891f941445453fe39eaa00d8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Nov 2024 15:42:24 +0530 Subject: [PATCH] feat: add fields to `OverrideInit`, better `nlsolve_alg` handling --- src/SciMLBase.jl | 12 ++++++++++-- src/initialization.jl | 7 ++++--- test/initialization.jl | 42 +++++++++++++++++++++++++++++++++++------- 3 files changed, 49 insertions(+), 12 deletions(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 986ec0949..7670e1a99 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -21,7 +21,7 @@ import CommonSolve: solve, init, step!, solve! import FunctionWrappersWrappers import RuntimeGeneratedFunctions import EnumX -import ADTypes: AbstractADType +import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset using Expronicon.ADT: @match @@ -351,7 +351,15 @@ struct CheckInit <: DAEInitializationAlgorithm end """ $(TYPEDEF) """ -struct OverrideInit <: DAEInitializationAlgorithm end +struct OverrideInit{T, F} <: DAEInitializationAlgorithm + abstol::T + nlsolve::F +end + +function OverrideInit(; abstol = 1e-10, nlsolve = nothing) + OverrideInit(abstol, nlsolve) +end +OverrideInit(abstol) = OverrideInit(; abstol = abstol, nlsolve = nothing) # PDE Discretizations diff --git a/src/initialization.jl b/src/initialization.jl index 86c71560d..9e7da393b 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -160,7 +160,7 @@ argument, failing which this function will throw an error. The success value ret depends on the success of the nonlinear solve. """ function get_initial_values(prob, valp, f, alg::OverrideInit, - isinplace::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, kwargs...) + iip::Union{Val{true}, Val{false}}; nlsolve_alg = nothing, autodiff = false, kwargs...) u0 = state_values(valp) p = parameter_values(valp) @@ -171,7 +171,8 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata::OverrideInitData = f.initialization_data initprob = initdata.initializeprob - if nlsolve_alg === nothing + nlsolve_alg = something(nlsolve_alg, alg.nlsolve, Some(nothing)) + if nlsolve_alg === nothing && state_values(initprob) !== nothing throw(OverrideInitMissingAlgorithm()) end @@ -179,7 +180,7 @@ function get_initial_values(prob, valp, f, alg::OverrideInit, initdata.update_initializeprob!(initprob, valp) end - nlsol = solve(initprob, nlsolve_alg) + nlsol = solve(initprob, nlsolve_alg; abstol = alg.abstol) u0 = initdata.initializeprobmap(nlsol) if initdata.initializeprobpmap !== nothing diff --git a/test/initialization.jl b/test/initialization.jl index ca8fb6b6c..3074aa690 100644 --- a/test/initialization.jl +++ b/test/initialization.jl @@ -101,15 +101,43 @@ end end @testset "Solves" begin - u0, p, success = SciMLBase.get_initial_values( - prob, integ, fn, SciMLBase.OverrideInit(), - Val(false); nlsolve_alg = NewtonRaphson()) + @testset "with explicit alg" begin + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(), + Val(false); nlsolve_alg = NewtonRaphson()) - @test u0 ≈ [2.0, 2.0] - @test p ≈ 1.0 - @test success + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success - initprob.p[1] = 1.0 + initprob.p[1] = 1.0 + end + @testset "with alg in `OverrideInit`" begin + u0, p, success = SciMLBase.get_initial_values( + prob, integ, fn, SciMLBase.OverrideInit(nlsolve = NewtonRaphson()), + Val(false)) + + @test u0 ≈ [2.0, 2.0] + @test p ≈ 1.0 + @test success + + initprob.p[1] = 1.0 + end + @testset "with trivial problem and no alg" begin + iprob = NonlinearProblem((u, p) -> 0.0, nothing, 1.0) + iprobmap = (_) -> [1.0, 1.0] + initdata = SciMLBase.OverrideInitData(iprob, nothing, iprobmap, nothing) + _fn = ODEFunction(rhs2; initialization_data = initdata) + _prob = ODEProblem(_fn, [2.0, 0.0], (0.0, 1.0), 1.0) + _integ = init(_prob; initializealg = NoInit()) + + u0, p, success = SciMLBase.get_initial_values( + _prob, _integ, _fn, SciMLBase.OverrideInit(), Val(false)) + + @test u0 ≈ [1.0, 1.0] + @test p ≈ 1.0 + @test success + end end @testset "Solves with non-integrator value provider" begin