Skip to content

Commit

Permalink
Merge pull request #2487 from oscardssmith/os/fix-FBDF-reinit
Browse files Browse the repository at this point in the history
fix initialize! of  FBDF
  • Loading branch information
oscardssmith authored Oct 9, 2024
2 parents 27c8076 + 09953ae commit b4a6686
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 2 deletions.
10 changes: 10 additions & 0 deletions lib/OrdinaryDiffEqBDF/src/bdf_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,11 @@ function initialize!(integrator, cache::FBDFConstantCache)
integrator.fsallast = zero(integrator.fsalfirst)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast

u_modified = integrator.u_modified
integrator.u_modified = true
reinitFBDF!(integrator, cache)
integrator.u_modified = u_modified
end

function perform_step!(integrator, cache::FBDFConstantCache{max_order},
Expand Down Expand Up @@ -1222,6 +1227,11 @@ function initialize!(integrator, cache::FBDFCache)
integrator.k[2] = integrator.fsallast
integrator.f(integrator.fsalfirst, integrator.uprev, integrator.p, integrator.t)
OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

u_modified = integrator.u_modified
integrator.u_modified = true
reinitFBDF!(integrator, cache)
integrator.u_modified = u_modified
end

function perform_step!(integrator, cache::FBDFCache{max_order},
Expand Down
10 changes: 10 additions & 0 deletions lib/OrdinaryDiffEqBDF/src/dae_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,11 @@ function initialize!(integrator, cache::DFBDFConstantCache)
integrator.fsallast = zero(integrator.fsalfirst)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast

u_modified = integrator.u_modified
integrator.u_modified = true
reinitFBDF!(integrator, cache)
integrator.u_modified = u_modified
end

function perform_step!(integrator, cache::DFBDFConstantCache{max_order},
Expand Down Expand Up @@ -355,6 +360,11 @@ function initialize!(integrator, cache::DFBDFCache)
integrator.k[2] = integrator.fsallast
#integrator.f(integrator.fsalfirst, integrator.du, integrator.uprev, integrator.p, integrator.t) # For the interpolation, needs k at the updated point
#OrdinaryDiffEqCore.increment_nf!(integrator.stats, 1)

u_modified = integrator.u_modified
integrator.u_modified = true
reinitFBDF!(integrator, cache)
integrator.u_modified = u_modified
end

function perform_step!(integrator, cache::DFBDFCache{max_order},
Expand Down
3 changes: 3 additions & 0 deletions lib/OrdinaryDiffEqBDF/test/bdf_convergence_tests.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# This definitely needs cleaning
using OrdinaryDiffEqBDF, ODEProblemLibrary, DiffEqDevTools
using OrdinaryDiffEqNonlinearSolve: NLFunctional, NLAnderson, NonlinearSolveAlg
using Test, Random
Random.seed!(100)

testTol = 0.2
dts = 1 .// 2 .^ (9:-1:5)
dts3 = 1 .// 2 .^ (12:-1:7)

@testset "Implicit Solver Convergence Tests ($(["out-of-place", "in-place"][i]))" for i in 1:2
prob = (ODEProblemLibrary.prob_ode_linear,
Expand Down
20 changes: 20 additions & 0 deletions lib/OrdinaryDiffEqBDF/test/bdf_regression_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using OrdinaryDiffEqBDF, Test

foop = (u,p,t)->u
proboop = ODEProblem(foop, ones(2), (0.0, 1000.0))

fiip = (du,u,p,t)->du.=u
probiip = ODEProblem(fiip, ones(2), (0.0, 1000.0))

@testset "FBDF reinit" begin
for prob in [proboop, probiip]
integ = init(prob, FBDF(), verbose=false) #suppress warning to clean up CI
solve!(integ)
@test integ.sol.retcode != ReturnCode.Success
@test integ.sol.t[end] >= 700
reinit!(integ, prob.u0)
solve!(integ)
@test integ.sol.retcode != ReturnCode.Success
@test integ.sol.t[end] >= 700
end
end
3 changes: 3 additions & 0 deletions lib/OrdinaryDiffEqBDF/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using SafeTestsets

@time @safetestset "BDF Convergence Tests" include("bdf_convergence_tests.jl")
@time @safetestset "BDF Regression Tests" include("bdf_regression_tests.jl")

@time @safetestset "DAE Convergence Tests" include("dae_convergence_tests.jl")
@time @safetestset "DAE AD Tests" include("dae_ad_tests.jl")
@time @safetestset "DAE Event Tests" include("dae_event.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ module OrdinaryDiffEqNonlinearSolve
import ADTypes: AutoFiniteDiff, AutoForwardDiff

import SciMLBase
import SciMLBase: init, solve, solve!
import SciMLBase: init, solve, solve!, remake
using SciMLBase: DAEFunction, DEIntegrator, NonlinearFunction, NonlinearProblem,
NonlinearLeastSquaresProblem, LinearProblem, ODEProblem, DAEProblem,
update_coefficients!, get_tmp_cache, AbstractSciMLOperator, ReturnCode
import DiffEqBase
import PreallocationTools
using SimpleNonlinearSolve: SimpleTrustRegion, SimpleGaussNewton
using NonlinearSolve: FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg, NewtonRaphson
using NonlinearSolve: FastShortcutNonlinearPolyalg, FastShortcutNLLSPolyalg, NewtonRaphson, step!
using MuladdMacro, FastBroadcast
import FastClosures: @closure
using LinearAlgebra: UniformScaling, UpperTriangular
Expand Down

0 comments on commit b4a6686

Please sign in to comment.