diff --git a/src/Fleck.jl b/src/Fleck.jl index 831f94e..2833b4a 100644 --- a/src/Fleck.jl +++ b/src/Fleck.jl @@ -1,6 +1,8 @@ module Fleck using Documenter +const ContinuousTime = AbstractFloat + include("prefixsearch/binarytreeprefixsearch.jl") include("prefixsearch/cumsumprefixsearch.jl") include("prefixsearch/keyedprefixsearch.jl") diff --git a/src/prefixsearch/binarytreeprefixsearch.jl b/src/prefixsearch/binarytreeprefixsearch.jl index e79d41d..2f70fdc 100644 --- a/src/prefixsearch/binarytreeprefixsearch.jl +++ b/src/prefixsearch/binarytreeprefixsearch.jl @@ -29,6 +29,10 @@ function BinaryTreePrefixSearch{T}(N=32) where {T<:Real} end +time_type(ps::BinaryTreePrefixSearch{T}) where {T} = T +time_type(::Type{BinaryTreePrefixSearch{T}}) where {T} = T + + function _btps_sizes(allocation) @assert allocation > 0 depth = Int(ceil(log2(allocation))) + 1 diff --git a/src/prefixsearch/cumsumprefixsearch.jl b/src/prefixsearch/cumsumprefixsearch.jl index ce13f7c..28f3428 100644 --- a/src/prefixsearch/cumsumprefixsearch.jl +++ b/src/prefixsearch/cumsumprefixsearch.jl @@ -15,6 +15,8 @@ end Base.length(ps::CumSumPrefixSearch) = length(ps.array) +time_type(ps::CumSumPrefixSearch{T}) where {T} = T +time_type(ps::Type{CumSumPrefixSearch{T}}) where {T} = T function Base.push!(ps::CumSumPrefixSearch{T}, value::T) where {T} diff --git a/src/prefixsearch/keyedprefixsearch.jl b/src/prefixsearch/keyedprefixsearch.jl index 2ee4972..29e5a19 100644 --- a/src/prefixsearch/keyedprefixsearch.jl +++ b/src/prefixsearch/keyedprefixsearch.jl @@ -23,6 +23,8 @@ end Base.length(kp::KeyedKeepPrefixSearch) = length(kp.index) +time_type(kp::KeyedKeepPrefixSearch{T,P}) where {T,P} = time_type(P) + function Base.setindex!(kp::KeyedKeepPrefixSearch, val, clock) idx = get(kp.index, clock, 0) @@ -37,9 +39,9 @@ function Base.setindex!(kp::KeyedKeepPrefixSearch, val, clock) end -Base.delete!(kp::KeyedKeepPrefixSearch, clock) = kp.prefix[kp.index[clock]] = zero(Float64) +Base.delete!(kp::KeyedKeepPrefixSearch, clock) = kp.prefix[kp.index[clock]] = zero(time_type(kp)) function Base.sum!(kp::KeyedKeepPrefixSearch) - (length(kp.index) > 0) ? sum!(kp.prefix) : zero(Float64) + (length(kp.index) > 0) ? sum!(kp.prefix) : zero(time_type(kp)) end @@ -53,7 +55,8 @@ function Random.rand( rng::AbstractRNG, d::Random.SamplerTrivial{KeyedKeepPrefixSearch{T,P}} ) where {T,P} total = sum!(d[]) - choose(d[], rand(rng, Uniform{Float64}(0, total))) + LocalTime = time_type(P) + choose(d[], rand(rng, Uniform{LocalTime}(zero(LocalTime), total))) end @@ -96,17 +99,17 @@ function Base.setindex!(kp::KeyedRemovalPrefixSearch, val, clock) end -function Base.delete!(kp::KeyedRemovalPrefixSearch, clock) +function Base.delete!(kp::KeyedRemovalPrefixSearch{T,P}, clock) where {T,P} idx = kp.index[clock] - kp.prefix[idx] = zero(Float64) + kp.prefix[idx] = zero(time_type(P)) delete!(kp.index, clock) # kp.key[idx] is now out of date. push!(kp.free, idx) end -function Base.sum!(kp::KeyedRemovalPrefixSearch) - (length(kp.index) > 0) ? sum!(kp.prefix) : zero(Float64) +function Base.sum!(kp::KeyedRemovalPrefixSearch{T,P}) where {T,P} + (length(kp.index) > 0) ? sum!(kp.prefix) : zero(time_type(P)) end function choose(kp::KeyedRemovalPrefixSearch, value) @@ -119,5 +122,6 @@ function Random.rand( rng::AbstractRNG, d::Random.SamplerTrivial{KeyedRemovalPrefixSearch{T,P}} ) where {T,P} total = sum!(d[]) - choose(d[], rand(rng, Uniform{Float64}(0, total))) + LocalTime = time_type(P) + choose(d[], rand(rng, Uniform{LocalTime}(zero(LocalTime), total))) end diff --git a/src/sample/combinednr.jl b/src/sample/combinednr.jl index ecf32a8..9683c98 100644 --- a/src/sample/combinednr.jl +++ b/src/sample/combinednr.jl @@ -85,17 +85,17 @@ invert_space(::Type{LinearSampling}, dist, survival) = cquantile(dist, survival) invert_space(::Type{LogSampling}, dist, survival) = invlogccdf(dist, survival)::Float64 -struct NRTransition +struct NRTransition{T} heap_handle::Int - survival::Float64 # value of S_j or Λ_j + survival::T # value of S_j or Λ_j distribution::UnivariateDistribution - te::Float64 # Enabling time of distribution - t0::Float64 # Enabling time of transition + te::T # Enabling time of distribution + t0::T # Enabling time of transition end """ - CombinedNextReaction{KeyType}() + CombinedNextReaction{KeyType,TimeType}() This combines Next Reaction Method and Modified Next Reaction Method. The Next Reaction Method is from Gibson and Bruck in their 2000 paper called @@ -128,18 +128,18 @@ sampling_space(::LinearGamma) = LinearSampling If you want to test a distribution, look at `tests/nrmetric.jl` to see how distributions are timed. """ -struct CombinedNextReaction{T} - firing_queue::MutableBinaryHeap{OrderedSample{T}} - transition_entry::Dict{T,NRTransition} +struct CombinedNextReaction{K,T} + firing_queue::MutableBinaryHeap{OrderedSample{K,T}} + transition_entry::Dict{K,NRTransition{T}} end -function CombinedNextReaction{T}() where {T} - heap = MutableBinaryMinHeap{OrderedSample{T}}() - CombinedNextReaction{T}(heap, Dict{T,NRTransition}()) +function CombinedNextReaction{K,T}() where {K,T <: ContinuousTime} + heap = MutableBinaryMinHeap{OrderedSample{K,T}}() + CombinedNextReaction{K,T}(heap, Dict{K,NRTransition{T}}()) end -clone(nr::CombinedNextReaction{T}) where {T} = CombinedNextReaction{T}() +clone(nr::CombinedNextReaction{K,T}) where {K,T} = CombinedNextReaction{K,T}() export clone @@ -150,20 +150,20 @@ on a CombinedNextReaction sampler, it returns the key associated with the clock that fires and marks that clock as fired. Calling next() again would return a nonsensical value. """ -function next(nr::CombinedNextReaction, when::Float64, rng::AbstractRNG) +function next(nr::CombinedNextReaction{K,T}, when::T, rng::AbstractRNG) where {K,T} if !isempty(nr.firing_queue) least = first(nr.firing_queue) # For this sampler, mark this transition as the one that will fire # by marking its remaining cumulative time as 0.0. entry = nr.transition_entry[least.key] - nr.transition_entry[least.key] = NRTransition( + nr.transition_entry[least.key] = NRTransition{T}( entry.heap_handle, get_survival_zero(entry.distribution), entry.distribution, entry.te, entry.t0 ) return (least.time, least.key) else # Return type is Tuple{Float64, Union{Nothing,T}} because T is not default-constructible. - return (Inf, nothing) + return (typemax(T), nothing) end end @@ -175,11 +175,11 @@ function sample_shifted( rng::AbstractRNG, distribution::UnivariateDistribution, ::Type{S}, - te::Float64, - when::Float64 - ) where {S <: SamplingSpaceType} + te::T, + when::T + ) where {S <: SamplingSpaceType, T <: ContinuousTime} if te < when - shifted_distribution = truncated(distribution, when - te, Inf) + shifted_distribution = truncated(distribution, when - te, typemax(T)) sample = rand(rng, shifted_distribution) tau = te + sample survival = survival_space(S, shifted_distribution, sample) @@ -194,10 +194,10 @@ end function sample_by_inversion( - distribution::UnivariateDistribution, ::Type{S}, te::Float64, when::Float64, survival::Float64 - ) where {S <: SamplingSpaceType} + distribution::UnivariateDistribution, ::Type{S}, te::T, when::T, survival::T + ) where {S <: SamplingSpaceType, T <: ContinuousTime} if te < when - te + invert_space(S, truncated(distribution, when - te, Inf), survival) + te + invert_space(S, truncated(distribution, when - te, typemax(T)), survival) else # te > when te + invert_space(S, distribution, survival) end @@ -215,17 +215,17 @@ te can be before t0, at t0, between t0 and tn, or at tn, or after tn. """ function consume_survival( - record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::Float64 - ) where {S <: LinearSampling} + record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::T + ) where {S <: LinearSampling, T <: ContinuousTime} survive_te_tn = if record.te < tn - ccdf(distribution, tn-record.te)::Float64 + ccdf(distribution, tn-record.te)::T else - one(Float64) + one(T) end survive_te_t0 = if record.te < record.t0 - ccdf(distribution, record.t0-record.te)::Float64 + ccdf(distribution, record.t0-record.te)::T else - one(Float64) + one(T) end record.survival / (survive_te_t0 * survive_te_tn) end @@ -239,52 +239,52 @@ Anderson's method. """ function consume_survival( - record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::Float64 - ) where {S <: LogSampling} + record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::T + ) where {S <: LogSampling, T <: ContinuousTime} log_survive_te_tn = if record.te < tn - logccdf(distribution, tn-record.te)::Float64 + logccdf(distribution, tn-record.te)::T else - zero(Float64) + zero(T) end log_survive_te_t0 = if record.te < record.t0 - logccdf(distribution, record.t0-record.te)::Float64 + logccdf(distribution, record.t0-record.te)::T else - zero(Float64) + zero(T) end record.survival - (log_survive_te_t0 + log_survive_te_tn) end function enable!( - nr::CombinedNextReaction{T}, clock::T, distribution::UnivariateDistribution, - te::Float64, when::Float64, rng::AbstractRNG) where {T} + nr::CombinedNextReaction{K,T}, clock::K, distribution::UnivariateDistribution, + te::T, when::T, rng::AbstractRNG) where {K,T} enable!(nr, clock, distribution, sampling_space(distribution), te, when, rng) nothing end function enable!( - nr::CombinedNextReaction{T}, clock::T, distribution::UnivariateDistribution, ::Type{S}, - te::Float64, when::Float64, rng::AbstractRNG) where {T, S <: SamplingSpaceType} + nr::CombinedNextReaction{K,T}, clock::K, distribution::UnivariateDistribution, ::Type{S}, + te::T, when::T, rng::AbstractRNG) where {K, T, S <: SamplingSpaceType} # Three cases: a) never been enabled b) currently enabled c) was disabled. record = get( nr.transition_entry, clock, - NRTransition(0, get_survival_zero(S), Never(), 0.0, 0.0) + NRTransition{T}(0, get_survival_zero(S), Never(), zero(T), zero(T)) ) heap_handle = record.heap_handle # if the transition needs to be re-drawn. if record.survival <= get_survival_zero(S) tau, shift_survival = sample_shifted(rng, distribution, S, te, when) - sample = OrderedSample{T}(clock, tau) + sample = OrderedSample{K,T}(clock, tau) if record.heap_handle > 0 update!(nr.firing_queue, record.heap_handle, sample) else heap_handle = push!(nr.firing_queue, sample) end - nr.transition_entry[clock] = NRTransition( + nr.transition_entry[clock] = NRTransition{T}( heap_handle, shift_survival, distribution, te, when ) @@ -300,9 +300,9 @@ function enable!( # Account for time between when this was last enabled and now. survival_remain = consume_survival(record, record.distribution, S, when) tau = sample_by_inversion(distribution, S, te, when, survival_remain) - entry = OrderedSample{T}(clock, tau) + entry = OrderedSample{K,T}(clock, tau) update!(nr.firing_queue, record.heap_handle, entry) - nr.transition_entry[clock] = NRTransition( + nr.transition_entry[clock] = NRTransition{T}( heap_handle, survival_remain, distribution, te, when ) end @@ -310,8 +310,8 @@ function enable!( # The transition was previously disabled. else tau = sample_by_inversion(distribution, S, te, when, record.survival) - heap_handle = push!(nr.firing_queue, OrderedSample{T}(clock, tau)) - nr.transition_entry[clock] = NRTransition( + heap_handle = push!(nr.firing_queue, OrderedSample{K,T}(clock, tau)) + nr.transition_entry[clock] = NRTransition{T}( heap_handle, record.survival, distribution, te, when ) end @@ -320,10 +320,10 @@ function enable!( end -function disable!(nr::CombinedNextReaction{T}, clock::T, when::Float64) where {T} +function disable!(nr::CombinedNextReaction{K,T}, clock::K, when::T) where {K,T <: ContinuousTime} record = nr.transition_entry[clock] delete!(nr.firing_queue, record.heap_handle) - nr.transition_entry[clock] = NRTransition( + nr.transition_entry[clock] = NRTransition{T}( 0, consume_survival(record, record.distribution, sampling_space(record.distribution), when), record.distribution, diff --git a/src/sample/direct.jl b/src/sample/direct.jl index 50f7778..ae573e0 100644 --- a/src/sample/direct.jl +++ b/src/sample/direct.jl @@ -6,39 +6,39 @@ export DirectCall, enable!, disable!, next """ - DirectCall{T} + DirectCall{K,T} DirectCall is responsible for sampling among Exponential distributions. It samples using the Direct method. In this case, there is no optimization to that Direct method, so we call it DirectCall because it recalculates everything every time you call it. -The type `T` is the type of an identifier for each transition. This identifier +The type `K` is the type of an identifier for each transition. This identifier is usually a nominal integer but can be a any key that identifies it, such as -a string or tuple of integers. Instances of type `T` are used as keys in a +a string or tuple of integers. Instances of type `K` are used as keys in a dictionary. # Example ```julia -DirectCall{T}() where {T} = - DirectCall{T,CumSumPrefixSearch{Float64}}(CumSumPrefixSearch(Float64)) +DirectCall{K,T}() where {K,T} = + DirectCall{K,CumSumPrefixSearch{T}}(CumSumPrefixSearch(T)) -DirectCall{T}() where {T} = - DirectCall{T,BinaryTreePrefixSearch{Float64}}(BinaryTreePrefixSearch(Float64)) +DirectCall{K,T}() where {K,T} = + DirectCall{K,BinaryTreePrefixSearch{T}}(BinaryTreePrefixSearch(T)) ``` """ -struct DirectCall{T,P} +struct DirectCall{K,P} prefix_tree::P - DirectCall{T,P}(tree::P) where {T,P} = new(tree) + DirectCall{K,P}(tree::P) where {K,P} = new(tree) end -function DirectCall{T}() where {T} - prefix_tree = BinaryTreePrefixSearch{Float64}() - keyed_prefix_tree = KeyedRemovalPrefixSearch{T,typeof(prefix_tree)}(prefix_tree) - DirectCall{T,typeof(keyed_prefix_tree)}(keyed_prefix_tree) +function DirectCall{K,T}() where {K,T<:ContinuousTime} + prefix_tree = BinaryTreePrefixSearch{T}() + keyed_prefix_tree = KeyedRemovalPrefixSearch{K,typeof(prefix_tree)}(prefix_tree) + DirectCall{K,typeof(keyed_prefix_tree)}(keyed_prefix_tree) end @@ -55,8 +55,8 @@ later than when it was first enabled. The `rng` is a random number generator. If a particular clock had one rate before an event and it has another rate after the event, call `enable!` to update the rate. """ -function enable!(dc::DirectCall{T}, clock::T, distribution::Exponential, - te::Float64, when::Float64, rng::AbstractRNG) where {T} +function enable!(dc::DirectCall{K,P}, clock::K, distribution::Exponential, + te, when, rng::AbstractRNG) where {K,P} dc.prefix_tree[clock] = rate(distribution) end @@ -68,13 +68,13 @@ Tell the `DirectCall` sampler to disable this clock. The `clock` argument is an identifier for the clock. The `when` argument is the time at which this clock is enabled. """ -function disable!(dc::DirectCall{T}, clock::T, when::Float64) where {T} +function disable!(dc::DirectCall{K,P}, clock::K, when) where {K,P} delete!(dc.prefix_tree, clock) end """ - next(dc::DirectCall, when::Float64, rng::AbstractRNG) + next(dc::DirectCall, when::TimeType, rng::AbstractRNG) Ask the sampler what clock will be the next to fire and at what time. This does not change the sampler. You can call this multiple times and get multiple @@ -82,13 +82,13 @@ answers. Each answer is a tuple of `(when, which clock)`. If there is no clock to fire, then the response will be `(Inf, nothing)`. That's a good sign the simulation is done. """ -function next(dc::DirectCall, when::Float64, rng::AbstractRNG) +function next(dc::DirectCall, when, rng::AbstractRNG) total = sum!(dc.prefix_tree) - if total > eps(Float64) + if total > eps(when) chosen, hazard_value = rand(rng, dc.prefix_tree) - tau = when + rand(rng, Exponential(1 / total)) + tau = when + rand(rng, Exponential(inv(total))) return (tau, chosen) else - return (Inf, nothing) + return (typemax(when), nothing) end end diff --git a/src/sample/firstreaction.jl b/src/sample/firstreaction.jl index 2eb50e3..4939d6f 100644 --- a/src/sample/firstreaction.jl +++ b/src/sample/firstreaction.jl @@ -10,32 +10,32 @@ Classic First Reaction method for Exponential and non-Exponential distributions. Every time you sample, go to each distribution and ask when it would fire. Then take the soonest and throw out the rest until the next sample. """ -struct FirstReaction{T} +struct FirstReaction{K,T} # This other class already stores the current set of distributions, so use it. - core_matrix::TrackWatcher{T} - FirstReaction{T}() where {T} = new(TrackWatcher{T}()) + core_matrix::TrackWatcher{K} + FirstReaction{K,T}() where {K,T <: ContinuousTime} = new(TrackWatcher{K,T}()) end -function enable!(fr::FirstReaction{T}, clock::T, distribution::UnivariateDistribution, - te::Float64, when::Float64, rng::AbstractRNG) where {T} +function enable!(fr::FirstReaction{K,T}, clock::K, distribution::UnivariateDistribution, + te::T, when::T, rng::AbstractRNG) where {K,T} enable!(fr.core_matrix, clock, distribution, te, when, rng) end -function disable!(fr::FirstReaction{T}, clock::T, when::Float64) where {T} +function disable!(fr::FirstReaction{K,T}, clock::K, when::T) where {K,T} disable!(fr.core_matrix, clock, when) end -function next(fr::FirstReaction{T}, when::Float64, rng) where {T} - soonest_clock::Union{Nothing,T} = nothing - soonest_time = Inf +function next(fr::FirstReaction{K,T}, when::T, rng) where {K,T} + soonest_clock::Union{Nothing,K} = nothing + soonest_time = typemax(T) - for entry::EnablingEntry{T} in fr.core_matrix + for entry::EnablingEntry{K,T} in fr.core_matrix if entry.te < when - relative_dist = truncated(entry.distribution, when - entry.te, Inf) + relative_dist = truncated(entry.distribution, when - entry.te, typemax(T)) putative_time = entry.te + rand(rng, relative_dist) else putative_time = entry.te + rand(rng, entry.distribution) @@ -53,18 +53,18 @@ end This sampler can help if it's the first time you're trying a model. It checks all of the things and uses Julia's logger to communicate them. """ -mutable struct ChatReaction{T} +mutable struct ChatReaction{K,T} # This other class already stores the current set of distributions, so use it. - core_matrix::TrackWatcher{T} + core_matrix::TrackWatcher{K} step_cnt::Int64 - enables::Set{T} - disables::Set{T} - ChatReaction{T}() where {T} = new(TrackWatcher{T}(), 0, Set{T}(), Set{T}()) + enables::Set{K} + disables::Set{K} + ChatReaction{K,T}() where {K,T<:ContinuousTime} = new(TrackWatcher{K,T}(), 0, Set{K}(), Set{K}()) end -function enable!(fr::ChatReaction{T}, clock::T, distribution::UnivariateDistribution, - te::Float64, when::Float64, rng::AbstractRNG) where {T} +function enable!(fr::ChatReaction{K,T}, clock::K, distribution::UnivariateDistribution, + te::T, when::T, rng::AbstractRNG) where {K,T} if clock ∈ keys(fr.core_matrix.enabled) @warn "Re-enabling transition $clock without disabling first" @@ -74,7 +74,7 @@ function enable!(fr::ChatReaction{T}, clock::T, distribution::UnivariateDistribu end -function disable!(fr::ChatReaction{T}, clock::T, when::Float64) where {T} +function disable!(fr::ChatReaction{K,T}, clock::K, when::T) where {K,T} if clock ∉ fr.enables @warn "Disabling a clock that was never enabled: $(clock)." end @@ -86,17 +86,17 @@ function disable!(fr::ChatReaction{T}, clock::T, when::Float64) where {T} end -function next(fr::ChatReaction{T}, when::Float64, rng) where {T} +function next(fr::ChatReaction{K,T}, when::T, rng) where {K,T} soonest_clock = nothing - soonest_time = Inf + soonest_time = typemax(T) if length(fr.enables) == 0 @warn "No transitions have ever been enabled. Sampler may not be initialized." end - for entry::EnablingEntry{T} in fr.core_matrix + for entry::EnablingEntry{K,T} in fr.core_matrix if entry.te < when - relative_dist = truncated(entry.distribution, when - entry.te, Inf) + relative_dist = truncated(entry.distribution, when - entry.te, typemax(T)) putative_time = entry.te + rand(rng, relative_dist) else putative_time = entry.te + rand(rng, entry.distribution) diff --git a/src/sample/firsttofire.jl b/src/sample/firsttofire.jl index cba3417..e82efb1 100644 --- a/src/sample/firsttofire.jl +++ b/src/sample/firsttofire.jl @@ -11,14 +11,14 @@ The soonest to fire wins. When a clock is disabled, its future firing time is removed from the list. There is no memory of previous firing times. """ struct FirstToFire{K,T} <: SSA{K,T} - firing_queue::MutableBinaryMinHeap{OrderedSample{K}} + firing_queue::MutableBinaryMinHeap{OrderedSample{K,T}} # This maps from transition to entry in the firing queue. transition_entry::Dict{K,Int} end function FirstToFire{K,T}() where {K,T} - heap = MutableBinaryMinHeap{OrderedSample{K}}() + heap = MutableBinaryMinHeap{OrderedSample{K,T}}() state = Dict{K,Int}() FirstToFire{K,T}(heap, state) end @@ -29,7 +29,7 @@ function next(propagator::FirstToFire{K,T}, when::T, rng::AbstractRNG) where {K, least = if !isempty(propagator.firing_queue) top(propagator.firing_queue) else - OrderedSample(nothing, Inf) + OrderedSample(nothing, typemax(T)) end @debug("FirstToFire.next queue length ", length(propagator.firing_queue), " least ", least) @@ -42,15 +42,15 @@ function enable!( te::T, when::T, rng::AbstractRNG) where {K,T} if te < when - when_fire = te + rand(rng, truncated(distribution, when - te, Inf)) + when_fire = te + rand(rng, truncated(distribution, when - te, typemax(T))) else when_fire = te + rand(rng, distribution) end if haskey(propagator.transition_entry, clock) heap_handle = propagator.transition_entry[clock] - update!(propagator.firing_queue, heap_handle, OrderedSample{K}(clock, when_fire)) + update!(propagator.firing_queue, heap_handle, OrderedSample{K,T}(clock, when_fire)) else - heap_handle = push!(propagator.firing_queue, OrderedSample{K}(clock, when_fire)) + heap_handle = push!(propagator.firing_queue, OrderedSample{K,T}(clock, when_fire)) propagator.transition_entry[clock] = heap_handle end end diff --git a/src/sample/multiple_direct.jl b/src/sample/multiple_direct.jl index f4c43ac..0d35208 100644 --- a/src/sample/multiple_direct.jl +++ b/src/sample/multiple_direct.jl @@ -35,7 +35,7 @@ end function enable!(md::MultipleDirect, clock, distribution::Exponential, - te::Float64, when::Float64, rng::AbstractRNG) + te, when, rng::AbstractRNG) if clock ∉ keys(md.chosen) which_prefix_search = choose_sampler(md.chooser, clock, distribution) scan_idx = md.scanmap[which_prefix_search] @@ -48,7 +48,7 @@ function enable!(md::MultipleDirect, clock, distribution::Exponential, end -function disable!(md::MultipleDirect, clock, when::Float64) +function disable!(md::MultipleDirect, clock, when) which_prefix_search = md.chosen[clock] delete!(md.scan[which_prefix_search], clock) end @@ -69,18 +69,18 @@ it is possible that a random number generator will _never_ choose a particular value because there is no guarantee that a random number generator covers every combination of bits. Using more draws decreases the likelihood of this problem. """ -function next(md::MultipleDirect, when::Float64, rng::AbstractRNG) +function next(md::MultipleDirect, when, rng::AbstractRNG) for scan_idx in eachindex(md.scan) md.totals[scan_idx] = sum!(md.scan[scan_idx]) end total = sum(md.totals) - if total > eps(Float64) + if total > eps(when) tau = when + rand(rng, Exponential(1 / total)) md.totals /= total chosen_idx = rand(rng, Categorical(md.totals)) chosen, hazard_value = rand(rng, md.scan[chosen_idx]) return (tau, chosen) else - return (Inf, nothing) + return (typemax(when), nothing) end end diff --git a/src/sample/neverdist.jl b/src/sample/neverdist.jl index 21092b9..f326592 100644 --- a/src/sample/neverdist.jl +++ b/src/sample/neverdist.jl @@ -13,21 +13,21 @@ Never() = Never{Float64}() params(d::Never) = () partype(d::Never{T}) where {T<:Real} = T -mean(d::Never) = Inf -median(d::Never) = Inf -mode(d::Never) = Inf -var(d::Never) = Inf +mean(d::Never{T}) where {T} = typemax(T) +median(d::Never{T}) where {T}= typemax(T) +mode(d::Never{T}) where {T} = typemax(T) +var(d::Never{T}) where {T} = typemax(T) skewness(d::Never{T}) where {T<:Real} = zero(T) kurtosis(d::Never{T}) where {T<:Real} = zero(T) pdf(d::Never{T}, x::Real) where {T<:Real} = zero(x) -logpdf(d::Never{T}, x::Real) where {T<:Real} = -Inf +logpdf(d::Never{T}, x::Real) where {T<:Real} = typemin(T) cdf(d::Never{T}, x::Real) where {T<:Real} = zero(x) ccdf(d::Never{T}, x::Real) where {T<:Real} = one(x) -quantile(d::Never{T}, q::Real) where {T<:Real} = Inf +quantile(d::Never{T}, q::Real) where {T<:Real} = typemax(T) mgf(d::Never{T}, x::Real) where {T<:Real} = zero(x) cf(d::Never{T}, x::Real) where {T<:Real} = zero(x) -rand(rng::Random.AbstractRNG, d::Never) = Inf +rand(rng::Random.AbstractRNG, d::Never) = typemax(T) function rand!(rng::Random.AbstractRNG, d::Never, arr::AbstractArray) - arr .= Inf + arr .= typemax(T) end diff --git a/src/sample/nrtransition.jl b/src/sample/nrtransition.jl index 9a410a9..f46d32a 100644 --- a/src/sample/nrtransition.jl +++ b/src/sample/nrtransition.jl @@ -4,9 +4,9 @@ import Base: ==, <, > A record of a transition and the time. It's sortable by time. Immutable. """ -struct OrderedSample{T} - key::T - time::Float64 +struct OrderedSample{K,T} + key::K + time::T end diff --git a/src/sample/track.jl b/src/sample/track.jl index d6caa8b..1829bd9 100644 --- a/src/sample/track.jl +++ b/src/sample/track.jl @@ -6,22 +6,22 @@ export TrackWatcher, DebugWatcher, enable!, disable! # to a model in order to provide more information about active # clocks. -struct EnablingEntry{T} - clock::T +struct EnablingEntry{K,T} + clock::K distribution::UnivariateDistribution - te::Float64 - when::Float64 + te::T + when::T end -struct DisablingEntry{T} - clock::T - when::Float64 +struct DisablingEntry{K,T} + clock::K + when::T end """ - TrackWatcher() + TrackWatcher{K,T}() This Watcher doesn't sample. It records everything enabled. You can iterate over enabled clocks with a for-loop. If we think of the @@ -39,9 +39,9 @@ for entry in tracker end ``` """ -mutable struct TrackWatcher{T} - enabled::Dict{T,EnablingEntry{T}} - TrackWatcher{T}() where {T}=new(Dict{T,EnablingEntry{T}}()) +mutable struct TrackWatcher{K,T} + enabled::Dict{K,EnablingEntry{K,T}} + TrackWatcher{K,T}() where {K,T}=new(Dict{K,EnablingEntry{K,T}}()) end @@ -59,12 +59,12 @@ function Base.length(ts::TrackWatcher) end -function enable!(ts::TrackWatcher{T}, clock::T, dist::UnivariateDistribution, te, when, rng) where {T} - ts.enabled[clock] = EnablingEntry{T}(clock, dist, te, when) +function enable!(ts::TrackWatcher{K,T}, clock::K, dist::UnivariateDistribution, te, when, rng) where {K,T} + ts.enabled[clock] = EnablingEntry{K,T}(clock, dist, te, when) end -function disable!(ts::TrackWatcher{T}, clock::T, when) where {T} +function disable!(ts::TrackWatcher{K,T}, clock::K, when) where {K,T} if haskey(ts.enabled, clock) delete!(ts.enabled, clock) end @@ -88,18 +88,18 @@ watcher.enabled[1].te, watcher.enabled[1].when) ``` """ -mutable struct DebugWatcher{T} - enabled::Vector{EnablingEntry{T}} - disabled::Vector{DisablingEntry{T}} - DebugWatcher{T}() where {T}=new(Vector{EnablingEntry{T}}(), Vector{DisablingEntry{T}}()) +mutable struct DebugWatcher{K,T} + enabled::Vector{EnablingEntry{K,T}} + disabled::Vector{DisablingEntry{K,T}} + DebugWatcher{K,T}() where {K,T}=new(Vector{EnablingEntry{K,T}}(), Vector{DisablingEntry{K,T}}()) end -function enable!(ts::DebugWatcher{T}, clock::T, dist::UnivariateDistribution, te, when, rng) where {T} +function enable!(ts::DebugWatcher{K,T}, clock::K, dist::UnivariateDistribution, te, when, rng) where {K,T} push!(ts.enabled, EnablingEntry(clock, dist, te, when)) end -function disable!(ts::DebugWatcher{T}, clock::T, when) where {T} +function disable!(ts::DebugWatcher{K,T}, clock::K, when) where {K,T} push!(ts.disabled, DisablingEntry(clock, when)) end diff --git a/test/nextreaction.jl b/test/nextreaction.jl index d888a1a..debc966 100644 --- a/test/nextreaction.jl +++ b/test/nextreaction.jl @@ -43,7 +43,7 @@ function NextReaction{T}() where {T} end -function next(nr::AbstractNextReaction{T}, when::Float64, rng::AbstractRNG) where {T} +function next(nr::AbstractNextReaction{T}, when, rng::AbstractRNG) where {T} if !isempty(nr.firing_queue) least = top(nr.firing_queue) # For this sampler, mark this transition as the one that will fire @@ -55,7 +55,7 @@ function next(nr::AbstractNextReaction{T}, when::Float64, rng::AbstractRNG) wher return (least.time, least.key) else # Return type is Tuple{Float64, Union{Nothing,T}} because T is not default-constructible. - return (Inf, nothing) + return (typemax(when), nothing) end end diff --git a/test/test_combinednr.jl b/test/test_combinednr.jl index 72b7e7b..449b3cb 100644 --- a/test/test_combinednr.jl +++ b/test/test_combinednr.jl @@ -8,7 +8,7 @@ using SafeTestsets rng = MersenneTwister(349827) for i in 1:100 - sampler = CombinedNextReaction{String}() + sampler = CombinedNextReaction{String,Float64}() @test next(sampler, 3.0, rng)[2] === nothing enable!(sampler, "walk home", Exponential(1.5), 0.0, 0.0, rng) @test next(sampler, 3.0, rng)[2] == "walk home" diff --git a/test/test_direct.jl b/test/test_direct.jl index a91d007..5424730 100644 --- a/test/test_direct.jl +++ b/test/test_direct.jl @@ -6,7 +6,7 @@ using SafeTestsets using Random: MersenneTwister using Distributions: Exponential - dc = DirectCall{Int}() + dc = DirectCall{Int,Float64}() rng = MersenneTwister(90422342) propensities = [0.3, 0.2, 0.7, 0.001, 0.25] for (i, p) in enumerate(propensities) @@ -22,7 +22,7 @@ end @safetestset direct_call_empty = "DirectCall empty hazard" begin using Fleck: DirectCall, next, enable! using Random: MersenneTwister - md = DirectCall{Int}() + md = DirectCall{Int,Float64}() rng = MersenneTwister(90497979) current = 0.0 when, which = next(md, current, rng) @@ -41,7 +41,7 @@ end # Given 10 slow distributions and 10 fast, we can figure out # that the marginal probability of a low vs a high is 1 / (1 + 1.5) = 3/5. # Check that we get the correct marginal probability. - md = DirectCall{Int}() + md = DirectCall{Int,Float64}() for i in 1:10 enable!(md, i, Exponential(1), 0.0, 0.0, rng) end @@ -66,6 +66,6 @@ end using HypothesisTests: BinomialTest, confint using ..DirectFixture: test_exponential_binomial rng = MersenneTwister(223497123) - md = DirectCall{Int}() + md = DirectCall{Int,Float64}() test_exponential_binomial(md, rng) end diff --git a/test/test_firstreaction.jl b/test/test_firstreaction.jl index 288014d..759c250 100644 --- a/test/test_firstreaction.jl +++ b/test/test_firstreaction.jl @@ -11,7 +11,7 @@ using SafeTestsets seen = Set{Int}() sample_time = 0.5 for i in 1:100 - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() enable!(sampler, 1, Exponential(1.7), 0.0, 0.0, rng) enable!(sampler, 2, Gamma(9, 0.5), 0.0, 0.0, rng) enable!(sampler, 3, Gamma(2, 2.0), 0.0, 0.0, rng) @@ -31,7 +31,7 @@ end using Random: MersenneTwister rng = MersenneTwister(90422342) - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() when, which = next(sampler, 5.7, rng) @test when == Inf @test which === nothing @@ -44,7 +44,7 @@ end using ..DirectFixture: test_exponential_binomial rng = MersenneTwister(12349678) - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() test_exponential_binomial(sampler, rng) end @@ -55,7 +55,7 @@ end using ..DirectFixture: test_weibull_binomial rng = MersenneTwister(12967847) - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() test_weibull_binomial(sampler, rng) end @@ -68,7 +68,7 @@ end rng = Xoshiro(8367109004) rand(rng, 100) # burn some early numbers - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() dist = Weibull() sample_cnt = 1000 enable!(sampler, 1, dist, 0.0, 0.0, rng) @@ -88,7 +88,7 @@ end rng = Xoshiro(8367109004) rand(rng, 100) # burn some early numbers - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() dist = Weibull() sample_cnt = 1000 enable!(sampler, 1, dist, 0.0, 0.0, rng) @@ -111,7 +111,7 @@ end rng = Xoshiro(8367109004) rand(rng, 100) # burn some early numbers - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() dist = Weibull() future = 2.7 sample_cnt = 1000 diff --git a/test/test_firsttofire.jl b/test/test_firsttofire.jl index a47d5a4..6bf497c 100644 --- a/test/test_firsttofire.jl +++ b/test/test_firsttofire.jl @@ -34,7 +34,10 @@ end propagator = FirstToFire{Int64,Float64}() for (clock, when_fire) in [(1, 7.9), (2, 12.3), (3, 3.7), (4, 0.00013), (5, 0.2)] - heap_handle = push!(propagator.firing_queue, Fleck.OrderedSample{Int64}(clock, when_fire)) + heap_handle = push!( + propagator.firing_queue, + Fleck.OrderedSample{Int64,Float64}(clock, when_fire) + ) propagator.transition_entry[clock] = heap_handle end rng = Xoshiro(39472) diff --git a/test/test_nextreaction.jl b/test/test_nextreaction.jl index fdc2667..62b6916 100644 --- a/test/test_nextreaction.jl +++ b/test/test_nextreaction.jl @@ -7,7 +7,7 @@ using SafeTestsets # The NextReaction algorithm relies on the heap handle always being positive # so this test checks that is the case. - heap = MutableBinaryMinHeap{OrderedSample{Int}}() + heap = MutableBinaryMinHeap{OrderedSample{Int,Float64}}() enabled = Set{Int}() for i in 1:10000 if rand() < 0.2 && length(heap) > 0 @@ -16,9 +16,9 @@ using SafeTestsets delete!(enabled, handle) elseif rand() < 0.4 && length(heap) > 0 && length(enabled) > 0 modify = rand(enabled) - update!(heap, modify, OrderedSample{Int}(i, rand())) + update!(heap, modify, OrderedSample{Int,Float64}(i, rand())) else - handle = push!(heap, OrderedSample{Int}(i, rand())) + handle = push!(heap, OrderedSample{Int,Float64}(i, rand())) push!(enabled, handle) @test(handle > 0) end diff --git a/test/test_nrtransition.jl b/test/test_nrtransition.jl index 068c328..8159425 100644 --- a/test/test_nrtransition.jl +++ b/test/test_nrtransition.jl @@ -4,26 +4,26 @@ using SafeTestsets @safetestset OrderedSample_smoke = "nr transition can be created and compared" begin using Fleck: OrderedSample - a = OrderedSample{Int}(3, 2.2) - b = OrderedSample{Int}(1, 2.5) - c = OrderedSample{Int}(2, 2.5) + a = OrderedSample{Int,Float64}(3, 2.2) + b = OrderedSample{Int,Float64}(1, 2.5) + c = OrderedSample{Int,Float64}(2, 2.5) @test a < b @test isless(a, b) @test b > a @test b == c - a = OrderedSample{String}("S", 2.2) - b = OrderedSample{String}("I", 2.5) - c = OrderedSample{String}("R", 2.5) + a = OrderedSample{String,Float64}("S", 2.2) + b = OrderedSample{String,Float64}("I", 2.5) + c = OrderedSample{String,Float64}("R", 2.5) @test a < b @test isless(a, b) @test b > a @test b == c # You can index it with a tuple if you want. - a = OrderedSample{Tuple{String,Int}}(("S", 3), 2.2) - b = OrderedSample{Tuple{String,Int}}(("I", 7), 2.5) - c = OrderedSample{Tuple{String,Int}}(("R", 4), 2.5) + a = OrderedSample{Tuple{String,Int},Float64}(("S", 3), 2.2) + b = OrderedSample{Tuple{String,Int},Float64}(("I", 7), 2.5) + c = OrderedSample{Tuple{String,Int},Float64}(("R", 4), 2.5) @test a < b @test isless(a, b) @test b > a @@ -37,11 +37,11 @@ end # This shows that the custom isless() operator is what we need in order to # use this data structure for sampling. - heap = MutableBinaryMinHeap{OrderedSample{Int}}() - push!(heap, OrderedSample{Int}(3, 2.2)) - push!(heap, OrderedSample{Int}(1, 3.9)) - push!(heap, OrderedSample{Int}(7, 0.05)) - push!(heap, OrderedSample{Int}(5, 1.7)) + heap = MutableBinaryMinHeap{OrderedSample{Int,Float64}}() + push!(heap, OrderedSample{Int,Float64}(3, 2.2)) + push!(heap, OrderedSample{Int,Float64}(1, 3.9)) + push!(heap, OrderedSample{Int,Float64}(7, 0.05)) + push!(heap, OrderedSample{Int,Float64}(5, 1.7)) @test pop!(heap).key == 7 @test pop!(heap).key == 5 @test pop!(heap).key == 3 diff --git a/test/test_vas.jl b/test/test_vas.jl index 6c935c4..bf803ec 100644 --- a/test/test_vas.jl +++ b/test/test_vas.jl @@ -33,7 +33,7 @@ vas = VectorAdditionSystem(take, give, rates) initializer = vas_initial(vas, [1, 1, 0]) state = zero_state(vas) -track_hazards = DebugWatcher{Int}() +track_hazards = DebugWatcher{Int,Float64}() fire!(track_hazards, vas, state, initializer, 0.0, "rng") enabled = Set(entry.clock for entry in track_hazards.enabled) @test enabled == Set([1, 2, 3, 4]) @@ -51,7 +51,7 @@ initializer = vas_initial(vas, [1, 1, 0]) state = zero_state(vas) initializer(state) -track_hazards = DebugWatcher{Int}() +track_hazards = DebugWatcher{Int,Float64}() fire_index = 2 input_change = vas_delta(vas, fire_index) fire!(track_hazards, vas, state, input_change, 0.0, "rng") @@ -77,7 +77,7 @@ using ..SampleVAS: sample_transitions disabled = zeros(Int, 0) enabled = zeros(Int, 0) newly_enabled = zeros(Int, 0) - track_hazards = TrackWatcher{Int}() + track_hazards = TrackWatcher{Int,Float64}() curtime = 0.0 for i in 1:10 if isnothing(next_transition) @@ -118,7 +118,7 @@ end take, give, rates = sample_transitions() vas = VectorAdditionSystem(take, give, rates) initial_state = vas_initial(vas, [1, 1, 0]) - fsm = VectorAdditionFSM(vas, initial_state, DirectCall{Int}(), rng) + fsm = VectorAdditionFSM(vas, initial_state, DirectCall{Int,Float64}(), rng) when, next_transition = simstep!(fsm) limit = 10 while next_transition !== nothing && limit > 0 @@ -140,7 +140,7 @@ end starting[2:cnt] .= 1 starting[cnt + 1] = 1 # Start with one infected. initial_state = vas_initial(vas, starting) - fsm = VectorAdditionFSM(vas, initial_state, DirectCall{Int}(), rng) + fsm = VectorAdditionFSM(vas, initial_state, DirectCall{Int,Float64}(), rng) when, next_transition = simstep!(fsm) event_cnt = 0 while next_transition !== nothing diff --git a/test/test_vas_integrate.jl b/test/test_vas_integrate.jl index c73cd7e..c900255 100644 --- a/test/test_vas_integrate.jl +++ b/test/test_vas_integrate.jl @@ -12,7 +12,7 @@ using ..SampleVAS: sample_sir cnt = 30 vas = VectorAdditionSystem(sample_sir(cnt)...) - sampler = DirectCall{Int}() + sampler = DirectCall{Int,Float64}() starting = zeros(Int, 3 * cnt) starting[2:cnt] .= 1 @@ -37,7 +37,7 @@ end cnt = 30 vas = VectorAdditionSystem(sample_sir(cnt)...) - sampler = FirstReaction{Int}() + sampler = FirstReaction{Int,Float64}() starting = zeros(Int, 3 * cnt) starting[2:cnt] .= 1 diff --git a/test/time_combinednr.jl b/test/time_combinednr.jl index 53f3096..f07f0f4 100644 --- a/test/time_combinednr.jl +++ b/test/time_combinednr.jl @@ -29,7 +29,7 @@ function compare_with_next_reaction() iter_cnt = 0 while pure_result === nothing || dual_result === nothing pure_nr = NextReaction{Int}() - dual_nr = CombinedNextReaction{Int}() + dual_nr = CombinedNextReaction{Int,Float64}() pure_rng = Xoshiro(342432) dual_rng = Xoshiro(342432) pure_time = @timed sample_a_while(pure_nr, common_distribution, pure_rng) @@ -55,7 +55,7 @@ function compare_with_modified_next_reaction() # We run a few iterations in order to account for compilation time. for burn_one in 1:4 pure_nr = ModifiedNextReaction{Int}() - dual_nr = CombinedNextReaction{Int}() + dual_nr = CombinedNextReaction{Int,Float64}() pure_rng = Xoshiro(342432) dual_rng = Xoshiro(342432) pure_time = @timed sample_a_while(pure_nr, common_distribution, pure_rng)