Skip to content

Commit

Permalink
Merge pull request #2520 from AayushSabharwal/as/late-tstops
Browse files Browse the repository at this point in the history
feat: add late binding for tstops
  • Loading branch information
ChrisRackauckas authored Nov 17, 2024
2 parents 26a5ad6 + b8a1f43 commit e09092c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 2 deletions.
4 changes: 2 additions & 2 deletions lib/OrdinaryDiffEqCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Accessors = "0.1.36"
Adapt = "3.0, 4"
ArrayInterface = "7"
DataStructures = "0.18"
DiffEqBase = "6.159"
DiffEqBase = "6.160"
DiffEqDevTools = "2.44.4"
DocStringExtensions = "0.9"
EnumX = "1"
Expand All @@ -70,7 +70,7 @@ Random = "<0.0.1, 1"
RecursiveArrayTools = "2.36, 3"
Reexport = "1.0"
SafeTestsets = "0.1.0"
SciMLBase = "2.59.2"
SciMLBase = "2.60"
SciMLOperators = "0.3"
SciMLStructures = "1"
SimpleUnPack = "1"
Expand Down
3 changes: 3 additions & 0 deletions lib/OrdinaryDiffEqCore/src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ end

SciMLBase.forwarddiffs_model_time(alg::RosenbrockAlgorithm) = true

SciMLBase.allows_late_binding_tstops(::OrdinaryDiffEqAlgorithm) = true
SciMLBase.allows_late_binding_tstops(::DAEAlgorithm) = true

# isadaptive is defined below.

## OrdinaryDiffEq Internal Traits
Expand Down
13 changes: 13 additions & 0 deletions lib/OrdinaryDiffEqCore/src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,12 @@ function DiffEqBase.__init(
resType = typeof(res_prototype)
end

if tstops isa AbstractArray || tstops isa Tuple || tstops isa Number
_tstops = nothing
else
_tstops = tstops
tstops = ()
end
tstops_internal = initialize_tstops(tType, tstops, d_discontinuities, tspan)
saveat_internal = initialize_saveat(tType, saveat, tspan)
d_discontinuities_internal = initialize_d_discontinuities(tType, d_discontinuities,
Expand Down Expand Up @@ -542,6 +548,13 @@ function DiffEqBase.__init(
end
end

if _tstops !== nothing
tstops = _tstops(parameter_values(integrator), prob.tspan)
for tstop in tstops
add_tstop!(integrator, tstop)
end
end

handle_dt!(integrator)
integrator
end
Expand Down
12 changes: 12 additions & 0 deletions test/interface/ode_tstops_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,15 @@ end
prob = ODEProblem(ff, [0.0], (0.0f0, 1.0f0))
sol = solve(prob, Tsit5(), tstops = [tval], callback = cb)
end

@testset "Late binding tstops" begin
function rhs(u, p, t)
u * p + t
end
prob = ODEProblem(rhs, 1.0, (0.0, 1.0), 0.1; tstops = (p, tspan) -> tspan[1]:p:tspan[2])
sol = solve(prob, Tsit5())
@test 0.0:0.1:1.0 sol.t
prob2 = remake(prob; p = 0.07)
sol2 = solve(prob2, Tsit5())
@test 0.0:0.07:1.0 sol2.t
end

0 comments on commit e09092c

Please sign in to comment.