Skip to content

Commit

Permalink
Merge pull request #75 from adolgert/feature/multisampler-is-a-sampler
Browse files Browse the repository at this point in the history
MultiSampler derives from SSA
  • Loading branch information
adolgert authored Jun 4, 2024
2 parents f070409 + b9b8615 commit e9eefc9
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 127 deletions.
1 change: 0 additions & 1 deletion docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ freeze
misscount
misses
MultiSampler
SingleSampler
ChatReaction
DebugWatcher
TrackWatcher
Expand Down
2 changes: 1 addition & 1 deletion src/CompetingClocks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ include("prefixsearch/binarytreeprefixsearch.jl")
include("prefixsearch/cumsumprefixsearch.jl")
include("prefixsearch/keyedprefixsearch.jl")
include("lefttrunc.jl")
include("sample/sampler.jl")
include("sample/interface.jl")
include("sample/sampler.jl")
include("sample/neverdist.jl")
include("sample/track.jl")
include("sample/nrtransition.jl")
Expand Down
48 changes: 40 additions & 8 deletions src/sample/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,22 @@ using Random: AbstractRNG
using Distributions: UnivariateDistribution
import Base: getindex, keys, length, keytype

export enable!, disable!, next,
export SSA, enable!, disable!, next,
getindex, keys, length, keytype

"""
SSA{KeyType,TimeType}
This abstract type represents a stochastic simulation algorithm (SSA). It is
parametrized by the clock ID, or key, and the type used for the time, which
is typically a Float64. The type of the key can be anything you would use
as a dictionary key. This excludes mutable values but includes a wide range
of identifiers useful for simulation. For instance, it could be a `String`,
but it could be a `Tuple{Int64,Int64,Int64}`, so that it indexes into a
complicated simulation state.
"""
abstract type SSA{Key,Time} end


"""
enable!(sampler, clock, distribution, enablingtime, currenttime, RNG)
Expand All @@ -27,6 +40,7 @@ function enable!(
when::T, # current simulation time
rng::AbstractRNG
) where {K,T}
@assert false
end

"""
Expand All @@ -36,7 +50,9 @@ After a sampler is used for a simulation run, it has internal state. This
function resets that internal state to the initial value in preparation
for another sample run.
"""
function reset!(sampler::SSA{K,T}) where {K,T} end
function reset!(sampler::SSA{K,T}) where {K,T}
@assert false
end


"""
Expand All @@ -47,7 +63,9 @@ the current state of the destination sampler. This is useful for splitting
techniques where you make copies of a simulation and restart it with different
random number generators.
"""
function Base.copy!(sampler::SSA{K,T}) where {K,T} end
function Base.copy!(sampler::SSA{K,T}) where {K,T}
@assert false
end


"""
Expand All @@ -56,37 +74,51 @@ function Base.copy!(sampler::SSA{K,T}) where {K,T} end
Tell the sampler to forget a clock. We include the current simulation time
because some Next Reaction methods use this to optimize sampling.
"""
function disable!(sampler::SSA{K,T}, clock::K, when::T) where {K,T} end
function disable!(sampler::SSA{K,T}, clock::K, when::T) where {K,T}
@assert false
end

"""
next(sampler, when, rng)
Ask the sampler for what happens next, in the form of
`(when, which)::Tuple{TimeType,KeyType}`. `rng` is a random number generator.
"""
function next(sampler::SSA{K,T}, when::T, rng::AbstractRNG) where {K,T} end
function next(sampler::SSA{K,T}, when::T, rng::AbstractRNG) where {K,T}
@assert false
end


"""
getindex(sampler, clock::KeyType)
Return stored state for a particular clock. If the clock does not exist,
a `KeyError` will be thrown.
"""
function Base.getindex(sampler::SSA{K,T}, clock::K) where {K,T} end
function Base.getindex(sampler::SSA{K,T}, clock::K) where {K,T}
@assert false
end


"""
keys(sampler)
Return all stored clocks as a vector.
"""
function Base.keys(sampler::SSA) end
function Base.keys(sampler::SSA)
@assert false
end


"""
length(sampler)::Int64
Return the number of stored clocks.
"""
function Base.length(sampler::SSA) end
function Base.length(sampler::SSA)
@assert false
end


"""
keytype(sampler)
Expand Down
139 changes: 51 additions & 88 deletions src/sample/sampler.jl
Original file line number Diff line number Diff line change
@@ -1,84 +1,31 @@

