Skip to content

Commit

Permalink
Merge pull request #40 from adolgert/35-inconsistencies-in-type-param…
Browse files Browse the repository at this point in the history
…eters-for-samplers

35 inconsistencies in type parameters for samplers
  • Loading branch information
adolgert authored Mar 4, 2024
2 parents 57cb405 + 529ad76 commit f09169d
Show file tree
Hide file tree
Showing 22 changed files with 197 additions and 182 deletions.
2 changes: 2 additions & 0 deletions src/Fleck.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
module Fleck
using Documenter

const ContinuousTime = AbstractFloat

include("prefixsearch/binarytreeprefixsearch.jl")
include("prefixsearch/cumsumprefixsearch.jl")
include("prefixsearch/keyedprefixsearch.jl")
Expand Down
4 changes: 4 additions & 0 deletions src/prefixsearch/binarytreeprefixsearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ function BinaryTreePrefixSearch{T}(N=32) where {T<:Real}
end


time_type(ps::BinaryTreePrefixSearch{T}) where {T} = T
time_type(::Type{BinaryTreePrefixSearch{T}}) where {T} = T


function _btps_sizes(allocation)
@assert allocation > 0
depth = Int(ceil(log2(allocation))) + 1
Expand Down
2 changes: 2 additions & 0 deletions src/prefixsearch/cumsumprefixsearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ end


Base.length(ps::CumSumPrefixSearch) = length(ps.array)
time_type(ps::CumSumPrefixSearch{T}) where {T} = T
time_type(ps::Type{CumSumPrefixSearch{T}}) where {T} = T


function Base.push!(ps::CumSumPrefixSearch{T}, value::T) where {T}
Expand Down
20 changes: 12 additions & 8 deletions src/prefixsearch/keyedprefixsearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ end


Base.length(kp::KeyedKeepPrefixSearch) = length(kp.index)
time_type(kp::KeyedKeepPrefixSearch{T,P}) where {T,P} = time_type(P)


function Base.setindex!(kp::KeyedKeepPrefixSearch, val, clock)
idx = get(kp.index, clock, 0)
Expand All @@ -37,9 +39,9 @@ function Base.setindex!(kp::KeyedKeepPrefixSearch, val, clock)
end


Base.delete!(kp::KeyedKeepPrefixSearch, clock) = kp.prefix[kp.index[clock]] = zero(Float64)
Base.delete!(kp::KeyedKeepPrefixSearch, clock) = kp.prefix[kp.index[clock]] = zero(time_type(kp))
function Base.sum!(kp::KeyedKeepPrefixSearch)
(length(kp.index) > 0) ? sum!(kp.prefix) : zero(Float64)
(length(kp.index) > 0) ? sum!(kp.prefix) : zero(time_type(kp))
end


Expand All @@ -53,7 +55,8 @@ function Random.rand(
rng::AbstractRNG, d::Random.SamplerTrivial{KeyedKeepPrefixSearch{T,P}}
) where {T,P}
total = sum!(d[])
choose(d[], rand(rng, Uniform{Float64}(0, total)))
LocalTime = time_type(P)
choose(d[], rand(rng, Uniform{LocalTime}(zero(LocalTime), total)))
end


Expand Down Expand Up @@ -96,17 +99,17 @@ function Base.setindex!(kp::KeyedRemovalPrefixSearch, val, clock)
end


function Base.delete!(kp::KeyedRemovalPrefixSearch, clock)
function Base.delete!(kp::KeyedRemovalPrefixSearch{T,P}, clock) where {T,P}
idx = kp.index[clock]
kp.prefix[idx] = zero(Float64)
kp.prefix[idx] = zero(time_type(P))
delete!(kp.index, clock)
# kp.key[idx] is now out of date.
push!(kp.free, idx)
end


function Base.sum!(kp::KeyedRemovalPrefixSearch)
(length(kp.index) > 0) ? sum!(kp.prefix) : zero(Float64)
function Base.sum!(kp::KeyedRemovalPrefixSearch{T,P}) where {T,P}
(length(kp.index) > 0) ? sum!(kp.prefix) : zero(time_type(P))
end

function choose(kp::KeyedRemovalPrefixSearch, value)
Expand All @@ -119,5 +122,6 @@ function Random.rand(
rng::AbstractRNG, d::Random.SamplerTrivial{KeyedRemovalPrefixSearch{T,P}}
) where {T,P}
total = sum!(d[])
choose(d[], rand(rng, Uniform{Float64}(0, total)))
LocalTime = time_type(P)
choose(d[], rand(rng, Uniform{LocalTime}(zero(LocalTime), total)))
end
94 changes: 47 additions & 47 deletions src/sample/combinednr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ invert_space(::Type{LinearSampling}, dist, survival) = cquantile(dist, survival)
invert_space(::Type{LogSampling}, dist, survival) = invlogccdf(dist, survival)::Float64


