Skip to content

Commit

Permalink
Merge pull request #42 from adolgert/sampler-interface
Browse files Browse the repository at this point in the history
Creating a base class for samplers, called `SSA{K,T}`. Adds getindex, length, keys, keytype to all SSA classes.
  • Loading branch information
adolgert authored Mar 26, 2024
2 parents f09169d + 56b09b2 commit d78cf8d
Show file tree
Hide file tree
Showing 14 changed files with 342 additions and 48 deletions.
4 changes: 2 additions & 2 deletions src/Fleck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ include("prefixsearch/binarytreeprefixsearch.jl")
include("prefixsearch/cumsumprefixsearch.jl")
include("prefixsearch/keyedprefixsearch.jl")
include("lefttrunc.jl")
include("sample/neverdist.jl")
include("sample/sampler.jl")
include("sample/interface.jl")
include("sample/neverdist.jl")
include("sample/track.jl")
# include("sample/interface.jl")
include("sample/nrtransition.jl")
include("sample/firstreaction.jl")
include("sample/firsttofire.jl")
Expand Down
69 changes: 46 additions & 23 deletions src/prefixsearch/binarytreeprefixsearch.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Base: sum!, push!, length, setindex!
import Base: sum!, push!, length, setindex!, getindex
using Random
using Distributions: Uniform
using Logging
Expand All @@ -10,17 +10,19 @@ easier to find the leaf such that the sum of it and all previous
leaves is greater than a given value.
"""
mutable struct BinaryTreePrefixSearch{T<:Real}
array::Array{T,1}
depth::Int64
offset::Int64
cnt::Int64
# Data structure uses an array with children of i at 2i and 2i+1.
array::Array{T,1} # length(array) > 0
depth::Int64 # length(array) = 2^depth - 1
offset::Int64 # 2^(depth - 1). Index of first leaf and number of leaves.
cnt::Int64 # Number of leaves in use. Logical number of entries. cnt > 0.
end


"""
BinaryTreePrefixSearch{T}()
BinaryTreePrefixSearch{T}([N])
Constructor of a prefix search tree from an iterable list of real numbers.
The optional hint, N, is the number of values to pre-allocate.
"""
function BinaryTreePrefixSearch{T}(N=32) where {T<:Real}
depth, offset, array_cnt = _btps_sizes(N)
Expand All @@ -33,26 +35,42 @@ time_type(ps::BinaryTreePrefixSearch{T}) where {T} = T
time_type(::Type{BinaryTreePrefixSearch{T}}) where {T} = T


"""
ceil_log2(v::Integer)
Integer log2, rounding up.
"""
function ceil_log2(v::Integer)
r = 0
power_of_two = ((v & (v - 1)) == 0) ? 0 : 1
while (v >>= 1) != 0
r += 1
end
r + power_of_two
end


# The tree must have at least `allocation` leaves.
function _btps_sizes(allocation)
@assert allocation > 0
depth = Int(ceil(log2(allocation))) + 1
allocation = (allocation > 0) ? allocation : 1
depth = ceil_log2(allocation) + 1
offset = 2^(depth - 1)
array_cnt = 2^depth - 1
@assert allocation <= offset + 1
return (depth, offset, array_cnt)
end


# newsize is the desired number of entries. It will allocate more than this.
function resize!(pst::BinaryTreePrefixSearch{T}, newsize) where {T}
depth, offset, array_cnt = _btps_sizes(newsize)
# newcnt is the desired number of entries.
function resize!(pst::BinaryTreePrefixSearch{T}, newcnt) where {T}
depth, offset, array_cnt = _btps_sizes(newcnt)
b = zeros(T, array_cnt)
will_fit = min(offset, pst.offset)
b[offset:(offset + will_fit - 1)] = pst.array[pst.offset:(pst.offset + will_fit - 1)]
pst.array = b
pst.depth = depth
pst.offset = offset
pst.cnt = newsize
pst.cnt = newcnt
calculate_prefix!(pst)
end

Expand Down Expand Up @@ -97,34 +115,35 @@ If there are multiple values to enter, then present them
at once as pairs of tuples, (index, value).
"""
function set_multiple!(pst::BinaryTreePrefixSearch, pairs)
modify=Set{Int}()
maxindex = maximum([i for (i, v) in pairs])
if maxindex > allocated(pst)
@debug "BinaryTreePrefixSearch resizing to $maxindex"
resize!(pst, maxindex)
end
if maxindex > pst.cnt
pst.cnt = maxindex
end
modify = Set{Int}()
for (pos, value) in pairs
index = pos + pst.offset - 1
pst.array[index] = value
push!(modify, div(index, 2))
end

