From b9b86150fcb801bbcbec2ab4908fa79250efc98d Mon Sep 17 00:00:00 2001 From: Andrew Dolgert Date: Sun, 2 Jun 2024 19:29:07 -0400 Subject: [PATCH] MultiSampler derives from SSA now and SingleSampler is gone and examples include hierarchical samplers --- docs/src/reference.md | 1 - src/CompetingClocks.jl | 2 +- src/sample/interface.jl | 48 +++++++++++--- src/sample/sampler.jl | 139 +++++++++++++++------------------------- test/test_sampler.jl | 116 ++++++++++++++++++++++++--------- 5 files changed, 179 insertions(+), 127 deletions(-) diff --git a/docs/src/reference.md b/docs/src/reference.md index 3fe954c..3b91d4b 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -29,7 +29,6 @@ freeze misscount misses MultiSampler -SingleSampler ChatReaction DebugWatcher TrackWatcher diff --git a/src/CompetingClocks.jl b/src/CompetingClocks.jl index 7ed9b2c..158c43c 100644 --- a/src/CompetingClocks.jl +++ b/src/CompetingClocks.jl @@ -7,8 +7,8 @@ include("prefixsearch/binarytreeprefixsearch.jl") include("prefixsearch/cumsumprefixsearch.jl") include("prefixsearch/keyedprefixsearch.jl") include("lefttrunc.jl") -include("sample/sampler.jl") include("sample/interface.jl") +include("sample/sampler.jl") include("sample/neverdist.jl") include("sample/track.jl") include("sample/nrtransition.jl") diff --git a/src/sample/interface.jl b/src/sample/interface.jl index 2b480ae..e2b516b 100644 --- a/src/sample/interface.jl +++ b/src/sample/interface.jl @@ -2,9 +2,22 @@ using Random: AbstractRNG using Distributions: UnivariateDistribution import Base: getindex, keys, length, keytype -export enable!, disable!, next, +export SSA, enable!, disable!, next, getindex, keys, length, keytype +""" + SSA{KeyType,TimeType} + +This abstract type represents a stochastic simulation algorithm (SSA). It is +parametrized by the clock ID, or key, and the type used for the time, which +is typically a Float64. The type of the key can be anything you would use +as a dictionary key. This excludes mutable values but includes a wide range +of identifiers useful for simulation. For instance, it could be a `String`, +but it could be a `Tuple{Int64,Int64,Int64}`, so that it indexes into a +complicated simulation state. +""" +abstract type SSA{Key,Time} end + """ enable!(sampler, clock, distribution, enablingtime, currenttime, RNG) @@ -27,6 +40,7 @@ function enable!( when::T, # current simulation time rng::AbstractRNG ) where {K,T} + @assert false end """ @@ -36,7 +50,9 @@ After a sampler is used for a simulation run, it has internal state. This function resets that internal state to the initial value in preparation for another sample run. """ -function reset!(sampler::SSA{K,T}) where {K,T} end +function reset!(sampler::SSA{K,T}) where {K,T} + @assert false +end """ @@ -47,7 +63,9 @@ the current state of the destination sampler. This is useful for splitting techniques where you make copies of a simulation and restart it with different random number generators. """ -function Base.copy!(sampler::SSA{K,T}) where {K,T} end +function Base.copy!(sampler::SSA{K,T}) where {K,T} + @assert false +end """ @@ -56,7 +74,9 @@ function Base.copy!(sampler::SSA{K,T}) where {K,T} end Tell the sampler to forget a clock. We include the current simulation time because some Next Reaction methods use this to optimize sampling. """ -function disable!(sampler::SSA{K,T}, clock::K, when::T) where {K,T} end +function disable!(sampler::SSA{K,T}, clock::K, when::T) where {K,T} + @assert false +end """ next(sampler, when, rng) @@ -64,7 +84,10 @@ function disable!(sampler::SSA{K,T}, clock::K, when::T) where {K,T} end Ask the sampler for what happens next, in the form of `(when, which)::Tuple{TimeType,KeyType}`. `rng` is a random number generator. """ -function next(sampler::SSA{K,T}, when::T, rng::AbstractRNG) where {K,T} end +function next(sampler::SSA{K,T}, when::T, rng::AbstractRNG) where {K,T} + @assert false +end + """ getindex(sampler, clock::KeyType) @@ -72,21 +95,30 @@ function next(sampler::SSA{K,T}, when::T, rng::AbstractRNG) where {K,T} end Return stored state for a particular clock. If the clock does not exist, a `KeyError` will be thrown. """ -function Base.getindex(sampler::SSA{K,T}, clock::K) where {K,T} end +function Base.getindex(sampler::SSA{K,T}, clock::K) where {K,T} + @assert false +end + """ keys(sampler) Return all stored clocks as a vector. """ -function Base.keys(sampler::SSA) end +function Base.keys(sampler::SSA) + @assert false +end + """ length(sampler)::Int64 Return the number of stored clocks. """ -function Base.length(sampler::SSA) end +function Base.length(sampler::SSA) + @assert false +end + """ keytype(sampler) diff --git a/src/sample/sampler.jl b/src/sample/sampler.jl index 6652353..b965a21 100644 --- a/src/sample/sampler.jl +++ b/src/sample/sampler.jl @@ -1,70 +1,12 @@ -export SingleSampler, MultiSampler -export SSA, enable!, disable!, sample! +export MultiSampler -""" - SSA{KeyType,TimeType} - -This abstract type represents a stochastic simulation algorithm (SSA). It is -parametrized by the clock ID, or key, and the type used for the time, which -is typically a Float64. The type of the key can be anything you would use -as a dictionary key. This excludes mutable values but includes a wide range -of identifiers useful for simulation. For instance, it could be a `String`, -but it could be a `Tuple{Int64,Int64,Int64}`, so that it indexes into a -complicated simulation state. -""" -abstract type SSA{Key,Time} end - - -""" - SingleSampler{SSA,Time}(propagator::SSA) - -This makes a sampler from a single stochastic simulation algorithm. It combines -the core algorithm with the rest of the state of the system, which is just -the time. -""" -mutable struct SingleSampler{Algorithm,Time} - propagator::Algorithm - when::Time -end - - -function SingleSampler(propagator::SSA{Key,Time}) where {Key,Time} - SingleSampler{SSA{Key,Time},Time}(propagator, zero(Time)) -end -function Base.copy!(dst::SingleSampler{Algorithm,Time}, src::SingleSampler{Algorithm,Time}) where {Algorithm,Time} - copy!(dst.propagator, src.propagator) - dst.when = src.when - dst -end - -function sample!(sampler::SingleSampler, rng::AbstractRNG) - when, transition = next(sampler.propagator, sampler.when, rng) - if transition !== nothing - sampler.when = when - disable!(sampler.propagator, transition, sampler.when) - end - return (when, transition) -end - - -function enable!( - sampler::SingleSampler, clock, distribution::UnivariateDistribution, te, rng::AbstractRNG) - enable!(sampler.propagator, clock, distribution, te, sampler.when, rng) -end - - -function disable!(sampler::SingleSampler, clock) - disable!(sampler.propagator, clock, sampler.when) -end - - -abstract type SamplerChoice{Key,SamplerKey} end +abstract type SamplerChoice{SamplerKey,Key} end function choose_sampler( - chooser::SamplerChoice{Key,SamplerKey}, clock::Key, distribution::UnivariateDistribution - )::SamplerKey where {Key,SamplerKey} + chooser::SamplerChoice{SamplerKey,Key}, clock::Key, distribution::UnivariateDistribution + )::SamplerKey where {SamplerKey,Key} throw(MissingException("No sampler choice given to the MultiSampler")) end @@ -72,13 +14,18 @@ export SamplerChoice export choose_sampler """ - MultiSampler{SamplerKey,Key,Time}(which_sampler::Function) + MultiSampler{SamplerKey,Key,Time}(which_sampler::Chooser) <: SSA{Key,Time} -This makes a sampler that uses multiple stochastic sampling algorithms (SSA) to -determine the next transition to fire. It returns the soonest transition of all -of the algorithms. The `which_sampler` function looks at the clock ID, or key, -and chooses which sampler should sample this clock. Add algorithms to this -sampler like you would add them to a dictionary. +A sampler returns the soonest event, so we can make a hierarchical sampler +that returns the soonest event of the samplers it contains. This is useful because +the performance of a sampler depends on the type of the event. For instance, +some simulations have a few fast events and a lot of slow ones, so it helps +to split them into separate data structures. + +The `SamplerKey` is the type of an identifier for the samplers that this +`MultiSampler` contains. The `which_sampler` argument is a strategy object +that decides which event is sampled by which contained sampler. There is +an example of this below. Once a clock is first enabled, it will always go to the same sampler. This sampler remembers the associations, which could increase memory for @@ -118,9 +65,8 @@ sampler[:slow] = FirstToFire{Int64,Float64}() ``` """ -mutable struct MultiSampler{SamplerKey,Key,Time,Chooser} +mutable struct MultiSampler{SamplerKey,Key,Time,Chooser} <: SSA{Key,Time} propagator::Dict{SamplerKey,SSA{Key,Time}} - when::Time chooser::Chooser chosen::Dict{Key,SamplerKey} end @@ -128,11 +74,10 @@ end function MultiSampler{SamplerKey,Key,Time}( which_sampler::Chooser - ) where {SamplerKey,Key,Time,Chooser <: SamplerChoice{Key,SamplerKey}} + ) where {SamplerKey,Key,Time,Chooser <: SamplerChoice{SamplerKey,Key}} MultiSampler{SamplerKey,Key,Time,Chooser}( Dict{SamplerKey,SSA{Key,Time}}(), - zero(Time), which_sampler, Dict{Key,SamplerKey}() ) @@ -143,7 +88,6 @@ function reset!(sampler::MultiSampler) for clear_sampler in values(sampler.propagator) reset!(clear_sampler) end - sampler.when = zero(sampler.when) empty!(sampler.chosen) end @@ -154,7 +98,8 @@ function Base.copy!( ) where {SamplerKey,Key,Time,Chooser} copy!(dst.propagator, src.propagator) - dst.when = src.when + dst.chooser = src.chooser + copy!(dst.chosen, src.chosen) dst end @@ -166,40 +111,58 @@ function Base.setindex!( end -function sample!( +function next( sampler::MultiSampler{SamplerKey,Key,Time}, + when::Time, rng::AbstractRNG ) where {SamplerKey,Key,Time} least_when::Time = typemax(Time) least_transition::Union{Nothing,Key} = nothing - least_source::Union{Nothing,SamplerKey} = nothing - for (sample_key, propagator) in sampler.propagator - when, transition = next(propagator, sampler.when, rng) + for propagator in values(sampler.propagator) + when, transition = next(propagator, when, rng) if when < least_when least_when = when least_transition = transition - least_source = sample_key end end - if least_transition !== nothing - sampler.when = least_when - disable!(sampler.propagator[least_source], least_transition, least_when) - end return (least_when, least_transition) end function enable!( - sampler::MultiSampler, clock, distribution::UnivariateDistribution, te, rng::AbstractRNG - ) + sampler::MultiSampler{SamplerKey,Key,Time}, + clock::Key, + distribution::UnivariateDistribution, + te::Time, + when::Time, + rng::AbstractRNG + ) where {SamplerKey,Key,Time} + @debug "Enabling the MultiSampler" this_clock_sampler = choose_sampler(sampler.chooser, clock, distribution) sampler.chosen[clock] = this_clock_sampler propagator = sampler.propagator[this_clock_sampler] - enable!(propagator, clock, distribution, te, sampler.when, rng) + enable!(propagator, clock, distribution, te, when, rng) +end + + +function disable!( + sampler::MultiSampler{SamplerKey,Key,Time}, clock::Key, when::Time + ) where {SamplerKey,Key,Time} + disable!(sampler.propagator[sampler.chosen[clock]], clock, when) +end + + +function Base.getindex(sampler::MultiSampler, clock) + return getindex(sampler.chosen[clock], clock) +end + + +function Base.keys(sampler::MultiSampler) + return union([keys(propagator) for propagator in values(sampler.propagator)]...) end -function disable!(sampler::MultiSampler, clock) - disable!(sampler.propagator[sampler.chosen[clock]], clock, sampler.when) +function Base.length(sampler::MultiSampler) + return sum([length(propagator) for propagator in values(sampler.propagator)]) end diff --git a/test/test_sampler.jl b/test/test_sampler.jl index 87112e1..ccefa22 100644 --- a/test/test_sampler.jl +++ b/test/test_sampler.jl @@ -1,25 +1,5 @@ using SafeTestsets -@safetestset singlesampler_smoke = "SingleSampler smoke" begin - using Random: Xoshiro - using CompetingClocks: FirstToFire, SingleSampler, enable!, disable!, sample! - using Distributions: Exponential - - sampler = SingleSampler(FirstToFire{Int64,Float64}()) - rng = Xoshiro(90422342) - enabled = Set{Int64}() - for (clock_id, propensity) in enumerate([0.3, 0.2, 0.7, 0.001, 0.25]) - enable!(sampler, clock_id, Exponential(propensity), 0.0, rng) - push!(enabled, clock_id) - end - when, which = sample!(sampler, rng) - delete!(enabled, when) - todisable = collect(enabled)[1] - disable!(sampler, todisable) - enable!(sampler, 35, Exponential(), when, rng) - when, which = sample!(sampler, rng) -end - module MultiSamplerHelp using CompetingClocks @@ -37,12 +17,24 @@ module MultiSamplerHelp )::Int64 return 2 end + + struct ByRate <: SamplerChoice{String,Int64} end + + function CompetingClocks.choose_sampler( + chooser::ByRate, clock::Int64, distribution::UnivariateDistribution + )::String + if clock ≥ 50 && clock < 100 + return "fast" + else + return "slow" + end + end end @safetestset multisampler_smoke = "MultiSampler smoke" begin using Random: Xoshiro - using CompetingClocks: FirstToFire, MultiSampler, enable!, disable!, sample!, choose_sampler, reset! + using CompetingClocks: FirstToFire, MultiSampler, enable!, disable!, next, choose_sampler, reset! using ..MultiSamplerHelp: ByDistribution using Distributions: Exponential, Gamma @@ -52,18 +44,84 @@ end rng = Xoshiro(90422342) enabled = Set{Int64}() for (clock_id, propensity) in enumerate([0.3, 0.2, 0.7, 0.001, 0.25]) + @debug "Calling enable on $clock_id" if clock_id < 3 - enable!(sampler, clock_id, Exponential(propensity), 0.0, rng) + enable!(sampler, clock_id, Exponential(propensity), 0.0, 0.0, rng) else - enable!(sampler, clock_id, Gamma(propensity), 0.0, rng) + enable!(sampler, clock_id, Gamma(propensity), 0.0, 0.0, rng) end push!(enabled, clock_id) end - when, which = sample!(sampler, rng) - delete!(enabled, when) - todisable = collect(enabled)[1] - disable!(sampler, todisable) - enable!(sampler, 35, Exponential(), when, rng) - when, which = sample!(sampler, rng) + @test 1 ∈ keys(sampler.propagator[1]) + @test 2 ∈ keys(sampler.propagator[1]) + @test 3 ∈ keys(sampler.propagator[2]) + when, which = next(sampler, 0.0, rng) + @test which !== nothing + disable!(sampler, which, when) + delete!(enabled, which) + @test enabled == Set(keys(sampler)) + todisable = pop!(enabled) + disable!(sampler, todisable, when) + @test enabled == Set(keys(sampler)) + enable!(sampler, 35, Exponential(), when, when, rng) + push!(enabled, 35) + when, which = next(sampler, when, rng) + @test which ∈ enabled + when, which = next(sampler, when, rng) + @test which ∈ enabled reset!(sampler) end + + +@safetestset multisampler_hierarchical = "MultiSampler hierarchical" begin + # Let's make a hierarchical sampler that contains a hierarchical sampler. + using Random: Xoshiro + using CompetingClocks: FirstToFire, DirectCall, MultiSampler, enable! + using CompetingClocks: disable!, next, choose_sampler, reset! + using ..MultiSamplerHelp: ByDistribution, ByRate + using Distributions: Exponential, Gamma + + EventKey = Int64 + Time = Float64 + sampler = MultiSampler{Int64,EventKey,Time}(ByDistribution()) + sampler[1] = DirectCall{EventKey,Time}() + sampler[2] = FirstToFire{EventKey,Time}() + highest = MultiSampler{String,EventKey,Time}(ByRate()) + highest["fast"] = FirstToFire{EventKey,Time}() + highest["slow"] = sampler + + rng = Xoshiro(90422342) + enabled = Set{Int64}() + for (clock_id, propensity) in enumerate([0.3, 0.2, 0.7, 0.001, 0.25]) + @debug "Calling enable on $clock_id" + if clock_id < 3 + enable!(highest, clock_id, Exponential(propensity), 0.0, 0.0, rng) + else + enable!(highest, clock_id, Gamma(propensity), 0.0, 0.0, rng) + end + push!(enabled, clock_id) + end + enable!(highest, 53, Exponential(10.0), 0.0, 0.0, rng) + push!(enabled, 53) + enable!(highest, 57, Exponential(10.0), 0.0, 0.0, rng) + push!(enabled, 57) + @test 1 ∈ keys(highest.propagator["slow"].propagator[1]) + @test 2 ∈ keys(highest.propagator["slow"].propagator[1]) + @test 3 ∈ keys(highest.propagator["slow"].propagator[2]) + @test 53 ∈ keys(highest.propagator["fast"]) + when, which = next(highest, 0.0, rng) + @test which !== nothing + disable!(highest, which, when) + delete!(enabled, which) + @test enabled == Set(keys(highest)) + todisable = pop!(enabled) + disable!(highest, todisable, when) + @test enabled == Set(keys(highest)) + enable!(highest, 35, Exponential(), when, when, rng) + push!(enabled, 35) + when, which = next(highest, when, rng) + @test which ∈ enabled + when, which = next(highest, when, rng) + @test which ∈ enabled + reset!(highest) +end