export SingleSampler, MultiSampler
export SSA, enable!, disable!, sample!
export MultiSampler

"""
SSA{KeyType,TimeType}
This abstract type represents a stochastic simulation algorithm (SSA). It is
parametrized by the clock ID, or key, and the type used for the time, which
is typically a Float64. The type of the key can be anything you would use
as a dictionary key. This excludes mutable values but includes a wide range
of identifiers useful for simulation. For instance, it could be a `String`,
but it could be a `Tuple{Int64,Int64,Int64}`, so that it indexes into a
complicated simulation state.
"""
abstract type SSA{Key,Time} end


"""
SingleSampler{SSA,Time}(propagator::SSA)
This makes a sampler from a single stochastic simulation algorithm. It combines
the core algorithm with the rest of the state of the system, which is just
the time.
"""
mutable struct SingleSampler{Algorithm,Time}
propagator::Algorithm
when::Time
end


function SingleSampler(propagator::SSA{Key,Time}) where {Key,Time}
SingleSampler{SSA{Key,Time},Time}(propagator, zero(Time))
end

function Base.copy!(dst::SingleSampler{Algorithm,Time}, src::SingleSampler{Algorithm,Time}) where {Algorithm,Time}
copy!(dst.propagator, src.propagator)
dst.when = src.when
dst
end

function sample!(sampler::SingleSampler, rng::AbstractRNG)
when, transition = next(sampler.propagator, sampler.when, rng)
if transition !== nothing
sampler.when = when
disable!(sampler.propagator, transition, sampler.when)
end
return (when, transition)
end


function enable!(
sampler::SingleSampler, clock, distribution::UnivariateDistribution, te, rng::AbstractRNG)
enable!(sampler.propagator, clock, distribution, te, sampler.when, rng)
end


function disable!(sampler::SingleSampler, clock)
disable!(sampler.propagator, clock, sampler.when)
end


abstract type SamplerChoice{Key,SamplerKey} end
abstract type SamplerChoice{SamplerKey,Key} end

function choose_sampler(
chooser::SamplerChoice{Key,SamplerKey}, clock::Key, distribution::UnivariateDistribution
)::SamplerKey where {Key,SamplerKey}
chooser::SamplerChoice{SamplerKey,Key}, clock::Key, distribution::UnivariateDistribution
)::SamplerKey where {SamplerKey,Key}
throw(MissingException("No sampler choice given to the MultiSampler"))
end

export SamplerChoice
export choose_sampler

"""
MultiSampler{SamplerKey,Key,Time}(which_sampler::Function)
MultiSampler{SamplerKey,Key,Time}(which_sampler::Chooser) <: SSA{Key,Time}
This makes a sampler that uses multiple stochastic sampling algorithms (SSA) to
determine the next transition to fire. It returns the soonest transition of all
of the algorithms. The `which_sampler` function looks at the clock ID, or key,
and chooses which sampler should sample this clock. Add algorithms to this
sampler like you would add them to a dictionary.
A sampler returns the soonest event, so we can make a hierarchical sampler
that returns the soonest event of the samplers it contains. This is useful because
the performance of a sampler depends on the type of the event. For instance,
some simulations have a few fast events and a lot of slow ones, so it helps
to split them into separate data structures.
The `SamplerKey` is the type of an identifier for the samplers that this
`MultiSampler` contains. The `which_sampler` argument is a strategy object
that decides which event is sampled by which contained sampler. There is
an example of this below.
Once a clock is first enabled, it will always go to the same sampler.
This sampler remembers the associations, which could increase memory for
Expand Down Expand Up @@ -118,21 +65,19 @@ sampler[:slow] = FirstToFire{Int64,Float64}()
```
"""
mutable struct MultiSampler{SamplerKey,Key,Time,Chooser}
mutable struct MultiSampler{SamplerKey,Key,Time,Chooser} <: SSA{Key,Time}
propagator::Dict{SamplerKey,SSA{Key,Time}}
when::Time
chooser::Chooser
chosen::Dict{Key,SamplerKey}
end