# everything at depth-1 is correct, and changes are in modify.
for depth = (pst.depth - 2):-1:0
for depth = (pst.depth - 1):-1:1
parents = Set{Int}()
for node_idx in modify
pst.array[node_idx] = pst.array[2 * node_idx] + pst.array[2 * node_idx + 1]
push!(parents, div(node_idx, 2))
end
modify = parents
end
@assert length(modify) == 1 && first(modify) == 0
end


function Base.push!(pst::BinaryTreePrefixSearch{T}, value::T) where T
index = pst.cnt + 1
if index <= allocated(pst)
pst.cnt += 1
else
@debug "Pushing to binarytreeprefix $index $(allocated(pst))"
resize!(pst, index)
end
pst[index] = value
set_multiple!(pst, [(pst.cnt + 1, value)])
return value
end

Expand All @@ -136,6 +155,10 @@ function Base.setindex!(pst::BinaryTreePrefixSearch{T}, value::T, index) where T
set_multiple!(pst, [(index, value)])
end

function Base.getindex(pst::BinaryTreePrefixSearch{T}, index) where {T}
return pst.array[index + pst.offset - 1]
end


# Private
function calculate_prefix!(pst::BinaryTreePrefixSearch)
Expand Down
10 changes: 9 additions & 1 deletion src/prefixsearch/keyedprefixsearch.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import Base: in, setindex!, delete!
import Base: in, setindex!, delete!, getindex
using Random


Expand Down Expand Up @@ -98,6 +98,14 @@ function Base.setindex!(kp::KeyedRemovalPrefixSearch, val, clock)
end
end

function Base.getindex(kp::KeyedRemovalPrefixSearch, clock)
if haskey(kp.index, clock)
idx = kp.index[clock]
return kp.prefix[idx]
else
throw(KeyError(clock))
end
end

function Base.delete!(kp::KeyedRemovalPrefixSearch{T,P}, clock) where {T,P}
idx = kp.index[clock]
Expand Down
24 changes: 22 additions & 2 deletions src/sample/combinednr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ sampling_space(::LinearGamma) = LinearSampling
If you want to test a distribution, look at `tests/nrmetric.jl` to see how
distributions are timed.
"""
struct CombinedNextReaction{K,T}
firing_queue::MutableBinaryHeap{OrderedSample{K,T}}
struct CombinedNextReaction{K,T} <: SSA{K,T}
firing_queue::MutableBinaryMinHeap{OrderedSample{K,T}}
transition_entry::Dict{K,NRTransition{T}}
end

Expand Down Expand Up @@ -332,3 +332,23 @@ function disable!(nr::CombinedNextReaction{K,T}, clock::K, when::T) where {K,T <
)
nothing
end

"""
For the `CombinedNextReaction` sampler, returns the stored firing time associated to the clock.
"""
function Base.getindex(nr::CombinedNextReaction{K,T}, clock::K) where {K,T}
if haskey(nr.transition_entry, clock)
heap_handle = getfield(nr.transition_entry[clock], :heap_handle)
return getfield(nr.firing_queue[heap_handle], :time)
else
throw(KeyError(clock))
end
end

function Base.keys(nr::CombinedNextReaction)
return collect(keys(nr.transition_entry))
end

function Base.length(nr::CombinedNextReaction)
return length(nr.transition_entry)
end
33 changes: 26 additions & 7 deletions src/sample/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,15 @@ DirectCall{K,T}() where {K,T} =
```
"""
struct DirectCall{K,P}
struct DirectCall{K,T,P} <: SSA{K,T}
prefix_tree::P
DirectCall{K,P}(tree::P) where {K,P} = new(tree)
end


function DirectCall{K,T}() where {K,T<:ContinuousTime}
prefix_tree = BinaryTreePrefixSearch{T}()
keyed_prefix_tree = KeyedRemovalPrefixSearch{K,typeof(prefix_tree)}(prefix_tree)
DirectCall{K,typeof(keyed_prefix_tree)}(keyed_prefix_tree)
DirectCall{K,T,typeof(keyed_prefix_tree)}(keyed_prefix_tree)
end