struct NRTransition
struct NRTransition{T}
heap_handle::Int
survival::Float64 # value of S_j or Λ_j
survival::T # value of S_j or Λ_j
distribution::UnivariateDistribution
te::Float64 # Enabling time of distribution
t0::Float64 # Enabling time of transition
te::T # Enabling time of distribution
t0::T # Enabling time of transition
end


"""
CombinedNextReaction{KeyType}()
CombinedNextReaction{KeyType,TimeType}()
This combines Next Reaction Method and Modified Next Reaction Method.
The Next Reaction Method is from Gibson and Bruck in their 2000 paper called
Expand Down Expand Up @@ -128,18 +128,18 @@ sampling_space(::LinearGamma) = LinearSampling
If you want to test a distribution, look at `tests/nrmetric.jl` to see how
distributions are timed.
"""
struct CombinedNextReaction{T}
firing_queue::MutableBinaryHeap{OrderedSample{T}}
transition_entry::Dict{T,NRTransition}
struct CombinedNextReaction{K,T}
firing_queue::MutableBinaryHeap{OrderedSample{K,T}}
transition_entry::Dict{K,NRTransition{T}}
end


function CombinedNextReaction{T}() where {T}
heap = MutableBinaryMinHeap{OrderedSample{T}}()
CombinedNextReaction{T}(heap, Dict{T,NRTransition}())
function CombinedNextReaction{K,T}() where {K,T <: ContinuousTime}
heap = MutableBinaryMinHeap{OrderedSample{K,T}}()
CombinedNextReaction{K,T}(heap, Dict{K,NRTransition{T}}())
end

clone(nr::CombinedNextReaction{T}) where {T} = CombinedNextReaction{T}()
clone(nr::CombinedNextReaction{K,T}) where {K,T} = CombinedNextReaction{K,T}()
export clone


