-
-
Notifications
You must be signed in to change notification settings - Fork 35
/
SSA_stepper.jl
333 lines (287 loc) · 11.3 KB
/
SSA_stepper.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
$(TYPEDEF)
Highly efficient integrator for pure jump problems that involve only `ConstantRateJump`s,
`MassActionJump`s, and/or `VariableRateJump`s *with rate bounds*.
## Notes
- Only works with `JumpProblem`s defined from `DiscreteProblem`s.
- Only works with collections of `ConstantRateJump`s, `MassActionJump`s, and
`VariableRateJump`s with rate bounds.
- Only supports `DiscreteCallback`s for events, which are checked after every step taken by
`SSAStepper`.
- Only supports a limited subset of the output controls from the common solver interface,
specifically `save_start`, `save_end`, and `saveat`.
- As when using jumps with ODEs and SDEs, saving controls for whether to save each time a
jump occurs are via the `save_positions` keyword argument to `JumpProblem`. Note that when
choosing `SSAStepper` as the timestepper, `save_positions = (true,true)`, `(true,false)`,
or `(false,true)` are all equivalent. `SSAStepper` will save only the post-jump state in
the solution object in each of these cases. This is because solution objects generated via
`SSAStepper` use piecewise-constant interpolation, and can therefore exactly reconstruct
the sampled jump process path with knowing just the post-jump state. That is, `sol(t)`
for any `0 <= t <= tstop` will give the exact value of the sampled solution path at `t`
provided at least one component of `save_positions` is `true`.
## Examples
SIR model:
```julia
using JumpProcesses
β = 0.1 / 1000.0; ν = .01;
p = (β,ν)
rate1(u,p,t) = p[1]*u[1]*u[2] # β*S*I
function affect1!(integrator)
integrator.u[1] -= 1 # S -> S - 1
integrator.u[2] += 1 # I -> I + 1
end
jump = ConstantRateJump(rate1,affect1!)
rate2(u,p,t) = p[2]*u[2] # ν*I
function affect2!(integrator)
integrator.u[2] -= 1 # I -> I - 1
integrator.u[3] += 1 # R -> R + 1
end
jump2 = ConstantRateJump(rate2,affect2!)
u₀ = [999,1,0]
tspan = (0.0,250.0)
prob = DiscreteProblem(u₀, tspan, p)
jump_prob = JumpProblem(prob, Direct(), jump, jump2)
sol = solve(jump_prob, SSAStepper())
```
see the
[tutorial](https://docs.sciml.ai/JumpProcesses/stable/tutorials/discrete_stochastic_example/)
for details.
"""
struct SSAStepper <: DiffEqBase.DEAlgorithm end
"""
$(TYPEDEF)
Solution objects for pure jump problems solved via `SSAStepper`.
## Fields
$(FIELDS)
"""
mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS} <:
AbstractSSAIntegrator{SSAStepper, Nothing, uType, tType}
"""The underlying `prob.f` function. Not currently used."""
f::F
"""The current solution values."""
u::uType
"""The current solution time."""
t::tType
"""The previous time a jump occurred."""
tprev::tType
"""The direction time is changing in (must be positive, indicating time is increasing)"""
tdir::tdirType
"""The current parameters."""
p::P
"""The current solution object."""
sol::S
i::Int
"""The next jump time."""
tstop::tType
"""The jump aggregator callback."""
cb::CB
"""Times to save the solution at."""
saveat::SA
"""Whether to save every time a jump occurs."""
save_everystep::Bool
"""Whether to save at the final step."""
save_end::Bool
"""Index of the next `saveat` time."""
cur_saveat::Int
"""Tuple storing callbacks."""
opts::OPT
"""User supplied times to step to, useful with callbacks."""
tstops::TS
tstops_idx::Int
u_modified::Bool
keep_stepping::Bool # false if should terminate a simulation
end
(integrator::SSAIntegrator)(t) = copy(integrator.u)
(integrator::SSAIntegrator)(out, t) = (out .= integrator.u)
function DiffEqBase.u_modified!(integrator::SSAIntegrator, bool::Bool)
integrator.u_modified = bool
end
function DiffEqBase.__solve(jump_prob::JumpProblem,
alg::SSAStepper;
kwargs...)
integrator = init(jump_prob, alg; kwargs...)
solve!(integrator)
integrator.sol
end
function DiffEqBase.solve!(integrator::SSAIntegrator)
end_time = integrator.sol.prob.tspan[2]
while should_continue_solve(integrator) # It stops before adding a tstop over
step!(integrator)
end
integrator.t = end_time
if integrator.saveat !== nothing && !isempty(integrator.saveat)
# Split to help prediction
while integrator.cur_saveat <= length(integrator.saveat) &&
integrator.saveat[integrator.cur_saveat] < integrator.t
push!(integrator.sol.t, integrator.saveat[integrator.cur_saveat])
push!(integrator.sol.u, copy(integrator.u))
integrator.cur_saveat += 1
end
end
if integrator.save_end && integrator.sol.t[end] != end_time
push!(integrator.sol.t, end_time)
push!(integrator.sol.u, copy(integrator.u))
end
DiffEqBase.finalize!(integrator.opts.callback, integrator.u, integrator.t, integrator)
if integrator.sol.retcode === ReturnCode.Default
integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, ReturnCode.Success)
end
end
function DiffEqBase.__init(jump_prob::JumpProblem,
alg::SSAStepper;
save_start = true,
save_end = true,
seed = nothing,
alias_jump = Threads.threadid() == 1,
saveat = nothing,
callback = nothing,
tstops = eltype(jump_prob.prob.tspan)[],
numsteps_hint = 100)
if !(jump_prob.prob isa DiscreteProblem)
error("SSAStepper only supports DiscreteProblems.")
end
@assert isempty(jump_prob.jump_callback.continuous_callbacks)
if alias_jump
cb = jump_prob.jump_callback.discrete_callbacks[end]
if seed !== nothing
Random.seed!(cb.condition.rng, seed)
end
else
cb = deepcopy(jump_prob.jump_callback.discrete_callbacks[end])
if seed === nothing
Random.seed!(cb.condition.rng, seed_multiplier() * rand(UInt64))
else
Random.seed!(cb.condition.rng, seed)
end
end
opts = (callback = CallbackSet(callback),)
prob = jump_prob.prob
if save_start
t = [prob.tspan[1]]
u = [copy(prob.u0)]
else
t = typeof(prob.tspan[1])[]
u = typeof(prob.u0)[]
end
sol = DiffEqBase.build_solution(prob, alg, t, u, dense = false,
calculate_error = false,
stats = DiffEqBase.Stats(0),
interp = DiffEqBase.ConstantInterpolation(t, u))
save_everystep = any(cb.save_positions)
if saveat isa Number
_saveat = prob.tspan[1]:saveat:prob.tspan[2]
else
_saveat = saveat
end
if _saveat !== nothing && !isempty(_saveat) && _saveat[1] == prob.tspan[1]
cur_saveat = 2
else
cur_saveat = 1
end
if _saveat !== nothing && !isempty(_saveat)
sizehint!(u, length(_saveat) + 1)
sizehint!(t, length(_saveat) + 1)
elseif save_everystep
sizehint!(u, numsteps_hint)
sizehint!(t, numsteps_hint)
else
sizehint!(u, save_start + save_end)
sizehint!(t, save_start + save_end)
end
tdir = sign(prob.tspan[2] - prob.tspan[1])
(tdir <= 0) &&
error("The time interval to solve over is non-increasing, i.e. tspan[2] <= tspan[1]. This is not allowed for pure jump problem.")
integrator = SSAIntegrator(prob.f, copy(prob.u0), prob.tspan[1], prob.tspan[1], tdir,
prob.p, sol, 1, prob.tspan[1], cb, _saveat, save_everystep,
save_end, cur_saveat, opts, tstops, 1, false, true)
cb.initialize(cb, integrator.u, prob.tspan[1], integrator)
DiffEqBase.initialize!(opts.callback, integrator.u, prob.tspan[1], integrator)
integrator
end
function DiffEqBase.add_tstop!(integrator::SSAIntegrator, tstop)
if tstop > integrator.t
future_tstops = @view integrator.tstops[(integrator.tstops_idx):end]
insert_index = integrator.tstops_idx + searchsortedfirst(future_tstops, tstop) - 1
Base.insert!(integrator.tstops, insert_index, tstop)
end
end
# The Jump aggregators should not register the next jump through add_tstop! for SSAIntegrator
# such that we can achieve maximum performance
@inline function register_next_jump_time!(integrator::SSAIntegrator,
p::AbstractSSAJumpAggregator, t)
integrator.tstop = p.next_jump_time
nothing
end
function DiffEqBase.step!(integrator::SSAIntegrator)
integrator.tprev = integrator.t
next_jump_time = integrator.tstop > integrator.t ? integrator.tstop :
typemax(integrator.tstop)
doaffect = false
if !isempty(integrator.tstops) &&
integrator.tstops_idx <= length(integrator.tstops) &&
integrator.tstops[integrator.tstops_idx] < next_jump_time
integrator.t = integrator.tstops[integrator.tstops_idx]
integrator.tstops_idx += 1
else
integrator.t = integrator.tstop
doaffect = true # delay effect until after saveat
end
@inbounds if integrator.saveat !== nothing && !isempty(integrator.saveat)
# Split to help prediction
while integrator.cur_saveat < length(integrator.saveat) &&
integrator.saveat[integrator.cur_saveat] < integrator.t
saved = true
push!(integrator.sol.t, integrator.saveat[integrator.cur_saveat])
push!(integrator.sol.u, copy(integrator.u))
integrator.cur_saveat += 1
end
end
# FP error means the new time may equal the old if the next jump time is
# sufficiently small, hence we add this check to execute jumps until
# this is no longer true.
integrator.u_modified = true
while integrator.t == integrator.tstop
doaffect && integrator.cb.affect!(integrator)
end
jump_modified_u = integrator.u_modified
if !(integrator.opts.callback.discrete_callbacks isa Tuple{})
discrete_modified, saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,
integrator.opts.callback.discrete_callbacks...)
else
saved_in_cb = false
end
!saved_in_cb && jump_modified_u && savevalues!(integrator)
nothing
end
function DiffEqBase.savevalues!(integrator::SSAIntegrator, force = false)
saved, savedexactly = false, false
# No saveat in here since it would only use previous values,
# so in the specific case of SSAStepper it's already handled
if integrator.save_everystep || force
saved = true
savedexactly = true
push!(integrator.sol.t, integrator.t)
push!(integrator.sol.u, copy(integrator.u))
end
saved, savedexactly
end
function should_continue_solve(integrator::SSAIntegrator)
end_time = integrator.sol.prob.tspan[2]
# we continue the solve if there is a tstop between now and end_time
has_tstop = !isempty(integrator.tstops) &&
integrator.tstops_idx <= length(integrator.tstops) &&
integrator.tstops[integrator.tstops_idx] < end_time
# we continue the solve if there will be a jump between now and end_time
has_jump = integrator.t < integrator.tstop < end_time
integrator.keep_stepping && (has_jump || has_tstop)
end
function reset_aggregated_jumps!(integrator::SSAIntegrator, uprev = nothing)
reset_aggregated_jumps!(integrator, uprev, integrator.cb)
nothing
end
function DiffEqBase.terminate!(integrator::SSAIntegrator, retcode = ReturnCode.Terminated)
integrator.keep_stepping = false
integrator.sol = DiffEqBase.solution_new_retcode(integrator.sol, retcode)
nothing
end
export SSAStepper