Expand All @@ -55,11 +54,16 @@ later than when it was first enabled. The `rng` is a random number generator.
If a particular clock had one rate before an event and it has another rate
after the event, call `enable!` to update the rate.
"""
function enable!(dc::DirectCall{K,P}, clock::K, distribution::Exponential,
te, when, rng::AbstractRNG) where {K,P}
function enable!(dc::DirectCall{K,T,P}, clock::K, distribution::Exponential,
te::T, when::T, rng::AbstractRNG) where {K,T,P}
dc.prefix_tree[clock] = rate(distribution)
end

function enable!(dc::DirectCall{K,T,P}, clock::K, distribution::D,
te::T, when::T, rng::AbstractRNG) where {K,T,P,D<:UnivariateDistribution}
error("DirectCall can only be used with Exponential type distributions")
end


"""
disable!(dc::DirectCall, clock::T, when)
Expand All @@ -68,7 +72,7 @@ Tell the `DirectCall` sampler to disable this clock. The `clock` argument is
an identifier for the clock. The `when` argument is the time at which this
clock is enabled.
"""
function disable!(dc::DirectCall{K,P}, clock::K, when) where {K,P}
function disable!(dc::DirectCall{K,T,P}, clock::K, when::T) where {K,T,P}
delete!(dc.prefix_tree, clock)
end

Expand All @@ -82,7 +86,7 @@ answers. Each answer is a tuple of `(when, which clock)`. If there is no clock
to fire, then the response will be `(Inf, nothing)`. That's a good sign the
simulation is done.
"""
function next(dc::DirectCall, when, rng::AbstractRNG)
function next(dc::DirectCall{K,T,P}, when::T, rng::AbstractRNG) where {K,T,P}
total = sum!(dc.prefix_tree)
if total > eps(when)
chosen, hazard_value = rand(rng, dc.prefix_tree)
Expand All @@ -92,3 +96,18 @@ function next(dc::DirectCall, when, rng::AbstractRNG)
return (typemax(when), nothing)
end
end

"""
For the `DirectCall` sampler, returns the rate parameter associated to the clock.
"""
function Base.getindex(dc::DirectCall{K,T,P}, clock::K) where {K,T,P}
return dc.prefix_tree[clock]
end

function Base.keys(dc::DirectCall)
return collect(keys(dc.prefix_tree.index))
end

function Base.length(dc::DirectCall)
return length(dc.prefix_tree)
end
23 changes: 21 additions & 2 deletions src/sample/firstreaction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Classic First Reaction method for Exponential and non-Exponential
distributions. Every time you sample, go to each distribution and ask when it
would fire. Then take the soonest and throw out the rest until the next sample.
"""
struct FirstReaction{K,T}
struct FirstReaction{K,T} <: SSA{K,T}
# This other class already stores the current set of distributions, so use it.
core_matrix::TrackWatcher{K}
FirstReaction{K,T}() where {K,T <: ContinuousTime} = new(TrackWatcher{K,T}())
Expand All @@ -29,7 +29,7 @@ function disable!(fr::FirstReaction{K,T}, clock::K, when::T) where {K,T}
end


function next(fr::FirstReaction{K,T}, when::T, rng) where {K,T}
function next(fr::FirstReaction{K,T}, when::T, rng::AbstractRNG) where {K,T}
soonest_clock::Union{Nothing,K} = nothing
soonest_time = typemax(T)

Expand All @@ -48,6 +48,25 @@ function next(fr::FirstReaction{K,T}, when::T, rng) where {K,T}
return (soonest_time, soonest_clock)
end

"""
For the `FirstReaction` sampler, returns the distribution object associated to the clock.
"""
function Base.getindex(fr::FirstReaction{K,T}, clock::K) where {K,T}
if haskey(fr.core_matrix.enabled, clock)
return getfield(fr.core_matrix.enabled[clock], :distribution)
else
throw(KeyError(clock))
end
end

function Base.keys(fr::FirstReaction)
return collect(keys(fr.core_matrix.enabled))
end

function Base.length(fr::FirstReaction)
return length(fr.core_matrix.enabled)
end


"""
This sampler can help if it's the first time you're trying a model. It checks
Expand Down
22 changes: 21 additions & 1 deletion src/sample/firsttofire.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
# Finds the next one without removing it from the queue.
function next(propagator::FirstToFire{K,T}, when::T, rng::AbstractRNG) where {K,T}
least = if !isempty(propagator.firing_queue)
top(propagator.firing_queue)
first(propagator.firing_queue)
else
OrderedSample(nothing, typemax(T))
end
Expand Down Expand Up @@ -61,3 +61,23 @@ function disable!(propagator::FirstToFire{K,T}, clock::K, when::T) where {K,T}
delete!(propagator.firing_queue, heap_handle)
delete!(propagator.transition_entry, clock)
end

"""
For the `FirstToFire` sampler, returns the stored firing time associated to the clock.
"""
function Base.getindex(propagator::FirstToFire{K,T}, clock::K) where {K,T}
if haskey(propagator.transition_entry, clock)
heap_handle = propagator.transition_entry[clock]
return getfield(propagator.firing_queue[heap_handle], :time)
else
throw(KeyError(clock))
end
end

function Base.keys(propagator::FirstToFire)
return collect(keys(propagator.transition_entry))
end

function Base.length(propagator::FirstToFire)
return length(propagator.transition_entry)
end
Loading

0 comments on commit d78cf8d

Please sign in to comment.