function MultiSampler{SamplerKey,Key,Time}(
which_sampler::Chooser
) where {SamplerKey,Key,Time,Chooser <: SamplerChoice{Key,SamplerKey}}
) where {SamplerKey,Key,Time,Chooser <: SamplerChoice{SamplerKey,Key}}

MultiSampler{SamplerKey,Key,Time,Chooser}(
Dict{SamplerKey,SSA{Key,Time}}(),
zero(Time),
which_sampler,
Dict{Key,SamplerKey}()
)
Expand All @@ -143,7 +88,6 @@ function reset!(sampler::MultiSampler)
for clear_sampler in values(sampler.propagator)
reset!(clear_sampler)
end
sampler.when = zero(sampler.when)
empty!(sampler.chosen)
end

Expand All @@ -154,7 +98,8 @@ function Base.copy!(
) where {SamplerKey,Key,Time,Chooser}

copy!(dst.propagator, src.propagator)
dst.when = src.when
dst.chooser = src.chooser
copy!(dst.chosen, src.chosen)
dst
end

Expand All @@ -166,40 +111,58 @@ function Base.setindex!(
end


function sample!(
function next(
sampler::MultiSampler{SamplerKey,Key,Time},
when::Time,
rng::AbstractRNG
) where {SamplerKey,Key,Time}

least_when::Time = typemax(Time)
least_transition::Union{Nothing,Key} = nothing
least_source::Union{Nothing,SamplerKey} = nothing
for (sample_key, propagator) in sampler.propagator
when, transition = next(propagator, sampler.when, rng)
for propagator in values(sampler.propagator)
when, transition = next(propagator, when, rng)
if when < least_when
least_when = when
least_transition = transition
least_source = sample_key
end
end
if least_transition !== nothing
sampler.when = least_when
disable!(sampler.propagator[least_source], least_transition, least_when)
end
return (least_when, least_transition)
end


function enable!(
sampler::MultiSampler, clock, distribution::UnivariateDistribution, te, rng::AbstractRNG
)
sampler::MultiSampler{SamplerKey,Key,Time},
clock::Key,
distribution::UnivariateDistribution,
te::Time,
when::Time,
rng::AbstractRNG
) where {SamplerKey,Key,Time}
@debug "Enabling the MultiSampler"
this_clock_sampler = choose_sampler(sampler.chooser, clock, distribution)
sampler.chosen[clock] = this_clock_sampler
propagator = sampler.propagator[this_clock_sampler]
enable!(propagator, clock, distribution, te, sampler.when, rng)
enable!(propagator, clock, distribution, te, when, rng)
end


function disable!(
sampler::MultiSampler{SamplerKey,Key,Time}, clock::Key, when::Time
) where {SamplerKey,Key,Time}
disable!(sampler.propagator[sampler.chosen[clock]], clock, when)
end


function Base.getindex(sampler::MultiSampler, clock)
return getindex(sampler.chosen[clock], clock)
end


function Base.keys(sampler::MultiSampler)
return union([keys(propagator) for propagator in values(sampler.propagator)]...)
end


function disable!(sampler::MultiSampler, clock)
disable!(sampler.propagator[sampler.chosen[clock]], clock, sampler.when)
function Base.length(sampler::MultiSampler)
return sum([length(propagator) for propagator in values(sampler.propagator)])
end
Loading

0 comments on commit e9eefc9

Please sign in to comment.