Expand All @@ -150,20 +150,20 @@ on a CombinedNextReaction sampler, it returns the key associated with the
clock that fires and marks that clock as fired. Calling next() again would
return a nonsensical value.
"""
function next(nr::CombinedNextReaction, when::Float64, rng::AbstractRNG)
function next(nr::CombinedNextReaction{K,T}, when::T, rng::AbstractRNG) where {K,T}
if !isempty(nr.firing_queue)
least = first(nr.firing_queue)
# For this sampler, mark this transition as the one that will fire
# by marking its remaining cumulative time as 0.0.
entry = nr.transition_entry[least.key]
nr.transition_entry[least.key] = NRTransition(
nr.transition_entry[least.key] = NRTransition{T}(
entry.heap_handle, get_survival_zero(entry.distribution),
entry.distribution, entry.te, entry.t0
)
return (least.time, least.key)
else
# Return type is Tuple{Float64, Union{Nothing,T}} because T is not default-constructible.
return (Inf, nothing)
return (typemax(T), nothing)
end
end

Expand All @@ -175,11 +175,11 @@ function sample_shifted(
rng::AbstractRNG,
distribution::UnivariateDistribution,
::Type{S},
te::Float64,
when::Float64
) where {S <: SamplingSpaceType}
te::T,
when::T
) where {S <: SamplingSpaceType, T <: ContinuousTime}
if te < when
shifted_distribution = truncated(distribution, when - te, Inf)
shifted_distribution = truncated(distribution, when - te, typemax(T))
sample = rand(rng, shifted_distribution)
tau = te + sample
survival = survival_space(S, shifted_distribution, sample)
Expand All @@ -194,10 +194,10 @@ end


function sample_by_inversion(
distribution::UnivariateDistribution, ::Type{S}, te::Float64, when::Float64, survival::Float64
) where {S <: SamplingSpaceType}
distribution::UnivariateDistribution, ::Type{S}, te::T, when::T, survival::T
) where {S <: SamplingSpaceType, T <: ContinuousTime}
if te < when
te + invert_space(S, truncated(distribution, when - te, Inf), survival)
te + invert_space(S, truncated(distribution, when - te, typemax(T)), survival)
else # te > when
te + invert_space(S, distribution, survival)
end
Expand All @@ -215,17 +215,17 @@ te can be before t0, at t0, between t0 and tn, or at tn, or after tn.
"""
function consume_survival(
record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::Float64
) where {S <: LinearSampling}
record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::T
) where {S <: LinearSampling, T <: ContinuousTime}
survive_te_tn = if record.te < tn
ccdf(distribution, tn-record.te)::Float64
ccdf(distribution, tn-record.te)::T
else
one(Float64)
one(T)
end
survive_te_t0 = if record.te < record.t0
ccdf(distribution, record.t0-record.te)::Float64
ccdf(distribution, record.t0-record.te)::T
else
one(Float64)
one(T)
end
record.survival / (survive_te_t0 * survive_te_tn)
end
Expand All @@ -239,52 +239,52 @@ Anderson's method.
"""
function consume_survival(
record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::Float64
) where {S <: LogSampling}
record::NRTransition, distribution::UnivariateDistribution, ::Type{S}, tn::T
) where {S <: LogSampling, T <: ContinuousTime}
log_survive_te_tn = if record.te < tn
logccdf(distribution, tn-record.te)::Float64
logccdf(distribution, tn-record.te)::T
else
zero(Float64)
zero(T)
end
log_survive_te_t0 = if record.te < record.t0
logccdf(distribution, record.t0-record.te)::Float64
logccdf(distribution, record.t0-record.te)::T
else
zero(Float64)
zero(T)
end
record.survival - (log_survive_te_t0 + log_survive_te_tn)
end


function enable!(
nr::CombinedNextReaction{T}, clock::T, distribution::UnivariateDistribution,
te::Float64, when::Float64, rng::AbstractRNG) where {T}
nr::CombinedNextReaction{K,T}, clock::K, distribution::UnivariateDistribution,
te::T, when::T, rng::AbstractRNG) where {K,T}
enable!(nr, clock, distribution, sampling_space(distribution), te, when, rng)
nothing
end


function enable!(
nr::CombinedNextReaction{T}, clock::T, distribution::UnivariateDistribution, ::Type{S},
te::Float64, when::Float64, rng::AbstractRNG) where {T, S <: SamplingSpaceType}
nr::CombinedNextReaction{K,T}, clock::K, distribution::UnivariateDistribution, ::Type{S},
te::T, when::T, rng::AbstractRNG) where {K, T, S <: SamplingSpaceType}

# Three cases: a) never been enabled b) currently enabled c) was disabled.
record = get(
nr.transition_entry,
clock,
NRTransition(0, get_survival_zero(S), Never(), 0.0, 0.0)
NRTransition{T}(0, get_survival_zero(S), Never(), zero(T), zero(T))
)
heap_handle = record.heap_handle

# if the transition needs to be re-drawn.
if record.survival <= get_survival_zero(S)
tau, shift_survival = sample_shifted(rng, distribution, S, te, when)
sample = OrderedSample{T}(clock, tau)
sample = OrderedSample{K,T}(clock, tau)
if record.heap_handle > 0
update!(nr.firing_queue, record.heap_handle, sample)
else
heap_handle = push!(nr.firing_queue, sample)
end
nr.transition_entry[clock] = NRTransition(
nr.transition_entry[clock] = NRTransition{T}(
heap_handle, shift_survival, distribution, te, when
)

Expand All @@ -300,18 +300,18 @@ function enable!(
# Account for time between when this was last enabled and now.
survival_remain = consume_survival(record, record.distribution, S, when)
tau = sample_by_inversion(distribution, S, te, when, survival_remain)
entry = OrderedSample{T}(clock, tau)
entry = OrderedSample{K,T}(clock, tau)
update!(nr.firing_queue, record.heap_handle, entry)
nr.transition_entry[clock] = NRTransition(
nr.transition_entry[clock] = NRTransition{T}(
heap_handle, survival_remain, distribution, te, when
)
end

# The transition was previously disabled.
else
tau = sample_by_inversion(distribution, S, te, when, record.survival)
heap_handle = push!(nr.firing_queue, OrderedSample{T}(clock, tau))
nr.transition_entry[clock] = NRTransition(
heap_handle = push!(nr.firing_queue, OrderedSample{K,T}(clock, tau))
nr.transition_entry[clock] = NRTransition{T}(
heap_handle, record.survival, distribution, te, when
)
end
Expand All @@ -320,10 +320,10 @@ function enable!(
end


function disable!(nr::CombinedNextReaction{T}, clock::T, when::Float64) where {T}
function disable!(nr::CombinedNextReaction{K,T}, clock::K, when::T) where {K,T <: ContinuousTime}
record = nr.transition_entry[clock]
delete!(nr.firing_queue, record.heap_handle)
nr.transition_entry[clock] = NRTransition(
nr.transition_entry[clock] = NRTransition{T}(
0,
consume_survival(record, record.distribution, sampling_space(record.distribution), when),
record.distribution,
Expand Down
Loading

0 comments on commit f09169d

Please sign in to comment.