From b3db4495891639e743fbb768f5d90285afdf1c54 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jul 2023 21:15:32 +0100 Subject: [PATCH 01/58] initial work on the new gibbs sampler --- src/mcmc/gibbs_new.jl | 278 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 278 insertions(+) create mode 100644 src/mcmc/gibbs_new.jl diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl new file mode 100644 index 000000000..9713ce457 --- /dev/null +++ b/src/mcmc/gibbs_new.jl @@ -0,0 +1,278 @@ +function unique_tuple(xs::Tuple, acc::Tuple = ()) + return if Base.first(xs) ∈ acc + unique_tuple(Base.tail(xs), acc) + else + unique_tuple(Base.tail(xs), (acc..., Base.first(xs))) + end +end +unique_tuple(::Tuple{}, acc::Tuple = ()) = acc + +subset(vi::DynamicPPL.TypedVarInfo, vns::Union{Tuple,AbstractArray}) = subset(vi, vns...) +function subset(vi::DynamicPPL.TypedVarInfo, vns::VarName...) + # TODO: peform proper check of the meatdatas corresponding to different symbols. + # F. ex. we might have vns `(@varname(x[1]), @varname(x[2]))`, in which case they + # have the same `metadata`. If they don't, we should error. + vns_unique_syms = unique_tuple(map(DynamicPPL.getsym, vns)) + mds = map(Base.Fix1(DynamicPPL.getfield, vi.metadata), vns_unique_syms) + return DynamicPPL.VarInfo(NamedTuple{vns_unique_syms}(mds), vi.logp, vi.num_produce) +end + +subset(vi::DynamicPPL.SimpleVarInfo, vns::Union{Tuple,AbstractArray}) = subset(vi, vns...) +function subset(vi::DynamicPPL.SimpleVarInfo, vns::VarName...) + vals = map(Base.Fix1(getindex, vi), vns) + return DynamicPPL.BangBang.@set!! vi.values = vals +end + +function Base.merge(md::DynamicPPL.Metadata, md_subset::DynamicPPL.Metadata) + @assert md.vns == md_subset.vns "Cannot merge metadata with different vns." + @assert length(md.vals) == length(md_subset.vals) "Cannot merge metadata with different length vals." + + # TODO: Re-adjust `ranges`, etc. so we can support things like changing support, etc. + return DynamicPPL.Metadata( + md_subset.idcs, + md_subset.vns, + md_subset.ranges, + md_subset.vals, + md_subset.dists, + md_subset.gids, + md_subset.orders, + md_subset.flags, + ) +end + +function Base.merge( + vi::DynamicPPL.VarInfo{<:NamedTuple{names}}, + vi_subset::TypedVarInfo, +) where {names} + # Assumes `vi` is a superset of `vi_subset`. + metadata_vals = map(names) do vn_sym + # TODO: Make generated. + return if haskey(vi_subset, VarName{vn_sym}()) + merge(vi.metadata[vn_sym], vi_subset.metadata[vn_sym]) + else + vi.metadata[vn_sym] + end + end + + # TODO: Is this the right way to do this? + return DynamicPPL.VarInfo(NamedTuple{names}(metadata_vals), vi.logp, vi.num_produce) +end + +function Base.merge(vi_left::SimpleVarInfo, vi_right::SimpleVarInfo) + return SimpleVarInfo( + merge(vi_left.values, vi_right.values), + vi_left.logp + vi_right.logp, + ) +end + +function Base.merge(vi_left::TypedVarInfo, vi_right::TypedVarInfo) + return TypedVarInfo( + merge(vi_left.metadata, vi_right.metadata), + vi_left.logp + vi_right.logp, + ) +end + +# TODO: Move to DynamicPPL. +DynamicPPL.condition(model::Model, varinfo::SimpleVarInfo) = + DynamicPPL.condition(model, DynamicPPL.values_as(varinfo)) +function DynamicPPL.condition(model::Model, varinfo::VarInfo) + # Use `OrderedDict` as default for `VarInfo`. + # TODO: Do better! + return DynamicPPL.condition(model, DynamicPPL.values_as(varinfo, OrderedDict)) +end + +# Recursive definition. +function DynamicPPL.condition(model::Model, varinfos::AbstractVarInfo...) + return DynamicPPL.condition( + DynamicPPL.condition(model, first(varinfos)), + Base.tail(varinfos)..., + ) +end +DynamicPPL.condition(model::Model, ::Tuple{}) = model + + +""" + make_conditional_model(model, varinfo, varinfos) + +Construct a conditional model from `model` conditioned `varinfos`, excluding `varinfo` if present. + +# Examples +```julia-repl +julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); + +julia> # A separate varinfo for each variable in `model`. + varinfos = (SimpleVarInfo(s=1.0), SimpleVarInfo(m=10.0)); + +julia> # The varinfo we want to NOT condition on. + target_varinfo = first(varinfos); + +julia> # Results in a model with only `m` conditioned. + conditional_model = make_conditional(model, target_varinfo, varinfos); + +julia> rand(conditioned_model) +``` + +""" +function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) + # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. + return DynamicPPL.condition(model, filter(Base.Fix1(!==, target_varinfo), varinfos)...) +end + +wrap_algorithm_maybe(x) = x +wrap_algorithm_maybe(x::InferenceAlgorithm) = Sampler(x) + +struct GibbsV2{V,A} <: InferenceAlgorithm + varnames::V + samplers::A +end + +# NamedTuple +GibbsV2(; algs...) = GibbsV2(NamedTuple(algs)) +function GibbsV2(algs::NamedTuple) + return GibbsV2( + map(s -> VarName{s}(), keys(algs)), + map(wrap_algorithm_maybe, values(algs)), + ) +end + +# AbstractDict +GibbsV2(algs::AbstractDict) = GibbsV2(keys(algs), map(wrap_algorithm_maybe, values(algs))) +function GibbsV2(algs::Pair...) + return GibbsV2(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) +end +GibbsV2(algs::Tuple) = GibbsV2(Dict(algs)) + +struct GibbsV2State{V<:AbstractVarInfo,S} + vi::V + states::S +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::Model, + spl::Sampler{<:GibbsV2}; + kwargs..., +) + alg = spl.alg + varnames = alg.varnames + samplers = alg.samplers + + # 1. Run the model once to get the varnames present + initial values to condition on. + vi_base = DynamicPPL.VarInfo(model) + @info "" varnames map(Base.Fix1(map, DynamicPPL.getsym), varnames) + varinfos = map(Base.Fix1(subset, vi_base), varnames) + + # 2. Construct a varinfo for every vn + sampler combo. + states_and_varinfos = map(samplers, varinfos) do sampler_local, varinfo_local + # Construct the conditional model. + model_local = make_conditional(model, varinfo_local, varinfos) + + # Take initial step. + new_state_local = + last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...)) + + # Return the new state and the invlinked `varinfo`. + vi_local_state = varinfo(new_state_local) + vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) + DynamicPPL.invlink!!(deepcopy(vi_local_state), sampler_local, model_local) + else + vi_local_state + end + return (new_state_local, vi_local_state_linked) + end + + states = map(first, states_and_varinfos) + varinfos = map(last, states_and_varinfos) + + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!(varinfos, vi_base, 1) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), + DynamicPPL.getlogp(last(varinfos)), + ) + + return Transition(vi), GibbsV2State(vi, states) +end + +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::Model, + spl::Sampler{<:GibbsV2}, + state::GibbsV2State; + kwargs..., +) + alg = spl.alg + samplers = alg.samplers + + varinfos = map(varinfo, state.states) + @assert length(samplers) == length(state.states) + + + states_and_varinfos = + map(samplers, state.states, varinfos) do sampler_local, state_local, varinfo_local + # Construct the conditional model. + # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, + # otherwise we're conditioning on values which are not in the support of the + # distributions. + model_local = make_conditional(model, varinfo_local, varinfos) + + # TODO: Might need to re-run the model. + # NOTE: We use `logjoint` instead of `evaluate!!` and capturing the resulting varinfo because + # the resulting varinfo might be in un-transformed space even if `varinfo_local` + # is in transformed space. This can occur if we hit `maybe_invlink_before_eval!!`. + varinfo_local = DynamicPPL.setlogp!!( + varinfo_local, + DynamicPPL.logjoint(model_local, varinfo_local), + ) + + # Update the state we're about to use if need be. + # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. + current_state_local = + gibbs_state(model_local, sampler_local, state_local, varinfo_local) + + # Take a step. + new_state_local = last( + AbstractMCMC.step( + rng, + model_local, + sampler_local, + current_state_local; + kwargs..., + ), + ) + + # Return the resulting state and invlinked `varinfo`. + # NOTE: We have to `deepcopy` to avoid potentially changing + # the varinfo in the `new_state_local`. + varinfo_local_state = deepcopy(varinfo(new_state_local)) + varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) + DynamicPPL.invlink!!(varinfo_local_state, sampler_local, model_local) + else + varinfo_local_state + end + return (new_state_local, varinfo_local_state_invlinked) + end + + states_and_varinfos = tuple(states_and_varinfos...) + states = map(first, states_and_varinfos) + varinfos = map(last, states_and_varinfos) + + # Combine the resulting varinfo objects. + # The last varinfo holds the correctly computed logp. + vi_base = state.vi + + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!( + varinfos, + merge(vi_base, first(varinfos)), + firstindex(varinfos), + ) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), + DynamicPPL.getlogp(last(varinfos)), + ) + + return Transition(vi), GibbsV2State(vi, states) +end From efe15db1d50640f86d0798e2ab89befd040b483f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 28 Jul 2023 21:15:44 +0100 Subject: [PATCH 02/58] added tests for the new Gibbs sampler --- test/mcmc/gibbs_new.jl | 145 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 145 insertions(+) create mode 100644 test/mcmc/gibbs_new.jl diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl new file mode 100644 index 000000000..b70207ecf --- /dev/null +++ b/test/mcmc/gibbs_new.jl @@ -0,0 +1,145 @@ +using Turing, DynamicPPL + +# Okay, so what do we actually need to test here. +# 1. Needs to be compatible with most models. +# 2. Restricted to usage of pairs for now to make things simple. + +# TODO: Don't require usage of tuples due to potential of blowing up compilation times. + +# FIXME: Currently failing for `demo_assume_index_observe`. +# Likely an issue with not linking correctly. +@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + # Run one sampler on variables starting with `s` and another on variables starting with `m`. + vns_s = filter(vns) do vn + DynamicPPL.getsym(vn) == :s + end + vns_m = filter(vns) do vn + DynamicPPL.getsym(vn) == :m + end + + # Construct the sampler. + sampler = Turing.Inference.GibbsV2( + vns_s => Turing.Inference.NUTS(), + vns_m => Turing.Inference.NUTS(), + ) + + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) + @test keys(transition.θ) == Tuple(unique(map(DynamicPPL.getsym, vns))) + + for _ = 1:10 + transition, state = + AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + @test keys(transition.θ) == Tuple(unique(map(DynamicPPL.getsym, vns))) + end +end + +# @testset "Gibbs using `condition`" begin +# @testset "demo_assume_dot_observe" begin +# model = DynamicPPL.TestUtils.demo_assume_dot_observe() +# # Construct the different varinfos to be used. +# varinfos = (SimpleVarInfo(s = 1.0), SimpleVarInfo(m = 10.0)) +# # Construct the varinfo for the particular variable we want to sample. +# target_varinfo = first(varinfos) + +# # Create the conditional model. +# conditional_model = +# Turing.Inference.make_conditional(model, target_varinfo, varinfos) + +# # Sample! +# sampler = Turing.Inference.GibbsV2(@varname(s) => MH(), @varname(m) => MH()) +# rng = Random.default_rng() + +# @testset "step" begin +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) +# @test keys(transition.θ) == (:s, :m) + +# transition, state = +# AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) +# @test keys(transition.θ) == (:s, :m) + +# transition, state = +# AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) +# @test keys(transition.θ) == (:s, :m) +# end + +# @testset "sample" begin +# chain = sample(model, sampler, 1000) +# @test size(chain, 1) == 1000 +# display(mean(chain)) +# end +# end + +# # @testset "gdemo" begin +# # Random.seed!(100) +# # alg = Turing.Inference.GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) +# # chain = sample(gdemo(1.5, 2.0), alg, 10_000) +# # end + +# # @testset "MoGtest" begin +# # Random.seed!(125) +# # alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) +# # chain = sample(MoGtest_default, alg, 6000) +# # check_MoGtest_default(chain, atol = 0.1) +# # end + +# @testset "multiple varnames" begin +# rng = Random.default_rng() + +# # With both `s` and `m` as random. +# model = gdemo(1.5, 2.0) +# alg = Turing.Inference.GibbsV2((@varname(s), @varname(m)) => MH()) + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) +# @test keys(transition.θ) == (:s, :m) + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) +# @test keys(transition.θ) == (:s, :m) + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) +# @test keys(transition.θ) == (:s, :m) + +# # Sample. +# chain = sample(model, alg, 10_000) +# check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.1) + + +# # Without `m` as random. +# model = gdemo(1.5, 2.0) | (m = 7 / 6,) +# alg = Turing.Inference.GibbsV2((@varname(s),) => MH()) +# @info "" alg alg.varnames + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) +# @test keys(transition.θ) == (:s,) + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) +# @test keys(transition.θ) == (:s,) + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) +# @test keys(transition.θ) == (:s,) +# end + +# @testset "CSMS + ESS" begin +# rng = Random.default_rng() +# model = MoGtest_default +# alg = Turing.Inference.GibbsV2( +# (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), +# @varname(mu1) => ESS(), +# @varname(mu2) => ESS(), +# ) +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) +# @test keys(transition.θ) == (:mu1, :mu2, :z1, :z2, :z3, :z4) + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) +# @test keys(transition.θ) == (:mu1, :mu2, :z1, :z2, :z3, :z4) + +# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) +# @test keys(transition.θ) == (:mu1, :mu2, :z1, :z2, :z3, :z4) + +# # Sample! +# chain = sample(MoGtest_default, alg, 1000) +# check_MoGtest_default(chain, atol = 0.1) +# end +# end From 35f9f1538c29fcca410a499837e6ac5f27b5bca7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Oct 2023 18:25:19 +0100 Subject: [PATCH 03/58] added tests for new Gibbs --- test/mcmc/gibbs_new.jl | 273 +++++++++++++++++++++-------------------- 1 file changed, 143 insertions(+), 130 deletions(-) diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index b70207ecf..ed94aeb62 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -1,5 +1,15 @@ using Turing, DynamicPPL +function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) + transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val + [first(vn_and_val)] + end + # Varnames in `transition` should be subsumed by those in `vns`. + for vn in transition_varnames + @test any(Base.Fix2(DynamicPPL.subsumes, vn), parent_varnames) + end +end + # Okay, so what do we actually need to test here. # 1. Needs to be compatible with most models. # 2. Restricted to usage of pairs for now to make things simple. @@ -8,138 +18,141 @@ using Turing, DynamicPPL # FIXME: Currently failing for `demo_assume_index_observe`. # Likely an issue with not linking correctly. -@testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS - vns = DynamicPPL.TestUtils.varnames(model) - # Run one sampler on variables starting with `s` and another on variables starting with `m`. - vns_s = filter(vns) do vn - DynamicPPL.getsym(vn) == :s +@testset "Demo models" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + # Run one sampler on variables starting with `s` and another on variables starting with `m`. + vns_s = filter(vns) do vn + DynamicPPL.getsym(vn) == :s + end + vns_m = filter(vns) do vn + DynamicPPL.getsym(vn) == :m + end + + # Construct the sampler. + sampler = Turing.Inference.GibbsV2( + vns_s => Turing.Inference.NUTS(), + vns_m => Turing.Inference.NUTS(), + ) + + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) + check_transition_varnames(transition, vns) + + for _ = 1:10 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + transition_varnames = mapreduce(first, vcat, transition.θ) + check_transition_varnames(transition, vns) + end + end +end + +@testset "Gibbs using `condition`" begin + @testset "demo_assume_dot_observe" begin + model = DynamicPPL.TestUtils.demo_assume_dot_observe() + # Construct the different varinfos to be used. + varinfos = (SimpleVarInfo(s = 1.0), SimpleVarInfo(m = 10.0)) + # Construct the varinfo for the particular variable we want to sample. + target_varinfo = first(varinfos) + + # Create the conditional model. + conditional_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos) + + # Sample! + sampler = Turing.Inference.GibbsV2(@varname(s) => MH(), @varname(m) => MH()) + rng = Random.default_rng() + + vns = [@varname(s), @varname(m)] + + @testset "step" begin + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + check_transition_varnames(transition, vns) + end + + @testset "sample" begin + chain = sample(model, sampler, 1000) + @test size(chain, 1) == 1000 + display(mean(chain)) + end end - vns_m = filter(vns) do vn - DynamicPPL.getsym(vn) == :m + + # @testset "gdemo" begin + # Random.seed!(100) + # alg = Turing.Inference.GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) + # chain = sample(gdemo(1.5, 2.0), alg, 10_000) + # end + + # @testset "MoGtest" begin + # Random.seed!(125) + # alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) + # chain = sample(MoGtest_default, alg, 6000) + # check_MoGtest_default(chain, atol = 0.1) + # end + + @testset "multiple varnames" begin + rng = Random.default_rng() + + # With both `s` and `m` as random. + model = gdemo(1.5, 2.0) + vns = (@varname(s), @varname(m)) + alg = Turing.Inference.GibbsV2(vns => MH()) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + + # Sample. + chain = sample(model, alg, 10_000) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.1) + + + # Without `m` as random. + model = gdemo(1.5, 2.0) | (m = 7 / 6,) + vns = (@varname(s),) + alg = Turing.Inference.GibbsV2(vns => MH()) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) end - # Construct the sampler. - sampler = Turing.Inference.GibbsV2( - vns_s => Turing.Inference.NUTS(), - vns_m => Turing.Inference.NUTS(), - ) - - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - @test keys(transition.θ) == Tuple(unique(map(DynamicPPL.getsym, vns))) - - for _ = 1:10 - transition, state = - AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - @test keys(transition.θ) == Tuple(unique(map(DynamicPPL.getsym, vns))) + @testset "CSMS + ESS" begin + rng = Random.default_rng() + model = MoGtest_default + alg = Turing.Inference.GibbsV2( + (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) + vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + + # Sample! + chain = sample(MoGtest_default, alg, 1000) + check_MoGtest_default(chain, atol = 0.2) end end - -# @testset "Gibbs using `condition`" begin -# @testset "demo_assume_dot_observe" begin -# model = DynamicPPL.TestUtils.demo_assume_dot_observe() -# # Construct the different varinfos to be used. -# varinfos = (SimpleVarInfo(s = 1.0), SimpleVarInfo(m = 10.0)) -# # Construct the varinfo for the particular variable we want to sample. -# target_varinfo = first(varinfos) - -# # Create the conditional model. -# conditional_model = -# Turing.Inference.make_conditional(model, target_varinfo, varinfos) - -# # Sample! -# sampler = Turing.Inference.GibbsV2(@varname(s) => MH(), @varname(m) => MH()) -# rng = Random.default_rng() - -# @testset "step" begin -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) -# @test keys(transition.θ) == (:s, :m) - -# transition, state = -# AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) -# @test keys(transition.θ) == (:s, :m) - -# transition, state = -# AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) -# @test keys(transition.θ) == (:s, :m) -# end - -# @testset "sample" begin -# chain = sample(model, sampler, 1000) -# @test size(chain, 1) == 1000 -# display(mean(chain)) -# end -# end - -# # @testset "gdemo" begin -# # Random.seed!(100) -# # alg = Turing.Inference.GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) -# # chain = sample(gdemo(1.5, 2.0), alg, 10_000) -# # end - -# # @testset "MoGtest" begin -# # Random.seed!(125) -# # alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) -# # chain = sample(MoGtest_default, alg, 6000) -# # check_MoGtest_default(chain, atol = 0.1) -# # end - -# @testset "multiple varnames" begin -# rng = Random.default_rng() - -# # With both `s` and `m` as random. -# model = gdemo(1.5, 2.0) -# alg = Turing.Inference.GibbsV2((@varname(s), @varname(m)) => MH()) - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) -# @test keys(transition.θ) == (:s, :m) - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) -# @test keys(transition.θ) == (:s, :m) - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) -# @test keys(transition.θ) == (:s, :m) - -# # Sample. -# chain = sample(model, alg, 10_000) -# check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.1) - - -# # Without `m` as random. -# model = gdemo(1.5, 2.0) | (m = 7 / 6,) -# alg = Turing.Inference.GibbsV2((@varname(s),) => MH()) -# @info "" alg alg.varnames - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) -# @test keys(transition.θ) == (:s,) - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) -# @test keys(transition.θ) == (:s,) - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) -# @test keys(transition.θ) == (:s,) -# end - -# @testset "CSMS + ESS" begin -# rng = Random.default_rng() -# model = MoGtest_default -# alg = Turing.Inference.GibbsV2( -# (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), -# @varname(mu1) => ESS(), -# @varname(mu2) => ESS(), -# ) -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) -# @test keys(transition.θ) == (:mu1, :mu2, :z1, :z2, :z3, :z4) - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) -# @test keys(transition.θ) == (:mu1, :mu2, :z1, :z2, :z3, :z4) - -# transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) -# @test keys(transition.θ) == (:mu1, :mu2, :z1, :z2, :z3, :z4) - -# # Sample! -# chain = sample(MoGtest_default, alg, 1000) -# check_MoGtest_default(chain, atol = 0.1) -# end -# end From 072cf6b332c1e4927a5b2e02ff574c7988789e74 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Oct 2023 18:25:44 +0100 Subject: [PATCH 04/58] new Gibbs is now sampling (correctly) sequentially --- src/mcmc/Inference.jl | 1 + src/mcmc/gibbs_new.jl | 137 +++++++++++++++++++++++++----------------- 2 files changed, 82 insertions(+), 56 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 1b09668ec..95a2e9f7b 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -518,6 +518,7 @@ include("gibbs.jl") include("sghmc.jl") include("emcee.jl") include("abstractmcmc.jl") +include("gibbs_new.jl") ################ # Typing tools # diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 9713ce457..7005cd825 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -159,7 +159,6 @@ function AbstractMCMC.step( # 1. Run the model once to get the varnames present + initial values to condition on. vi_base = DynamicPPL.VarInfo(model) - @info "" varnames map(Base.Fix1(map, DynamicPPL.getsym), varnames) varinfos = map(Base.Fix1(subset, vi_base), varnames) # 2. Construct a varinfo for every vn + sampler combo. @@ -168,13 +167,12 @@ function AbstractMCMC.step( model_local = make_conditional(model, varinfo_local, varinfos) # Take initial step. - new_state_local = - last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...)) + new_state_local = last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...)) # Return the new state and the invlinked `varinfo`. vi_local_state = varinfo(new_state_local) vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) - DynamicPPL.invlink!!(deepcopy(vi_local_state), sampler_local, model_local) + DynamicPPL.invlink(vi_local_state, sampler_local, model_local) else vi_local_state end @@ -192,7 +190,66 @@ function AbstractMCMC.step( DynamicPPL.getlogp(last(varinfos)), ) - return Transition(vi), GibbsV2State(vi, states) + return Transition(model, vi), GibbsV2State(vi, states) +end + +function gibbs_step_inner( + rng::Random.AbstractRNG, + model::Model, + samplers, + states, + varinfos, + index; + kwargs..., +) + # Needs to do a a few things. + sampler_local = samplers[index] + state_local = states[index] + varinfo_local = varinfos[index] + + # 1. Create conditional model. + # Construct the conditional model. + # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, + # otherwise we're conditioning on values which are not in the support of the + # distributions. + model_local = make_conditional(model, varinfo_local, varinfos) + + # TODO: Might need to re-run the model. + # NOTE: We use `logjoint` instead of `evaluate!!` and capturing the resulting varinfo because + # the resulting varinfo might be in un-transformed space even if `varinfo_local` + # is in transformed space. This can occur if we hit `maybe_invlink_before_eval!!`. + varinfo_local = DynamicPPL.setlogp!!( + varinfo_local, + DynamicPPL.logjoint(model_local, varinfo_local), + ) + + # 2. Take step with local sampler. + # Update the state we're about to use if need be. + # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. + current_state_local = gibbs_state(model_local, sampler_local, state_local, varinfo_local) + + # Take a step. + new_state_local = last( + AbstractMCMC.step( + rng, + model_local, + sampler_local, + current_state_local; + kwargs..., + ), + ) + + # 3. Extract the new varinfo. + # Return the resulting state and invlinked `varinfo`. + varinfo_local_state = varinfo(new_state_local) + varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) + DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) + else + varinfo_local_state + end + + # TODO: alternatively, we can return `states_new, varinfos_new, index_new` + return (new_state_local, varinfo_local_state_invlinked) end function AbstractMCMC.step( @@ -204,59 +261,27 @@ function AbstractMCMC.step( ) alg = spl.alg samplers = alg.samplers - + states = state.states varinfos = map(varinfo, state.states) @assert length(samplers) == length(state.states) - - states_and_varinfos = - map(samplers, state.states, varinfos) do sampler_local, state_local, varinfo_local - # Construct the conditional model. - # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, - # otherwise we're conditioning on values which are not in the support of the - # distributions. - model_local = make_conditional(model, varinfo_local, varinfos) - - # TODO: Might need to re-run the model. - # NOTE: We use `logjoint` instead of `evaluate!!` and capturing the resulting varinfo because - # the resulting varinfo might be in un-transformed space even if `varinfo_local` - # is in transformed space. This can occur if we hit `maybe_invlink_before_eval!!`. - varinfo_local = DynamicPPL.setlogp!!( - varinfo_local, - DynamicPPL.logjoint(model_local, varinfo_local), - ) - - # Update the state we're about to use if need be. - # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - current_state_local = - gibbs_state(model_local, sampler_local, state_local, varinfo_local) - - # Take a step. - new_state_local = last( - AbstractMCMC.step( - rng, - model_local, - sampler_local, - current_state_local; - kwargs..., - ), - ) - - # Return the resulting state and invlinked `varinfo`. - # NOTE: We have to `deepcopy` to avoid potentially changing - # the varinfo in the `new_state_local`. - varinfo_local_state = deepcopy(varinfo(new_state_local)) - varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) - DynamicPPL.invlink!!(varinfo_local_state, sampler_local, model_local) - else - varinfo_local_state - end - return (new_state_local, varinfo_local_state_invlinked) - end - - states_and_varinfos = tuple(states_and_varinfos...) - states = map(first, states_and_varinfos) - varinfos = map(last, states_and_varinfos) + # TODO: move this into a recursive function so we can unroll when reasonable? + for index = 1:length(samplers) + # Take the inner step. + new_state_local, new_varinfo_local = gibbs_step_inner( + rng, + model, + samplers, + states, + varinfos, + index; + kwargs..., + ) + + # Update the `states` and `varinfos`. + states = Setfield.setindex(states, new_state_local, index) + varinfos = Setfield.setindex(varinfos, new_varinfo_local, index) + end # Combine the resulting varinfo objects. # The last varinfo holds the correctly computed logp. @@ -274,5 +299,5 @@ function AbstractMCMC.step( DynamicPPL.getlogp(last(varinfos)), ) - return Transition(vi), GibbsV2State(vi, states) + return Transition(model, vi), GibbsV2State(vi, states) end From 2f25199ca71143c9a280f6729c6bc31953ae7032 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Oct 2023 18:20:26 +0100 Subject: [PATCH 05/58] let's not overload merge just yet --- src/mcmc/gibbs_new.jl | 44 +++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 7005cd825..f631cd8fe 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -23,7 +23,7 @@ function subset(vi::DynamicPPL.SimpleVarInfo, vns::VarName...) return DynamicPPL.BangBang.@set!! vi.values = vals end -function Base.merge(md::DynamicPPL.Metadata, md_subset::DynamicPPL.Metadata) +function merge_metadata(md::DynamicPPL.Metadata, md_subset::DynamicPPL.Metadata) @assert md.vns == md_subset.vns "Cannot merge metadata with different vns." @assert length(md.vals) == length(md_subset.vals) "Cannot merge metadata with different length vals." @@ -40,7 +40,7 @@ function Base.merge(md::DynamicPPL.Metadata, md_subset::DynamicPPL.Metadata) ) end -function Base.merge( +function merge_varinfo( vi::DynamicPPL.VarInfo{<:NamedTuple{names}}, vi_subset::TypedVarInfo, ) where {names} @@ -48,7 +48,7 @@ function Base.merge( metadata_vals = map(names) do vn_sym # TODO: Make generated. return if haskey(vi_subset, VarName{vn_sym}()) - merge(vi.metadata[vn_sym], vi_subset.metadata[vn_sym]) + merge_metadata(vi.metadata[vn_sym], vi_subset.metadata[vn_sym]) else vi.metadata[vn_sym] end @@ -58,16 +58,16 @@ function Base.merge( return DynamicPPL.VarInfo(NamedTuple{names}(metadata_vals), vi.logp, vi.num_produce) end -function Base.merge(vi_left::SimpleVarInfo, vi_right::SimpleVarInfo) +function merge_varinfo(vi_left::SimpleVarInfo, vi_right::SimpleVarInfo) return SimpleVarInfo( merge(vi_left.values, vi_right.values), vi_left.logp + vi_right.logp, ) end -function Base.merge(vi_left::TypedVarInfo, vi_right::TypedVarInfo) +function merge_varinfo(vi_left::TypedVarInfo, vi_right::TypedVarInfo) return TypedVarInfo( - merge(vi_left.metadata, vi_right.metadata), + merge_metadata(vi_left.metadata, vi_right.metadata), vi_left.logp + vi_right.logp, ) end @@ -101,21 +101,29 @@ Construct a conditional model from `model` conditioned `varinfos`, excluding `va julia> model = DynamicPPL.TestUtils.demo_assume_dot_observe(); julia> # A separate varinfo for each variable in `model`. - varinfos = (SimpleVarInfo(s=1.0), SimpleVarInfo(m=10.0)); + varinfos = (DynamicPPL.SimpleVarInfo(s=1.0), DynamicPPL.SimpleVarInfo(m=10.0)); julia> # The varinfo we want to NOT condition on. target_varinfo = first(varinfos); julia> # Results in a model with only `m` conditioned. - conditional_model = make_conditional(model, target_varinfo, varinfos); + conditioned_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos); -julia> rand(conditioned_model) -``` +julia> result = conditioned_model(); + +julia> result.m == 10.0 # we conditioned on varinfo with `m = 10.0` +true +julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` +true +``` """ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return DynamicPPL.condition(model, filter(Base.Fix1(!==, target_varinfo), varinfos)...) + return DynamicPPL.condition( + model, + filter(Base.Fix1(!==, target_varinfo), varinfos)... + ) end wrap_algorithm_maybe(x) = x @@ -136,7 +144,9 @@ function GibbsV2(algs::NamedTuple) end # AbstractDict -GibbsV2(algs::AbstractDict) = GibbsV2(keys(algs), map(wrap_algorithm_maybe, values(algs))) +function GibbsV2(algs::AbstractDict) + return GibbsV2(keys(algs), map(wrap_algorithm_maybe, values(algs))) +end function GibbsV2(algs::Pair...) return GibbsV2(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end @@ -186,7 +196,7 @@ function AbstractMCMC.step( varinfos_new = DynamicPPL.setindex!!(varinfos, vi_base, 1) # Merge the updated initial varinfo with the rest of the varinfos + update the logp. vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), + reduce(merge_varinfo, varinfos_new), DynamicPPL.getlogp(last(varinfos)), ) @@ -226,7 +236,9 @@ function gibbs_step_inner( # 2. Take step with local sampler. # Update the state we're about to use if need be. # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - current_state_local = gibbs_state(model_local, sampler_local, state_local, varinfo_local) + current_state_local = gibbs_state( + model_local, sampler_local, state_local, varinfo_local + ) # Take a step. new_state_local = last( @@ -290,12 +302,12 @@ function AbstractMCMC.step( # Update the base varinfo from the first varinfo and replace it. varinfos_new = DynamicPPL.setindex!!( varinfos, - merge(vi_base, first(varinfos)), + merge_varinfo(vi_base, first(varinfos)), firstindex(varinfos), ) # Merge the updated initial varinfo with the rest of the varinfos + update the logp. vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), + reduce(merge_varinfo, varinfos_new), DynamicPPL.getlogp(last(varinfos)), ) From 7ca26bbcfda2ad8a2e0839e3b46ec765dc304081 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Oct 2023 19:01:18 +0100 Subject: [PATCH 06/58] export GibbsV2 + added more samplers to the tests --- src/Turing.jl | 1 + src/mcmc/Inference.jl | 1 + test/mcmc/gibbs_new.jl | 95 ++++++++++++++++++++++++++++-------------- 3 files changed, 65 insertions(+), 32 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index 3b15beda5..dc3b92d9b 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -74,6 +74,7 @@ export @model, # modelling ESS, Gibbs, GibbsConditional, + GibbsV2, HMC, # Hamiltonian-like sampling SGLD, diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 95a2e9f7b..e677106fa 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -48,6 +48,7 @@ export InferenceAlgorithm, Emcee, Gibbs, # classic sampling GibbsConditional, + GibbsV2, HMC, SGLD, PolynomialStepsize, diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index ed94aeb62..8952457c9 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -1,6 +1,9 @@ using Turing, DynamicPPL -function check_transition_varnames(transition::Turing.Inference.Transition, parent_varnames) +function check_transition_varnames( + transition::Turing.Inference.Transition, + parent_varnames +) transition_varnames = mapreduce(vcat, transition.θ) do vn_and_val [first(vn_and_val)] end @@ -11,11 +14,22 @@ function check_transition_varnames(transition::Turing.Inference.Transition, pare end # Okay, so what do we actually need to test here. -# 1. Needs to be compatible with most models. -# 2. Restricted to usage of pairs for now to make things simple. +# 1. (✓) Needs to be compatible with most models. +# 2. (???) Restricted to usage of pairs for now to make things simple. # TODO: Don't require usage of tuples due to potential of blowing up compilation times. +const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ + Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, + Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, + Model{typeof(DynamicPPL.TestUtils.demo_assume_dot_observe)}, + Model{typeof(DynamicPPL.TestUtils.demo_assume_observe_literal)}, + Model{typeof(DynamicPPL.TestUtils.demo_assume_literal_dot_observe)}, + Model{typeof(DynamicPPL.TestUtils.demo_assume_matrix_dot_observe_matrix)}, +} +has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false +has_dot_assume(::Model) = true + # FIXME: Currently failing for `demo_assume_index_observe`. # Likely an issue with not linking correctly. @testset "Demo models" begin @@ -29,21 +43,45 @@ end DynamicPPL.getsym(vn) == :m end - # Construct the sampler. - sampler = Turing.Inference.GibbsV2( - vns_s => Turing.Inference.NUTS(), - vns_m => Turing.Inference.NUTS(), - ) - - # Check that taking steps performs as expected. - rng = Random.default_rng() - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) - check_transition_varnames(transition, vns) + samplers = [ + GibbsV2( + vns_s => NUTS(), + vns_m => NUTS(), + ), + GibbsV2( + vns_s => NUTS(), + vns_m => HMC(0.01, 4), + ) + ] + + if !has_dot_assume(model) + # Add in some MH samplers + append!( + samplers, + [ + GibbsV2( + vns_s => HMC(0.01, 4), + vns_m => MH(), + ), + GibbsV2( + vns_s => MH(), + vns_m => HMC(0.01, 4), + ) + ] + ) + end - for _ = 1:10 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - transition_varnames = mapreduce(first, vcat, transition.θ) + @testset "$sampler" for sampler in samplers + # Check that taking steps performs as expected. + rng = Random.default_rng() + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) check_transition_varnames(transition, vns) + + for _ = 1:10 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + transition_varnames = mapreduce(first, vcat, transition.θ) + check_transition_varnames(transition, vns) + end end end end @@ -60,7 +98,7 @@ end conditional_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos) # Sample! - sampler = Turing.Inference.GibbsV2(@varname(s) => MH(), @varname(m) => MH()) + sampler = GibbsV2(@varname(s) => MH(), @varname(m) => MH()) rng = Random.default_rng() vns = [@varname(s), @varname(m)] @@ -83,18 +121,11 @@ end end end - # @testset "gdemo" begin - # Random.seed!(100) - # alg = Turing.Inference.GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) - # chain = sample(gdemo(1.5, 2.0), alg, 10_000) - # end - - # @testset "MoGtest" begin - # Random.seed!(125) - # alg = Gibbs(CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) - # chain = sample(MoGtest_default, alg, 6000) - # check_MoGtest_default(chain, atol = 0.1) - # end + @testset "gdemo" begin + Random.seed!(100) + alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + end @testset "multiple varnames" begin rng = Random.default_rng() @@ -102,7 +133,7 @@ end # With both `s` and `m` as random. model = gdemo(1.5, 2.0) vns = (@varname(s), @varname(m)) - alg = Turing.Inference.GibbsV2(vns => MH()) + alg = GibbsV2(vns => MH()) transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) @@ -121,7 +152,7 @@ end # Without `m` as random. model = gdemo(1.5, 2.0) | (m = 7 / 6,) vns = (@varname(s),) - alg = Turing.Inference.GibbsV2(vns => MH()) + alg = GibbsV2(vns => MH()) transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) @@ -136,7 +167,7 @@ end @testset "CSMS + ESS" begin rng = Random.default_rng() model = MoGtest_default - alg = Turing.Inference.GibbsV2( + alg = GibbsV2( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS(), From 8b0de2178f89e0d84d5b15384e0e200d882399d9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 8 Oct 2023 01:08:33 +0100 Subject: [PATCH 07/58] added TODO comment --- src/mcmc/gibbs_new.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index f631cd8fe..02e9291f0 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -12,6 +12,8 @@ function subset(vi::DynamicPPL.TypedVarInfo, vns::VarName...) # TODO: peform proper check of the meatdatas corresponding to different symbols. # F. ex. we might have vns `(@varname(x[1]), @varname(x[2]))`, in which case they # have the same `metadata`. If they don't, we should error. + + # TODO: Handle mixing of symbols, e.g. `(@varname(x[1]), @varname(y[1]))`. vns_unique_syms = unique_tuple(map(DynamicPPL.getsym, vns)) mds = map(Base.Fix1(DynamicPPL.getfield, vi.metadata), vns_unique_syms) return DynamicPPL.VarInfo(NamedTuple{vns_unique_syms}(mds), vi.logp, vi.num_produce) From 014fbe2c9ce24d3caf33b7ac43b3a54f7bf43fa5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 16 Nov 2023 14:13:08 +0000 Subject: [PATCH 08/58] removed lots of varinfo related merging functionality that is now available in DynamicPPL --- src/mcmc/gibbs_new.jl | 88 ++++-------------------------------------- test/mcmc/gibbs_new.jl | 77 +++++++++++++++--------------------- 2 files changed, 40 insertions(+), 125 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 02e9291f0..d5abc8d96 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -1,79 +1,3 @@ -function unique_tuple(xs::Tuple, acc::Tuple = ()) - return if Base.first(xs) ∈ acc - unique_tuple(Base.tail(xs), acc) - else - unique_tuple(Base.tail(xs), (acc..., Base.first(xs))) - end -end -unique_tuple(::Tuple{}, acc::Tuple = ()) = acc - -subset(vi::DynamicPPL.TypedVarInfo, vns::Union{Tuple,AbstractArray}) = subset(vi, vns...) -function subset(vi::DynamicPPL.TypedVarInfo, vns::VarName...) - # TODO: peform proper check of the meatdatas corresponding to different symbols. - # F. ex. we might have vns `(@varname(x[1]), @varname(x[2]))`, in which case they - # have the same `metadata`. If they don't, we should error. - - # TODO: Handle mixing of symbols, e.g. `(@varname(x[1]), @varname(y[1]))`. - vns_unique_syms = unique_tuple(map(DynamicPPL.getsym, vns)) - mds = map(Base.Fix1(DynamicPPL.getfield, vi.metadata), vns_unique_syms) - return DynamicPPL.VarInfo(NamedTuple{vns_unique_syms}(mds), vi.logp, vi.num_produce) -end - -subset(vi::DynamicPPL.SimpleVarInfo, vns::Union{Tuple,AbstractArray}) = subset(vi, vns...) -function subset(vi::DynamicPPL.SimpleVarInfo, vns::VarName...) - vals = map(Base.Fix1(getindex, vi), vns) - return DynamicPPL.BangBang.@set!! vi.values = vals -end - -function merge_metadata(md::DynamicPPL.Metadata, md_subset::DynamicPPL.Metadata) - @assert md.vns == md_subset.vns "Cannot merge metadata with different vns." - @assert length(md.vals) == length(md_subset.vals) "Cannot merge metadata with different length vals." - - # TODO: Re-adjust `ranges`, etc. so we can support things like changing support, etc. - return DynamicPPL.Metadata( - md_subset.idcs, - md_subset.vns, - md_subset.ranges, - md_subset.vals, - md_subset.dists, - md_subset.gids, - md_subset.orders, - md_subset.flags, - ) -end - -function merge_varinfo( - vi::DynamicPPL.VarInfo{<:NamedTuple{names}}, - vi_subset::TypedVarInfo, -) where {names} - # Assumes `vi` is a superset of `vi_subset`. - metadata_vals = map(names) do vn_sym - # TODO: Make generated. - return if haskey(vi_subset, VarName{vn_sym}()) - merge_metadata(vi.metadata[vn_sym], vi_subset.metadata[vn_sym]) - else - vi.metadata[vn_sym] - end - end - - # TODO: Is this the right way to do this? - return DynamicPPL.VarInfo(NamedTuple{names}(metadata_vals), vi.logp, vi.num_produce) -end - -function merge_varinfo(vi_left::SimpleVarInfo, vi_right::SimpleVarInfo) - return SimpleVarInfo( - merge(vi_left.values, vi_right.values), - vi_left.logp + vi_right.logp, - ) -end - -function merge_varinfo(vi_left::TypedVarInfo, vi_right::TypedVarInfo) - return TypedVarInfo( - merge_metadata(vi_left.metadata, vi_right.metadata), - vi_left.logp + vi_right.logp, - ) -end - # TODO: Move to DynamicPPL. DynamicPPL.condition(model::Model, varinfo::SimpleVarInfo) = DynamicPPL.condition(model, DynamicPPL.values_as(varinfo)) @@ -159,6 +83,10 @@ struct GibbsV2State{V<:AbstractVarInfo,S} states::S end +_maybevec(x) = vec(x) # assume it's iterable +_maybevec(x::Tuple) = [x...] +_maybevec(x::VarName) = [x] + function AbstractMCMC.step( rng::Random.AbstractRNG, model::Model, @@ -171,7 +99,7 @@ function AbstractMCMC.step( # 1. Run the model once to get the varnames present + initial values to condition on. vi_base = DynamicPPL.VarInfo(model) - varinfos = map(Base.Fix1(subset, vi_base), varnames) + varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) ∘ _maybevec, varnames) # 2. Construct a varinfo for every vn + sampler combo. states_and_varinfos = map(samplers, varinfos) do sampler_local, varinfo_local @@ -198,7 +126,7 @@ function AbstractMCMC.step( varinfos_new = DynamicPPL.setindex!!(varinfos, vi_base, 1) # Merge the updated initial varinfo with the rest of the varinfos + update the logp. vi = DynamicPPL.setlogp!!( - reduce(merge_varinfo, varinfos_new), + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)), ) @@ -304,12 +232,12 @@ function AbstractMCMC.step( # Update the base varinfo from the first varinfo and replace it. varinfos_new = DynamicPPL.setindex!!( varinfos, - merge_varinfo(vi_base, first(varinfos)), + merge(vi_base, first(varinfos)), firstindex(varinfos), ) # Merge the updated initial varinfo with the rest of the varinfos + update the logp. vi = DynamicPPL.setlogp!!( - reduce(merge_varinfo, varinfos_new), + reduce(merge, varinfos_new), DynamicPPL.getlogp(last(varinfos)), ) diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index 8952457c9..4cd6132bf 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -1,4 +1,4 @@ -using Turing, DynamicPPL +using Test, Random, Turing, DynamicPPL function check_transition_varnames( transition::Turing.Inference.Transition, @@ -30,7 +30,6 @@ const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ has_dot_assume(::DEMO_MODELS_WITHOUT_DOT_ASSUME) = false has_dot_assume(::Model) = true -# FIXME: Currently failing for `demo_assume_index_observe`. # Likely an issue with not linking correctly. @testset "Demo models" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS @@ -55,7 +54,7 @@ has_dot_assume(::Model) = true ] if !has_dot_assume(model) - # Add in some MH samplers + # Add in some MH samplers, which are not compatible with `.~`. append!( samplers, [ @@ -76,10 +75,8 @@ has_dot_assume(::Model) = true rng = Random.default_rng() transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) check_transition_varnames(transition, vns) - - for _ = 1:10 + for _ = 1:5 transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - transition_varnames = mapreduce(first, vcat, transition.θ) check_transition_varnames(transition, vns) end end @@ -89,42 +86,36 @@ end @testset "Gibbs using `condition`" begin @testset "demo_assume_dot_observe" begin model = DynamicPPL.TestUtils.demo_assume_dot_observe() - # Construct the different varinfos to be used. - varinfos = (SimpleVarInfo(s = 1.0), SimpleVarInfo(m = 10.0)) - # Construct the varinfo for the particular variable we want to sample. - target_varinfo = first(varinfos) - - # Create the conditional model. - conditional_model = Turing.Inference.make_conditional(model, target_varinfo, varinfos) # Sample! - sampler = GibbsV2(@varname(s) => MH(), @varname(m) => MH()) rng = Random.default_rng() - vns = [@varname(s), @varname(m)] + sampler = GibbsV2(map(Base.Fix2(Pair, MH()), vns)...) @testset "step" begin transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) check_transition_varnames(transition, vns) - - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) - - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) - check_transition_varnames(transition, vns) + for _ = 1:5 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler), state) + check_transition_varnames(transition, vns) + end end @testset "sample" begin - chain = sample(model, sampler, 1000) + chain = sample(model, sampler, 1000; progress=false) @test size(chain, 1) == 1000 display(mean(chain)) end end - @testset "gdemo" begin + @testset "gdemo with CSMC & ESS" begin + # `GibbsV2` does not work with SMC samplers, e.g. `CSMC`. + # FIXME: Oooor it is (see tests below). Uncertain. Random.seed!(100) alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) + @test_broken mean(chain[:s]) ≈ 49 / 24 + @test_broken mean(chain[:m]) ≈ 7 / 6 end @testset "multiple varnames" begin @@ -135,33 +126,30 @@ end vns = (@varname(s), @varname(m)) alg = GibbsV2(vns => MH()) + # `step` transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) + for _ = 1:5 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + end - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - - # Sample. - chain = sample(model, alg, 10_000) + # `sample` + chain = sample(model, alg, 10_000; progress=false) check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.1) - # Without `m` as random. model = gdemo(1.5, 2.0) | (m = 7 / 6,) vns = (@varname(s),) alg = GibbsV2(vns => MH()) + # `step` transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) - - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) + for _ = 1:5 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + end end @testset "CSMS + ESS" begin @@ -173,17 +161,16 @@ end @varname(mu2) => ESS(), ) vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) + # `step` transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) - - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) + for _ = 1:5 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + end # Sample! - chain = sample(MoGtest_default, alg, 1000) + chain = sample(MoGtest_default, alg, 1000; progress=true) check_MoGtest_default(chain, atol = 0.2) end end From 63d64e668202ecd09ffc5fe46924b6b8246272e8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 16 Nov 2023 14:58:53 +0000 Subject: [PATCH 09/58] shifting some code around --- src/mcmc/gibbs_new.jl | 100 +++++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index d5abc8d96..73a67e84b 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -133,6 +133,56 @@ function AbstractMCMC.step( return Transition(model, vi), GibbsV2State(vi, states) end +function AbstractMCMC.step( + rng::Random.AbstractRNG, + model::Model, + spl::Sampler{<:GibbsV2}, + state::GibbsV2State; + kwargs..., +) + alg = spl.alg + samplers = alg.samplers + states = state.states + varinfos = map(varinfo, state.states) + @assert length(samplers) == length(state.states) + + # TODO: move this into a recursive function so we can unroll when reasonable? + for index = 1:length(samplers) + # Take the inner step. + new_state_local, new_varinfo_local = gibbs_step_inner( + rng, + model, + samplers, + states, + varinfos, + index; + kwargs..., + ) + + # Update the `states` and `varinfos`. + states = Setfield.setindex(states, new_state_local, index) + varinfos = Setfield.setindex(varinfos, new_varinfo_local, index) + end + + # Combine the resulting varinfo objects. + # The last varinfo holds the correctly computed logp. + vi_base = state.vi + + # Update the base varinfo from the first varinfo and replace it. + varinfos_new = DynamicPPL.setindex!!( + varinfos, + merge(vi_base, first(varinfos)), + firstindex(varinfos), + ) + # Merge the updated initial varinfo with the rest of the varinfos + update the logp. + vi = DynamicPPL.setlogp!!( + reduce(merge, varinfos_new), + DynamicPPL.getlogp(last(varinfos)), + ) + + return Transition(model, vi), GibbsV2State(vi, states) +end + function gibbs_step_inner( rng::Random.AbstractRNG, model::Model, @@ -193,53 +243,3 @@ function gibbs_step_inner( # TODO: alternatively, we can return `states_new, varinfos_new, index_new` return (new_state_local, varinfo_local_state_invlinked) end - -function AbstractMCMC.step( - rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:GibbsV2}, - state::GibbsV2State; - kwargs..., -) - alg = spl.alg - samplers = alg.samplers - states = state.states - varinfos = map(varinfo, state.states) - @assert length(samplers) == length(state.states) - - # TODO: move this into a recursive function so we can unroll when reasonable? - for index = 1:length(samplers) - # Take the inner step. - new_state_local, new_varinfo_local = gibbs_step_inner( - rng, - model, - samplers, - states, - varinfos, - index; - kwargs..., - ) - - # Update the `states` and `varinfos`. - states = Setfield.setindex(states, new_state_local, index) - varinfos = Setfield.setindex(varinfos, new_varinfo_local, index) - end - - # Combine the resulting varinfo objects. - # The last varinfo holds the correctly computed logp. - vi_base = state.vi - - # Update the base varinfo from the first varinfo and replace it. - varinfos_new = DynamicPPL.setindex!!( - varinfos, - merge(vi_base, first(varinfos)), - firstindex(varinfos), - ) - # Merge the updated initial varinfo with the rest of the varinfos + update the logp. - vi = DynamicPPL.setlogp!!( - reduce(merge, varinfos_new), - DynamicPPL.getlogp(last(varinfos)), - ) - - return Transition(model, vi), GibbsV2State(vi, states) -end From 0dcd5bff0f70b908f6a4f0adbd9b9dc8ac7ad029 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 16 Nov 2023 15:34:28 +0000 Subject: [PATCH 10/58] removed redundant constructor for GibbsV2 --- src/mcmc/gibbs_new.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 73a67e84b..53aae23b3 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -76,7 +76,6 @@ end function GibbsV2(algs::Pair...) return GibbsV2(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end -GibbsV2(algs::Tuple) = GibbsV2(Dict(algs)) struct GibbsV2State{V<:AbstractVarInfo,S} vi::V @@ -204,7 +203,6 @@ function gibbs_step_inner( # distributions. model_local = make_conditional(model, varinfo_local, varinfos) - # TODO: Might need to re-run the model. # NOTE: We use `logjoint` instead of `evaluate!!` and capturing the resulting varinfo because # the resulting varinfo might be in un-transformed space even if `varinfo_local` # is in transformed space. This can occur if we hit `maybe_invlink_before_eval!!`. From c29efc11d06c311cce99c6600821e28c71aa33d9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 18 Nov 2023 16:09:04 +0000 Subject: [PATCH 11/58] added GibbsContext which is similar to FixContext but also computes the log-prob of the fixed variables --- src/mcmc/gibbs_new.jl | 122 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 106 insertions(+), 16 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 53aae23b3..9016602a5 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -1,20 +1,110 @@ -# TODO: Move to DynamicPPL. -DynamicPPL.condition(model::Model, varinfo::SimpleVarInfo) = - DynamicPPL.condition(model, DynamicPPL.values_as(varinfo)) -function DynamicPPL.condition(model::Model, varinfo::VarInfo) - # Use `OrderedDict` as default for `VarInfo`. - # TODO: Do better! - return DynamicPPL.condition(model, DynamicPPL.values_as(varinfo, OrderedDict)) -end - -# Recursive definition. -function DynamicPPL.condition(model::Model, varinfos::AbstractVarInfo...) - return DynamicPPL.condition( - DynamicPPL.condition(model, first(varinfos)), - Base.tail(varinfos)..., +# Basically like a `DynamicPPL.FixedContext` but +# 1. Hijacks the tilde pipeline to fix variables. +# 2. Computes the log-probability of the fixed variables. +struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + values::Values + context::Ctx +end + +Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) + +DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::GibbsContext) = context.context +DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) + +# has and get +has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) +function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(has_conditioned_gibbs, context), vns) +end + +get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) +function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(get_conditioned_gibbs, context), vns) +end + +# Tilde pipeline +function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) +end + +function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) +end + +function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + # FIXME: This probably won't work as is. + @info "dot_tilde_assume" vns value + if has_conditioned_gibbs(context, vns) + value = get_conditioned_gibbs(context, vns) + return value, sum(logpdf.(right, value)), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) +end + +function DynamicPPL.dot_tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + values = get_conditioned_gibbs(context, vns) + return values, sum(logpdf.(right, values)), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) +end + + +preferred_value_type(::AbstractVarInfo) = OrderedDict +preferred_value_type(::SimpleVarInfo{<:NamedTuple}) = NamedTuple +function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) + # We can only do this in the scenario where all the varnames are `Setfield.IdentityLens`. + namedtuple_compatible = all(varinfo.metadata) do md + eltype(md.vns) <: VarName{<:Any,DynamicPPL.Setfield.IdentityLens} + end + return namedtuple_compatible ? NamedTuple : OrderedDict +end + +# No-op if no values are provided. +condition_gibbs(context::DynamicPPL.AbstractContext) = context +# For `NamedTuple` and `AbstractDict` we just construct the context. +function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) + return GibbsContext(values, context) +end +# If we get more than one argument, we just recurse. +function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) + return condition_gibbs( + condition_gibbs(context, value), + values... ) end -DynamicPPL.condition(model::Model, ::Tuple{}) = model +# For `AbstractVarInfo` we just extract the values. +function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) + # TODO: Determine when it's okay to use `NamedTuple` and use that instead. + return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) +end +# Allow calling this on a `Model` directly. +function condition_gibbs(model::Model, values...) + return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) +end """ @@ -46,7 +136,7 @@ true """ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return DynamicPPL.condition( + return condition_gibbs( model, filter(Base.Fix1(!==, target_varinfo), varinfos)... ) From bff078637b80d34ab8996799325c86002d6d7fcb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 15:33:01 +0000 Subject: [PATCH 12/58] adopted the rerun mechanism in Gibbs for GibbsV2, thus fixing the issues with some of the tests for GibbsV2 --- src/mcmc/gibbs_new.jl | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 9016602a5..750b20e70 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -176,10 +176,11 @@ _maybevec(x) = vec(x) # assume it's iterable _maybevec(x::Tuple) = [x...] _maybevec(x::VarName) = [x] -function AbstractMCMC.step( +function DynamicPPL.initialstep( rng::Random.AbstractRNG, model::Model, - spl::Sampler{<:GibbsV2}; + spl::Sampler{<:GibbsV2}, + vi_base::AbstractVarInfo; kwargs..., ) alg = spl.alg @@ -272,6 +273,14 @@ function AbstractMCMC.step( return Transition(model, vi), GibbsV2State(vi, states) end +function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, sampler_previous::DynamicPPL.Sampler) + selector = DynamicPPL.Selector( + Symbol(typeof(sampler.alg)), + gibbs_rerun(sampler_previous.alg, sampler.alg) + ) + return DynamicPPL.Sampler(sampler.alg, model, selector) +end + function gibbs_step_inner( rng::Random.AbstractRNG, model::Model, @@ -286,6 +295,8 @@ function gibbs_step_inner( state_local = states[index] varinfo_local = varinfos[index] + # We need the previous sampler to determine whether we'll need to rerun. + sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] # 1. Create conditional model. # Construct the conditional model. # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, @@ -296,11 +307,20 @@ function gibbs_step_inner( # NOTE: We use `logjoint` instead of `evaluate!!` and capturing the resulting varinfo because # the resulting varinfo might be in un-transformed space even if `varinfo_local` # is in transformed space. This can occur if we hit `maybe_invlink_before_eval!!`. - varinfo_local = DynamicPPL.setlogp!!( - varinfo_local, - DynamicPPL.logjoint(model_local, varinfo_local), - ) + # Re-run the sampler if needed. + if gibbs_rerun(sampler_local, sampler_previous) + # Make the re-run sampler. + # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, + # e.g. log-likelihood in the scenario of `ESS`. + # TODO: Check if `sampler_rerun` should be replacing `sampler_local` or not. + sampler_rerun = make_rerun_sampler(model_local, sampler_local, sampler_previous) + varinfo_local = last(DynamicPPL.evaluate!!( + model_local, + varinfo_local, + DynamicPPL.SamplingContext(rng, sampler_rerun) + )) + end # 2. Take step with local sampler. # Update the state we're about to use if need be. # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. From 5955167222da31125f0674c7b1b9a0f569b27e21 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 15:36:15 +0000 Subject: [PATCH 13/58] broken tests are no longer broken --- test/mcmc/gibbs_new.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index 4cd6132bf..977a221fb 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -114,8 +114,7 @@ end Random.seed!(100) alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) - @test_broken mean(chain[:s]) ≈ 49 / 24 - @test_broken mean(chain[:m]) ≈ 7 / 6 + check_gdemo(chain) end @testset "multiple varnames" begin From 614dc5279f040570b1eea8d8a1291f51a414f7c3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 20:38:06 +0000 Subject: [PATCH 14/58] fix issues with dot_tilde_* impls for GibbsContext --- src/mcmc/gibbs_new.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 750b20e70..f57a9cd29 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -46,13 +46,15 @@ function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) end +make_broadcastable(x) = x +make_broadcastable(dist::Distribution) = tuple(dist) + function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) # Short-circuits the tilde assume if `vn` is present in `context`. # FIXME: This probably won't work as is. - @info "dot_tilde_assume" vns value if has_conditioned_gibbs(context, vns) value = get_conditioned_gibbs(context, vns) - return value, sum(logpdf.(right, value)), vi + return value, sum(logpdf.(make_broadcastable(right), value)), vi end # Otherwise, falls back to the default behavior. @@ -65,7 +67,7 @@ function DynamicPPL.dot_tilde_assume( # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vns) values = get_conditioned_gibbs(context, vns) - return values, sum(logpdf.(right, values)), vi + return values, sum(logpdf.(make_broadcastable(right), values)), vi end # Otherwise, falls back to the default behavior. From 7776f8277fa83a8c9fb015278519a1a6ee1ca1da Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 20:56:41 +0000 Subject: [PATCH 15/58] fix for dot_tilde_assume when using GibbsContext --- src/mcmc/gibbs_new.jl | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index f57a9cd29..741b1f1ed 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -46,15 +46,29 @@ function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) end +# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. make_broadcastable(x) = x make_broadcastable(dist::Distribution) = tuple(dist) +# Need the following two methods to properly support broadcasting over columns. +broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) +function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) + return loglikelihood(dist, x) +end + +reconstruct_getvalue(dist, x) = x +function reconstruct_getvalue( + dist::MultivariateDistribution, + x::AbstractVector{<:AbstractVector{<:Real}} +) + return reduce(hcat, x[2:end]; init=x[1]) +end + function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) # Short-circuits the tilde assume if `vn` is present in `context`. - # FIXME: This probably won't work as is. if has_conditioned_gibbs(context, vns) - value = get_conditioned_gibbs(context, vns) - return value, sum(logpdf.(make_broadcastable(right), value)), vi + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, values), vi end # Otherwise, falls back to the default behavior. @@ -66,8 +80,8 @@ function DynamicPPL.dot_tilde_assume( ) # Short-circuits the tilde assume if `vn` is present in `context`. if has_conditioned_gibbs(context, vns) - values = get_conditioned_gibbs(context, vns) - return values, sum(logpdf.(make_broadcastable(right), values)), vi + values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return values, broadcast_logpdf(right, values), vi end # Otherwise, falls back to the default behavior. From c65b7e93a4c54bf586d98d6ad329d7c08b08ad7a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 20:57:13 +0000 Subject: [PATCH 16/58] fixed re-running of models for Gibbs sampling properly this time --- src/mcmc/gibbs_new.jl | 49 ++++++++++++++++++++++++++----------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 741b1f1ed..148bb436e 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -290,13 +290,35 @@ function AbstractMCMC.step( end function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, sampler_previous::DynamicPPL.Sampler) - selector = DynamicPPL.Selector( - Symbol(typeof(sampler.alg)), - gibbs_rerun(sampler_previous.alg, sampler.alg) - ) - return DynamicPPL.Sampler(sampler.alg, model, selector) + # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide + # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact + # same `selector` as before but now with `rerun` set to `true` if needed. + return DynamicPPL.Setfield.@set sampler.selector.rerun = gibbs_rerun(sampler_previous.alg, sampler.alg) end +function gibbs_rerun_maybe( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + sampler::DynamicPPL.Sampler, + sampler_previous::DynamicPPL.Sampler, + varinfo::AbstractVarInfo, +) + # Return early if we don't need it. + gibbs_rerun(sampler, sampler_previous) || return varinfo + + # Make the re-run sampler. + # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, + # e.g. log-likelihood in the scenario of `ESS`. + # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. + sampler_rerun = make_rerun_sampler(model, sampler, sampler_previous) + # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed + # `varinfo`, even if `varinfo` was linked. + return last(DynamicPPL.evaluate!!( + model, + varinfo, + DynamicPPL.SamplingContext(rng, sampler_rerun) + )) +end function gibbs_step_inner( rng::Random.AbstractRNG, model::Model, @@ -311,8 +333,6 @@ function gibbs_step_inner( state_local = states[index] varinfo_local = varinfos[index] - # We need the previous sampler to determine whether we'll need to rerun. - sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] # 1. Create conditional model. # Construct the conditional model. # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, @@ -325,18 +345,9 @@ function gibbs_step_inner( # is in transformed space. This can occur if we hit `maybe_invlink_before_eval!!`. # Re-run the sampler if needed. - if gibbs_rerun(sampler_local, sampler_previous) - # Make the re-run sampler. - # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, - # e.g. log-likelihood in the scenario of `ESS`. - # TODO: Check if `sampler_rerun` should be replacing `sampler_local` or not. - sampler_rerun = make_rerun_sampler(model_local, sampler_local, sampler_previous) - varinfo_local = last(DynamicPPL.evaluate!!( - model_local, - varinfo_local, - DynamicPPL.SamplingContext(rng, sampler_rerun) - )) - end + sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] + varinfo_local = gibbs_rerun_maybe(rng, model_local, sampler_local, sampler_previous, varinfo_local) + # 2. Take step with local sampler. # Update the state we're about to use if need be. # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. From 16ddca2027d8dac0256d23399f2326a4f03fcbaa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 23:18:27 +0000 Subject: [PATCH 17/58] added new gibbs to tests --- test/runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/runtests.jl b/test/runtests.jl index b57ccda73..76c77775d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,6 +74,7 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ @testset "samplers" begin @timeit_include("mcmc/gibbs.jl") @timeit_include("mcmc/gibbs_conditional.jl") + @timeit_include("mcmc/gibbs_new.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") From ba8c6e13759447aae41164909e955f35fb1b5834 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 23:27:44 +0000 Subject: [PATCH 18/58] added some further comments on why we need `GibbsContext` --- src/mcmc/gibbs_new.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 148bb436e..498adf60a 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -1,6 +1,15 @@ # Basically like a `DynamicPPL.FixedContext` but # 1. Hijacks the tilde pipeline to fix variables. # 2. Computes the log-probability of the fixed variables. +# +# Purpose: avoid triggering resampling of variables we're conditioning on. +# - Using standard `DynamicPPL.condition` results in conditioned variables being treated +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to +# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable +# rather than only for the "true" observations. +# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline +# rather than the `observe` pipeline for the conditioned variables. struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext values::Values context::Ctx From 53bd7072dbbdfd52713da35b8be4100390a6a43e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 23:28:58 +0000 Subject: [PATCH 19/58] went back to using `DynamicPPL.condition` rather than using custom `GibbsContext` while we wait for https://github.com/TuringLang/DynamicPPL.jl/pull/563 to be merged --- src/mcmc/gibbs_new.jl | 121 +----------------------------------------- 1 file changed, 2 insertions(+), 119 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 498adf60a..6dc0ac022 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -1,103 +1,3 @@ -# Basically like a `DynamicPPL.FixedContext` but -# 1. Hijacks the tilde pipeline to fix variables. -# 2. Computes the log-probability of the fixed variables. -# -# Purpose: avoid triggering resampling of variables we're conditioning on. -# - Using standard `DynamicPPL.condition` results in conditioned variables being treated -# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. -# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to -# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable -# rather than only for the "true" observations. -# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline -# rather than the `observe` pipeline for the conditioned variables. -struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext - values::Values - context::Ctx -end - -Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) - -DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() -DynamicPPL.childcontext(context::GibbsContext) = context.context -DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) - -# has and get -has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) -function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return all(Base.Fix1(has_conditioned_gibbs, context), vns) -end - -get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) -function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) - return map(Base.Fix1(get_conditioned_gibbs, context), vns) -end - -# Tilde pipeline -function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) -end - -function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vn) - value = get_conditioned_gibbs(context, vn) - return value, logpdf(right, value), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) -end - -# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. -make_broadcastable(x) = x -make_broadcastable(dist::Distribution) = tuple(dist) - -# Need the following two methods to properly support broadcasting over columns. -broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) -function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) - return loglikelihood(dist, x) -end - -reconstruct_getvalue(dist, x) = x -function reconstruct_getvalue( - dist::MultivariateDistribution, - x::AbstractVector{<:AbstractVector{<:Real}} -) - return reduce(hcat, x[2:end]; init=x[1]) -end - -function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return value, broadcast_logpdf(right, values), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) -end - -function DynamicPPL.dot_tilde_assume( - rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi -) - # Short-circuits the tilde assume if `vn` is present in `context`. - if has_conditioned_gibbs(context, vns) - values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) - return values, broadcast_logpdf(right, values), vi - end - - # Otherwise, falls back to the default behavior. - return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) -end - - preferred_value_type(::AbstractVarInfo) = OrderedDict preferred_value_type(::SimpleVarInfo{<:NamedTuple}) = NamedTuple function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) @@ -108,28 +8,10 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : OrderedDict end -# No-op if no values are provided. -condition_gibbs(context::DynamicPPL.AbstractContext) = context -# For `NamedTuple` and `AbstractDict` we just construct the context. -function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) - return GibbsContext(values, context) -end -# If we get more than one argument, we just recurse. -function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) - return condition_gibbs( - condition_gibbs(context, value), - values... - ) -end -# For `AbstractVarInfo` we just extract the values. -function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) +function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) # TODO: Determine when it's okay to use `NamedTuple` and use that instead. return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end -# Allow calling this on a `Model` directly. -function condition_gibbs(model::Model, values...) - return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) -end """ @@ -162,6 +44,7 @@ true function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. return condition_gibbs( + return condition( model, filter(Base.Fix1(!==, target_varinfo), varinfos)... ) From b38a82a797fb2dd5ff7a6476f1edda9be6d45ba0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 19 Nov 2023 23:29:59 +0000 Subject: [PATCH 20/58] add concrete comment about reverting changes for `gibbs_condition` --- src/mcmc/gibbs_new.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 6dc0ac022..10f56d771 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -43,7 +43,8 @@ true """ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - return condition_gibbs( + # FIXME: Revert commit 53bd7072 and use `gibbs_condition` as soon as + # https://github.com/TuringLang/DynamicPPL.jl/pull/563 is merged. return condition( model, filter(Base.Fix1(!==, target_varinfo), varinfos)... From adc67be18c818768fc47a83554216b860f178e9a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 20 Nov 2023 12:28:59 +0000 Subject: [PATCH 21/58] Update test/mcmc/gibbs_new.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- test/mcmc/gibbs_new.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index 977a221fb..c5d0bca54 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -151,7 +151,7 @@ end end end - @testset "CSMS + ESS" begin + @testset "CSMC + ESS" begin rng = Random.default_rng() model = MoGtest_default alg = GibbsV2( From 3b5a74c80dac1f69ed02031e2165ffb2ef276825 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 20 Nov 2023 17:37:40 +0000 Subject: [PATCH 22/58] fixed recursive definition of `condition` varinfos --- src/mcmc/gibbs_new.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 10f56d771..0884aa02a 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -10,7 +10,14 @@ end function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) # TODO: Determine when it's okay to use `NamedTuple` and use that instead. - return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) + return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) +end +function DynamicPPL.condition( + context::DynamicPPL.AbstractContext, + varinfo::AbstractVarInfo, + varinfos::AbstractVarInfo... +) + return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...) end From f87e2d1195d603b43da4f75721a1bd6c23b2a10f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 20 Nov 2023 19:22:38 +0000 Subject: [PATCH 23/58] use `fix` instead of `condition` --- src/mcmc/gibbs_new.jl | 10 +++++----- test/mcmc/gibbs_new.jl | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 0884aa02a..b7fe8e291 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -8,16 +8,16 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : OrderedDict end -function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) +function DynamicPPL.fix(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) # TODO: Determine when it's okay to use `NamedTuple` and use that instead. - return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) + return DynamicPPL.fix(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end -function DynamicPPL.condition( +function DynamicPPL.fix( context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo, varinfos::AbstractVarInfo... ) - return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...) + return DynamicPPL.fix(DynamicPPL.fix(context, varinfo), varinfos...) end @@ -52,7 +52,7 @@ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfo # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. # FIXME: Revert commit 53bd7072 and use `gibbs_condition` as soon as # https://github.com/TuringLang/DynamicPPL.jl/pull/563 is merged. - return condition( + return fix( model, filter(Base.Fix1(!==, target_varinfo), varinfos)... ) diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index c5d0bca54..a60df0076 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -112,7 +112,7 @@ end # `GibbsV2` does not work with SMC samplers, e.g. `CSMC`. # FIXME: Oooor it is (see tests below). Uncertain. Random.seed!(100) - alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) + alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain) end From 1c1d9b767d97fc24eeae27c733a13c83c47863a4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 10:06:14 +0000 Subject: [PATCH 24/58] Revert "use `fix` instead of `condition`" This reverts commit f87e2d1195d603b43da4f75721a1bd6c23b2a10f. --- src/mcmc/gibbs_new.jl | 10 +++++----- test/mcmc/gibbs_new.jl | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index b7fe8e291..0884aa02a 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -8,16 +8,16 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : OrderedDict end -function DynamicPPL.fix(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) +function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) # TODO: Determine when it's okay to use `NamedTuple` and use that instead. - return DynamicPPL.fix(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) + return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end -function DynamicPPL.fix( +function DynamicPPL.condition( context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo, varinfos::AbstractVarInfo... ) - return DynamicPPL.fix(DynamicPPL.fix(context, varinfo), varinfos...) + return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...) end @@ -52,7 +52,7 @@ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfo # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. # FIXME: Revert commit 53bd7072 and use `gibbs_condition` as soon as # https://github.com/TuringLang/DynamicPPL.jl/pull/563 is merged. - return fix( + return condition( model, filter(Base.Fix1(!==, target_varinfo), varinfos)... ) diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index a60df0076..c5d0bca54 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -112,7 +112,7 @@ end # `GibbsV2` does not work with SMC samplers, e.g. `CSMC`. # FIXME: Oooor it is (see tests below). Uncertain. Random.seed!(100) - alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS()) + alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain) end From 6be7ab9af2064b3555fddb1e9af487c01e9c6173 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 10:06:36 +0000 Subject: [PATCH 25/58] rmeoved unnused symbol --- test/mcmc/gibbs_new.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index c5d0bca54..a60df0076 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -112,7 +112,7 @@ end # `GibbsV2` does not work with SMC samplers, e.g. `CSMC`. # FIXME: Oooor it is (see tests below). Uncertain. Random.seed!(100) - alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS(:m)) + alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain) end From a143cc417c2017da3e330427f5930fc4d82cbd1f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 13:57:46 +0000 Subject: [PATCH 26/58] Revert "went back to using `DynamicPPL.condition` rather than using custom" This reverts commit 53bd7072dbbdfd52713da35b8be4100390a6a43e. --- src/mcmc/gibbs_new.jl | 124 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 120 insertions(+), 4 deletions(-) diff --git a/src/mcmc/gibbs_new.jl b/src/mcmc/gibbs_new.jl index 0884aa02a..28a211a1d 100644 --- a/src/mcmc/gibbs_new.jl +++ b/src/mcmc/gibbs_new.jl @@ -1,3 +1,103 @@ +# Basically like a `DynamicPPL.FixedContext` but +# 1. Hijacks the tilde pipeline to fix variables. +# 2. Computes the log-probability of the fixed variables. +# +# Purpose: avoid triggering resampling of variables we're conditioning on. +# - Using standard `DynamicPPL.condition` results in conditioned variables being treated +# as observations in the truest sense, i.e. we hit `DynamicPPL.tilde_observe`. +# - But `observe` is overloaded by some samplers, e.g. `CSMC`, which can lead to +# undesirable behavior, e.g. `CSMC` triggering a resampling for every conditioned variable +# rather than only for the "true" observations. +# - `GibbsContext` allows us to perform conditioning while still hit the `assume` pipeline +# rather than the `observe` pipeline for the conditioned variables. +struct GibbsContext{Values,Ctx<:DynamicPPL.AbstractContext} <: DynamicPPL.AbstractContext + values::Values + context::Ctx +end + +Gibbscontext(values) = GibbsContext(values, DynamicPPL.DefaultContext()) + +DynamicPPL.NodeTrait(::GibbsContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::GibbsContext) = context.context +DynamicPPL.setchildcontext(context::GibbsContext, childcontext) = GibbsContext(context.values, childcontext) + +# has and get +has_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.hasvalue(context.values, vn) +function has_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return all(Base.Fix1(has_conditioned_gibbs, context), vns) +end + +get_conditioned_gibbs(context::GibbsContext, vn::VarName) = DynamicPPL.getvalue(context.values, vn) +function get_conditioned_gibbs(context::GibbsContext, vns::AbstractArray{<:VarName}) + return map(Base.Fix1(get_conditioned_gibbs, context), vns) +end + +# Tilde pipeline +function DynamicPPL.tilde_assume(context::GibbsContext, right, vn, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume(DynamicPPL.childcontext(context), right, vn, vi) +end + +function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, context::GibbsContext, sampler, right, vn, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vn) + value = get_conditioned_gibbs(context, vn) + return value, logpdf(right, value), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, vn, vi) +end + +# Some utility methods for handling the `logpdf` computations in dot-tilde the pipeline. +make_broadcastable(x) = x +make_broadcastable(dist::Distribution) = tuple(dist) + +# Need the following two methods to properly support broadcasting over columns. +broadcast_logpdf(dist, x) = sum(logpdf.(make_broadcastable(dist), x)) +function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) + return loglikelihood(dist, x) +end + +reconstruct_getvalue(dist, x) = x +function reconstruct_getvalue( + dist::MultivariateDistribution, + x::AbstractVector{<:AbstractVector{<:Real}} +) + return reduce(hcat, x[2:end]; init=x[1]) +end + +function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return value, broadcast_logpdf(right, values), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume(DynamicPPL.childcontext(context), right, left, vns, vi) +end + +function DynamicPPL.dot_tilde_assume( + rng::Random.AbstractRNG, context::GibbsContext, sampler, right, left, vns, vi +) + # Short-circuits the tilde assume if `vn` is present in `context`. + if has_conditioned_gibbs(context, vns) + values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns)) + return values, broadcast_logpdf(right, values), vi + end + + # Otherwise, falls back to the default behavior. + return DynamicPPL.dot_tilde_assume(rng, DynamicPPL.childcontext(context), sampler, right, left, vns, vi) +end + + preferred_value_type(::AbstractVarInfo) = OrderedDict preferred_value_type(::SimpleVarInfo{<:NamedTuple}) = NamedTuple function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) @@ -8,7 +108,21 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : OrderedDict end -function DynamicPPL.condition(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) +# No-op if no values are provided. +condition_gibbs(context::DynamicPPL.AbstractContext) = context +# For `NamedTuple` and `AbstractDict` we just construct the context. +function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) + return GibbsContext(values, context) +end +# If we get more than one argument, we just recurse. +function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) + return condition_gibbs( + condition_gibbs(context, value), + values... + ) +end +# For `AbstractVarInfo` we just extract the values. +function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) # TODO: Determine when it's okay to use `NamedTuple` and use that instead. return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end @@ -19,6 +133,10 @@ function DynamicPPL.condition( ) return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...) end +# Allow calling this on a `Model` directly. +function condition_gibbs(model::Model, values...) + return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) +end """ @@ -50,9 +168,7 @@ true """ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. - # FIXME: Revert commit 53bd7072 and use `gibbs_condition` as soon as - # https://github.com/TuringLang/DynamicPPL.jl/pull/563 is merged. - return condition( + return condition_gibbs( model, filter(Base.Fix1(!==, target_varinfo), varinfos)... ) From 4d37f5fadaf532691224a75983097e7b126fdc0c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 13:57:51 +0000 Subject: [PATCH 27/58] bump compat entry of DynamicPPL so we can overload acclogp! --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 041b2635a..2cd9b0474 100644 --- a/Project.toml +++ b/Project.toml @@ -58,7 +58,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.24" +DynamicPPL = "0.24.2" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" From f39e636692fa51fb2c13c628fdabccad21162aff Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 14:01:32 +0000 Subject: [PATCH 28/58] update assume for SMC samplers to make use of new `acclogp!` --- src/mcmc/particle_mcmc.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 29deb73cc..8e67fb1d4 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -363,6 +363,8 @@ function DynamicPPL.assume( DynamicPPL.updategid!(vi, vn, spl) # Pick data from reference particle r = vi[vn] end + # TODO: Should we make this `zero(promote_type(eltype(dist), eltype(r)))` or something? + lp = 0 else # vn belongs to other sampler <=> conditioning on vn if haskey(vi, vn) r = vi[vn] @@ -371,9 +373,8 @@ function DynamicPPL.assume( push!!(vi, vn, r, dist, DynamicPPL.Selector(:invalid)) end lp = logpdf_with_trans(dist, r, istrans(vi, vn)) - acclogp!!(vi, lp) end - return r, 0, vi + return r, lp, vi end function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) From 9cc5397a4493bd92a85139cce91fbdb8797a43ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 21 Nov 2023 14:18:45 +0000 Subject: [PATCH 29/58] added proper impl of acclogp!! for SMC samplers + made accessing task local varinfo and rng a bit nicer --- src/mcmc/particle_mcmc.jl | 49 +++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 8e67fb1d4..5ed2d53cd 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -327,27 +327,43 @@ end DynamicPPL.use_threadsafe_eval(::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, ::AbstractVarInfo) = false -function DynamicPPL.assume( - rng, - spl::Sampler{<:Union{PG,SMC}}, - dist::Distribution, - vn::VarName, - __vi__::AbstractVarInfo -) - local vi, trng +function trace_local_varinfo_maybe(varinfo) try trace = AdvancedPS.current_trace() - trng = trace.rng - vi = trace.model.f.varinfo + return trace.model.f.varinfo catch e # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. if e == KeyError(:__trace) || current_task().storage isa Nothing - vi = __vi__ - trng = rng + return varinfo else rethrow(e) end end +end + +function trace_local_rng_maybe(rng::Random.AbstractRNG) + try + trace = AdvancedPS.current_trace() + return trace.rng + catch e + # NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. + if e == KeyError(:__trace) || current_task().storage isa Nothing + return rng + else + rethrow(e) + end + end +end + +function DynamicPPL.assume( + rng, + spl::Sampler{<:Union{PG,SMC}}, + dist::Distribution, + vn::VarName, + _vi::AbstractVarInfo +) + vi = trace_local_varinfo_maybe(_vi) + trng = trace_local_rng_maybe(rng) if inspace(vn, spl) if ~haskey(vi, vn) @@ -382,6 +398,15 @@ function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, v return 0, vi end +function DynamicPPL.acclogp!!( + context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, + varinfo::AbstractVarInfo, + logp +) + varinfo_trace = trace_local_varinfo_maybe(varinfo) + DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) +end + # Convenient constructor function AdvancedPS.Trace( model::Model, From c93944c7ef06ad7192b9fac24316b8dd4f8c0c7d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Nov 2023 08:00:50 +0000 Subject: [PATCH 30/58] added experimental module and moved gibbs to it --- src/experimental/Experimental.jl | 15 +++++++++++++++ src/{mcmc/gibbs_new.jl => experimental/gibbs.jl} | 0 2 files changed, 15 insertions(+) create mode 100644 src/experimental/Experimental.jl rename src/{mcmc/gibbs_new.jl => experimental/gibbs.jl} (100%) diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl new file mode 100644 index 000000000..5d390475e --- /dev/null +++ b/src/experimental/Experimental.jl @@ -0,0 +1,15 @@ +module Experimental + +using Random: Random +using AbstractMCMC: AbstractMCMC +using DynamicPPL: DynamicPPL, VarName +using Setfield: Setfield + +using Distributions + +using ..Turing: Turing +using ..Turing.Inference: gibbs_rerun, InferenceAlgorithm + +include("gibbs.jl") + +end diff --git a/src/mcmc/gibbs_new.jl b/src/experimental/gibbs.jl similarity index 100% rename from src/mcmc/gibbs_new.jl rename to src/experimental/gibbs.jl From e6ce62f1369e50739ca191f26c927699c3c04a7d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Nov 2023 08:01:22 +0000 Subject: [PATCH 31/58] fixed now-inccorect references in new gibbs file --- src/experimental/gibbs.jl | 72 +++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index 28a211a1d..89ad20710 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -98,14 +98,14 @@ function DynamicPPL.dot_tilde_assume( end -preferred_value_type(::AbstractVarInfo) = OrderedDict -preferred_value_type(::SimpleVarInfo{<:NamedTuple}) = NamedTuple +preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict +preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) # We can only do this in the scenario where all the varnames are `Setfield.IdentityLens`. namedtuple_compatible = all(varinfo.metadata) do md - eltype(md.vns) <: VarName{<:Any,DynamicPPL.Setfield.IdentityLens} + eltype(md.vns) <: VarName{<:Any,Setfield.IdentityLens} end - return namedtuple_compatible ? NamedTuple : OrderedDict + return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict end # No-op if no values are provided. @@ -121,20 +121,20 @@ function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) values... ) end -# For `AbstractVarInfo` we just extract the values. -function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::AbstractVarInfo) +# For `DynamicPPL.AbstractVarInfo` we just extract the values. +function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) # TODO: Determine when it's okay to use `NamedTuple` and use that instead. return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end function DynamicPPL.condition( context::DynamicPPL.AbstractContext, - varinfo::AbstractVarInfo, - varinfos::AbstractVarInfo... + varinfo::DynamicPPL.AbstractVarInfo, + varinfos::DynamicPPL.AbstractVarInfo... ) return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...) end -# Allow calling this on a `Model` directly. -function condition_gibbs(model::Model, values...) +# Allow calling this on a `DynamicPPL.Model` directly. +function condition_gibbs(model::DynamicPPL.Model, values...) return DynamicPPL.contextualize(model, condition_gibbs(model.context, values...)) end @@ -166,7 +166,7 @@ julia> result.s != 1.0 # we did NOT want to condition on varinfo with `s = 1.0` true ``` """ -function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfos) +function make_conditional(model::DynamicPPL.Model, target_varinfo::DynamicPPL.AbstractVarInfo, varinfos) # TODO: Check if this is known at compile-time if `varinfos isa Tuple`. return condition_gibbs( model, @@ -175,31 +175,31 @@ function make_conditional(model::Model, target_varinfo::AbstractVarInfo, varinfo end wrap_algorithm_maybe(x) = x -wrap_algorithm_maybe(x::InferenceAlgorithm) = Sampler(x) +wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) -struct GibbsV2{V,A} <: InferenceAlgorithm +struct Gibbs{V,A} <: InferenceAlgorithm varnames::V samplers::A end # NamedTuple -GibbsV2(; algs...) = GibbsV2(NamedTuple(algs)) -function GibbsV2(algs::NamedTuple) - return GibbsV2( +Gibbs(; algs...) = Gibbs(NamedTuple(algs)) +function Gibbs(algs::NamedTuple) + return Gibbs( map(s -> VarName{s}(), keys(algs)), map(wrap_algorithm_maybe, values(algs)), ) end # AbstractDict -function GibbsV2(algs::AbstractDict) - return GibbsV2(keys(algs), map(wrap_algorithm_maybe, values(algs))) +function Gibbs(algs::AbstractDict) + return Gibbs(keys(algs), map(wrap_algorithm_maybe, values(algs))) end -function GibbsV2(algs::Pair...) - return GibbsV2(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) +function Gibbs(algs::Pair...) + return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) end -struct GibbsV2State{V<:AbstractVarInfo,S} +struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S} vi::V states::S end @@ -210,9 +210,9 @@ _maybevec(x::VarName) = [x] function DynamicPPL.initialstep( rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:GibbsV2}, - vi_base::AbstractVarInfo; + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + vi_base::DynamicPPL.AbstractVarInfo; kwargs..., ) alg = spl.alg @@ -232,7 +232,7 @@ function DynamicPPL.initialstep( new_state_local = last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...)) # Return the new state and the invlinked `varinfo`. - vi_local_state = varinfo(new_state_local) + vi_local_state = Turing.Inference.varinfo(new_state_local) vi_local_state_linked = if DynamicPPL.istrans(vi_local_state) DynamicPPL.invlink(vi_local_state, sampler_local, model_local) else @@ -252,20 +252,20 @@ function DynamicPPL.initialstep( DynamicPPL.getlogp(last(varinfos)), ) - return Transition(model, vi), GibbsV2State(vi, states) + return Turing.Inference.Transition(model, vi), GibbsState(vi, states) end function AbstractMCMC.step( rng::Random.AbstractRNG, - model::Model, - spl::Sampler{<:GibbsV2}, - state::GibbsV2State; + model::DynamicPPL.Model, + spl::DynamicPPL.Sampler{<:Gibbs}, + state::GibbsState; kwargs..., ) alg = spl.alg samplers = alg.samplers states = state.states - varinfos = map(varinfo, state.states) + varinfos = map(Turing.Inference.varinfo, state.states) @assert length(samplers) == length(state.states) # TODO: move this into a recursive function so we can unroll when reasonable? @@ -302,14 +302,14 @@ function AbstractMCMC.step( DynamicPPL.getlogp(last(varinfos)), ) - return Transition(model, vi), GibbsV2State(vi, states) + return Turing.Inference.Transition(model, vi), GibbsState(vi, states) end function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, sampler_previous::DynamicPPL.Sampler) # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact # same `selector` as before but now with `rerun` set to `true` if needed. - return DynamicPPL.Setfield.@set sampler.selector.rerun = gibbs_rerun(sampler_previous.alg, sampler.alg) + return Setfield.@set sampler.selector.rerun = gibbs_rerun(sampler_previous.alg, sampler.alg) end function gibbs_rerun_maybe( @@ -317,7 +317,7 @@ function gibbs_rerun_maybe( model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, sampler_previous::DynamicPPL.Sampler, - varinfo::AbstractVarInfo, + varinfo::DynamicPPL.AbstractVarInfo, ) # Return early if we don't need it. gibbs_rerun(sampler, sampler_previous) || return varinfo @@ -337,7 +337,7 @@ function gibbs_rerun_maybe( end function gibbs_step_inner( rng::Random.AbstractRNG, - model::Model, + model::DynamicPPL.Model, samplers, states, varinfos, @@ -367,7 +367,7 @@ function gibbs_step_inner( # 2. Take step with local sampler. # Update the state we're about to use if need be. # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - current_state_local = gibbs_state( + current_state_local = Turing.Inference.gibbs_state( model_local, sampler_local, state_local, varinfo_local ) @@ -384,7 +384,7 @@ function gibbs_step_inner( # 3. Extract the new varinfo. # Return the resulting state and invlinked `varinfo`. - varinfo_local_state = varinfo(new_state_local) + varinfo_local_state = Turing.Inference.varinfo(new_state_local) varinfo_local_state_invlinked = if DynamicPPL.istrans(varinfo_local_state) DynamicPPL.invlink(varinfo_local_state, sampler_local, model_local) else From 475e26403da184e7e17fa93960be131f74fe8f92 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Nov 2023 08:01:48 +0000 Subject: [PATCH 32/58] updated gibbs tests --- src/Turing.jl | 2 ++ src/mcmc/Inference.jl | 1 - test/mcmc/gibbs_new.jl | 20 ++++++++++---------- 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index 450de1ba5..b79c95744 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -54,6 +54,8 @@ using .Variational include("optimisation/Optimisation.jl") using .Optimisation +include("experimental/Experimental.jl") + ########### # Exports # ########### diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index cbf1fe8cc..cae806c96 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -475,7 +475,6 @@ include("gibbs.jl") include("sghmc.jl") include("emcee.jl") include("abstractmcmc.jl") -include("gibbs_new.jl") ################ # Typing tools # diff --git a/test/mcmc/gibbs_new.jl b/test/mcmc/gibbs_new.jl index a60df0076..a60abec6e 100644 --- a/test/mcmc/gibbs_new.jl +++ b/test/mcmc/gibbs_new.jl @@ -43,11 +43,11 @@ has_dot_assume(::Model) = true end samplers = [ - GibbsV2( + Turing.Experimental.Gibbs( vns_s => NUTS(), vns_m => NUTS(), ), - GibbsV2( + Turing.Experimental.Gibbs( vns_s => NUTS(), vns_m => HMC(0.01, 4), ) @@ -58,11 +58,11 @@ has_dot_assume(::Model) = true append!( samplers, [ - GibbsV2( + Turing.Experimental.Gibbs( vns_s => HMC(0.01, 4), vns_m => MH(), ), - GibbsV2( + Turing.Experimental.Gibbs( vns_s => MH(), vns_m => HMC(0.01, 4), ) @@ -90,7 +90,7 @@ end # Sample! rng = Random.default_rng() vns = [@varname(s), @varname(m)] - sampler = GibbsV2(map(Base.Fix2(Pair, MH()), vns)...) + sampler = Turing.Experimental.Gibbs(map(Base.Fix2(Pair, MH()), vns)...) @testset "step" begin transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(sampler)) @@ -109,10 +109,10 @@ end end @testset "gdemo with CSMC & ESS" begin - # `GibbsV2` does not work with SMC samplers, e.g. `CSMC`. + # `Turing.Experimental.Gibbs` does not work with SMC samplers, e.g. `CSMC`. # FIXME: Oooor it is (see tests below). Uncertain. Random.seed!(100) - alg = GibbsV2(@varname(s) => CSMC(15), @varname(m) => ESS()) + alg = Turing.Experimental.Gibbs(@varname(s) => CSMC(15), @varname(m) => ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_gdemo(chain) end @@ -123,7 +123,7 @@ end # With both `s` and `m` as random. model = gdemo(1.5, 2.0) vns = (@varname(s), @varname(m)) - alg = GibbsV2(vns => MH()) + alg = Turing.Experimental.Gibbs(vns => MH()) # `step` transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) @@ -140,7 +140,7 @@ end # Without `m` as random. model = gdemo(1.5, 2.0) | (m = 7 / 6,) vns = (@varname(s),) - alg = GibbsV2(vns => MH()) + alg = Turing.Experimental.Gibbs(vns => MH()) # `step` transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) @@ -154,7 +154,7 @@ end @testset "CSMC + ESS" begin rng = Random.default_rng() model = MoGtest_default - alg = GibbsV2( + alg = Turing.Experimental.Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS(), From 463f31cc23f5fde91320a527644fa29c3d43f096 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Nov 2023 08:02:59 +0000 Subject: [PATCH 33/58] moved experimental gibbs tests --- test/{mcmc/gibbs_new.jl => experimental/gibbs.jl} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{mcmc/gibbs_new.jl => experimental/gibbs.jl} (100%) diff --git a/test/mcmc/gibbs_new.jl b/test/experimental/gibbs.jl similarity index 100% rename from test/mcmc/gibbs_new.jl rename to test/experimental/gibbs.jl From a79ad92854a601503bbdfb77e95ab3ce6d9bcfea Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Nov 2023 08:03:54 +0000 Subject: [PATCH 34/58] updated tests to include experiemntal tests --- test/runtests.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 5d881615d..94e0826e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -92,10 +92,13 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ @timeit_include("optimisation/OptimInterface.jl") @timeit_include("ext/Optimisation.jl") end - end end + @testset "experimental" begin + @timeit_include("experimental/gibbs.jl") + end + @testset "variational optimisers" begin @timeit_include("variational/optimisers.jl") end From b2a95664cc6e7fccd11874476d1e6fd4be28540d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 24 Nov 2023 10:51:12 +0000 Subject: [PATCH 35/58] removed refrences to previews tests of experimental Gibbs sampler --- test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 94e0826e5..9497ef99d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,7 +74,6 @@ macro timeit_include(path::AbstractString) :(@timeit TIMEROUTPUT $path include($ @testset "samplers" begin @timeit_include("mcmc/gibbs.jl") @timeit_include("mcmc/gibbs_conditional.jl") - @timeit_include("mcmc/gibbs_new.jl") @timeit_include("mcmc/hmc.jl") @timeit_include("mcmc/Inference.jl") @timeit_include("mcmc/sghmc.jl") From 1266639628e96cabeb74df1c02b132e8cf2de670 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 29 Jan 2024 13:16:42 +0000 Subject: [PATCH 36/58] removed solved TODO --- src/experimental/gibbs.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index 89ad20710..546c949ac 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -123,7 +123,6 @@ function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) end # For `DynamicPPL.AbstractVarInfo` we just extract the values. function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) - # TODO: Determine when it's okay to use `NamedTuple` and use that instead. return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end function DynamicPPL.condition( From 7bd08314d38dcf50084a6dc903b72f7471f94692 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 29 Jan 2024 13:18:42 +0000 Subject: [PATCH 37/58] added a comment on `reconstruct_getvalue` usage --- src/experimental/gibbs.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index 546c949ac..109259cf5 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -65,6 +65,7 @@ function broadcast_logpdf(dist::MultivariateDistribution, x::AbstractMatrix) return loglikelihood(dist, x) end +# Needed to support broadcasting over columns for `MultivariateDistribution`s. reconstruct_getvalue(dist, x) = x function reconstruct_getvalue( dist::MultivariateDistribution, From c74727f51a68529645b07ce3e2811f9f237a6170 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 29 Jan 2024 13:24:53 +0000 Subject: [PATCH 38/58] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 04146896a..6833ca4b0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.30.2" +version = "0.30.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 8b46f11b3d8e5b224d0093725fd417b83da6da66 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 29 Jan 2024 13:27:02 +0000 Subject: [PATCH 39/58] added comments on future work --- src/experimental/gibbs.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index 109259cf5..1f9fd82ac 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -305,6 +305,7 @@ function AbstractMCMC.step( return Turing.Inference.Transition(model, vi), GibbsState(vi, states) end +# TODO: Remove this once we've done away with the selector functionality in DynamicPPL. function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, sampler_previous::DynamicPPL.Sampler) # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact @@ -312,6 +313,8 @@ function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler return Setfield.@set sampler.selector.rerun = gibbs_rerun(sampler_previous.alg, sampler.alg) end +# TODO: Once we have removed all the selector stuff in DynamicPPL, replace this with an improved mechanism +# for determining whether we need to re-run the model. function gibbs_rerun_maybe( rng::Random.AbstractRNG, model::DynamicPPL.Model, From a81cf09be0355bd7708e1140eefa98a3a18bb57f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 29 Jan 2024 13:46:39 +0000 Subject: [PATCH 40/58] Update test/experimental/gibbs.jl --- test/experimental/gibbs.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index a60abec6e..479d9147b 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -17,8 +17,6 @@ end # 1. (✓) Needs to be compatible with most models. # 2. (???) Restricted to usage of pairs for now to make things simple. -# TODO: Don't require usage of tuples due to potential of blowing up compilation times. - const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, From f280caa3aad5c4740f24de0b6ef39e72ae3cbf82 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 29 Jan 2024 14:09:06 +0000 Subject: [PATCH 41/58] fixed bug where particle samplers didn't properly account for weightings of logpdf, etc. --- src/mcmc/particle_mcmc.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 5ed2d53cd..7d0714c41 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -394,8 +394,8 @@ function DynamicPPL.assume( end function DynamicPPL.observe(spl::Sampler{<:Union{PG,SMC}}, dist::Distribution, value, vi) - Libtask.produce(logpdf(dist, value)) - return 0, vi + # NOTE: The `Libtask.produce` is now hit in `acclogp_observe!!`. + return logpdf(dist, value), trace_local_varinfo_maybe(vi) end function DynamicPPL.acclogp!!( @@ -407,6 +407,15 @@ function DynamicPPL.acclogp!!( DynamicPPL.acclogp!!(DynamicPPL.childcontext(context), varinfo_trace, logp) end +function DynamicPPL.acclogp_observe!!( + context::SamplingContext{<:Sampler{<:Union{PG,SMC}}}, + varinfo::AbstractVarInfo, + logp +) + Libtask.produce(logp) + return trace_local_varinfo_maybe(varinfo) +end + # Convenient constructor function AdvancedPS.Trace( model::Model, From b1dcadf664f4c417e3716c33defa71bb08a15a93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 30 Jan 2024 16:59:41 +0000 Subject: [PATCH 42/58] relax atol for a numerical test with Gibbs a bit --- test/experimental/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index 479d9147b..5fb00dd24 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -133,7 +133,7 @@ end # `sample` chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.1) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.2) # Without `m` as random. model = gdemo(1.5, 2.0) | (m = 7 / 6,) From 4bde75ade7f4536889c6305e90130f532cac80ec Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 10 Mar 2024 21:53:04 +0000 Subject: [PATCH 43/58] fixed bug with `AbstractDict` constructor for experimental `Gibbs` --- src/experimental/gibbs.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index 1f9fd82ac..d9e6c0fa1 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -193,7 +193,7 @@ end # AbstractDict function Gibbs(algs::AbstractDict) - return Gibbs(keys(algs), map(wrap_algorithm_maybe, values(algs))) + return Gibbs(collect(keys(algs)), map(wrap_algorithm_maybe, values(algs))) end function Gibbs(algs::Pair...) return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs))) From b54e6eb215400bec1aa514229eaaeda73e059270 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 15 Apr 2024 16:53:36 +0100 Subject: [PATCH 44/58] aaaalways link the varinfo in the new Gibbs sampler, just to be sure --- src/experimental/gibbs.jl | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index d9e6c0fa1..bd6d3f4f0 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -352,12 +352,20 @@ function gibbs_step_inner( state_local = states[index] varinfo_local = varinfos[index] + # Make sure that all `varinfos` are linked. + varinfos_invlinked = map(varinfos) do vi + # NOTE: This is immutable linking! + # TODO: Do we need the `istrans` check here or should we just always use `invlink`? + DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi + end + varinfo_local_invlinked = varinfos_invlinked[index] + # 1. Create conditional model. # Construct the conditional model. # NOTE: Here it's crucial that all the `varinfos` are in the constrained space, # otherwise we're conditioning on values which are not in the support of the # distributions. - model_local = make_conditional(model, varinfo_local, varinfos) + model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) # NOTE: We use `logjoint` instead of `evaluate!!` and capturing the resulting varinfo because # the resulting varinfo might be in un-transformed space even if `varinfo_local` From e7ad6823395ba8ae0eb48c360b0e8da95bb41388 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 11:00:27 +0100 Subject: [PATCH 45/58] add test to cover recent improvement to `DynamicPPL.subset` ref: https://github.com/TuringLang/DynamicPPL.jl/pull/587 --- test/experimental/gibbs.jl | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index 5fb00dd24..eb0b66a3a 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -133,7 +133,7 @@ end # `sample` chain = sample(model, alg, 10_000; progress=false) - check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.2) + check_numerical(chain, [:s, :m], [49 / 24, 7 / 6], atol = 0.3) # Without `m` as random. model = gdemo(1.5, 2.0) | (m = 7 / 6,) @@ -152,11 +152,41 @@ end @testset "CSMC + ESS" begin rng = Random.default_rng() model = MoGtest_default - alg = Turing.Experimental.Gibbs( + vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) + alg_explicit = Turing.Experimental.Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS(), ) + # Here `@varname(z)` is supposed to cover all the `z`'s. + alg_z_implicit = Turing.Experimental.Gibbs( + @varname(z) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) + for alg in [alg_explicit, alg_z_implicit] + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ = 1:5 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) + check_transition_varnames(transition, vns) + end + + # Sample! + chain = sample(MoGtest_default, alg, 1000; progress=true) + check_MoGtest_default(chain, atol = 0.2) + end + end + + @testset "CSMC + ESS" begin + rng = Random.default_rng() + model = MoGtest_default + alg = Turing.Experimental.Gibbs( + @varname(z) => CSMC(15), + @varname(mu1) => ESS(), + @varname(mu2) => ESS(), + ) vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) # `step` transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) From 27133162d1f91b1dd3d48d818f9652529a6b2636 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 11:01:18 +0100 Subject: [PATCH 46/58] bump compat entry for DynamicPPL --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f13d92e0d..93396bc9c 100644 --- a/Project.toml +++ b/Project.toml @@ -58,7 +58,7 @@ Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8, 0.9" DynamicHMC = "3.4" -DynamicPPL = "0.24.7" +DynamicPPL = "0.24.10" EllipticalSliceSampling = "0.5, 1, 2" ForwardDiff = "0.10.3" Libtask = "0.7, 0.8" From fb29556e6e3e6c1ac5b0f0553cbf497ac79006b9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 12:50:34 +0100 Subject: [PATCH 47/58] added some docstrings --- src/experimental/gibbs.jl | 41 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index bd6d3f4f0..e08233091 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -99,6 +99,11 @@ function DynamicPPL.dot_tilde_assume( end +""" + preferred_value_type(varinfo::DynamicPPL.AbstractVarInfo) + +Returns the preferred value type for a variable with the given `varinfo`. +""" preferred_value_type(::DynamicPPL.AbstractVarInfo) = DynamicPPL.OrderedDict preferred_value_type(::DynamicPPL.SimpleVarInfo{<:NamedTuple}) = NamedTuple function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) @@ -109,7 +114,16 @@ function preferred_value_type(varinfo::DynamicPPL.TypedVarInfo) return namedtuple_compatible ? NamedTuple : DynamicPPL.OrderedDict end -# No-op if no values are provided. +""" + condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}...) + +Return a `GibbsContext` with the given values treated as conditioned. + +# Arguments +- `context::DynamicPPL.AbstractContext`: The context to condition. +- `values::Union{NamedTuple,AbstractDict}...`: The values to condition on. + If multiple values are provided, we recursively condition on each of them. +""" condition_gibbs(context::DynamicPPL.AbstractContext) = context # For `NamedTuple` and `AbstractDict` we just construct the context. function condition_gibbs(context::DynamicPPL.AbstractContext, values::Union{NamedTuple,AbstractDict}) @@ -122,7 +136,13 @@ function condition_gibbs(context::DynamicPPL.AbstractContext, value, values...) values... ) end + # For `DynamicPPL.AbstractVarInfo` we just extract the values. +""" + condition_gibbs(context::DynamicPPL.AbstractContext, varinfos::DynamicPPL.AbstractVarInfo...) + +Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned. +""" function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo) return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo))) end @@ -173,12 +193,31 @@ function make_conditional(model::DynamicPPL.Model, target_varinfo::DynamicPPL.Ab filter(Base.Fix1(!==, target_varinfo), varinfos)... ) end +# Assumes the ones given are the ones to condition on. +function make_conditional(model::DynamicPPL.Model, varinfos) + return condition_gibbs( + model, + varinfos... + ) +end +# HACK: Allows us to support either passing in an implementation of `AbstractMCMC.AbstractSampler` +# or an `AbstractInferenceAlgorithm`. wrap_algorithm_maybe(x) = x wrap_algorithm_maybe(x::InferenceAlgorithm) = DynamicPPL.Sampler(x) +""" + Gibbs + +A type representing a Gibbs sampler. + +# Fields +$(TYPEDFIELDS) +""" struct Gibbs{V,A} <: InferenceAlgorithm + "varnames representing variables for each sampler" varnames::V + "samplers for each entry in `varnames`" samplers::A end From d3a13ad7337fc2fed02cb35996a8cb4af85efde0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 12:53:49 +0100 Subject: [PATCH 48/58] fixed test --- test/experimental/gibbs.jl | 2 +- test/test_utils/models.jl | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index eb0b66a3a..163dc715d 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -181,7 +181,7 @@ end @testset "CSMC + ESS" begin rng = Random.default_rng() - model = MoGtest_default + model = MoGtest_default_z_vector alg = Turing.Experimental.Gibbs( @varname(z) => CSMC(15), @varname(mu1) => ESS(), diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index d32ab3cb4..ee26d10f1 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -49,5 +49,39 @@ end MoGtest_default = MoGtest([1.0 1.0 4.0 4.0]) +@model function MoGtest_z_vector(D) + mu1 ~ Normal(1, 1) + mu2 ~ Normal(4, 1) + + z ~ Vector{Int}(undef, 4) + z[1] ~ Categorical(2) + if z[1] == 1 + D[1] ~ Normal(mu1, 1) + else + D[1] ~ Normal(mu2, 1) + end + z[2] ~ Categorical(2) + if z[2] == 1 + D[2] ~ Normal(mu1, 1) + else + D[2] ~ Normal(mu2, 1) + end + z[3] ~ Categorical(2) + if z[3] == 1 + D[3] ~ Normal(mu1, 1) + else + D[3] ~ Normal(mu2, 1) + end + z[4] ~ Categorical(2) + if z[4] == 1 + D[4] ~ Normal(mu1, 1) + else + D[4] ~ Normal(mu2, 1) + end + z[1], z[2], z[3], z[4], mu1, mu2 +end + +MoGtest_default_z_vector = MoGtest_z_vector([1.0 1.0 4.0 4.0]) + # Declare empty model to make the Sampler constructor work. @model empty_model() = x = 1 From 81bf9c0c760c5c0e56988e6824e82ba150ac30c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 13:03:53 +0100 Subject: [PATCH 49/58] fixed import --- src/experimental/Experimental.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/experimental/Experimental.jl b/src/experimental/Experimental.jl index 5d390475e..fa446465e 100644 --- a/src/experimental/Experimental.jl +++ b/src/experimental/Experimental.jl @@ -5,6 +5,7 @@ using AbstractMCMC: AbstractMCMC using DynamicPPL: DynamicPPL, VarName using Setfield: Setfield +using DocStringExtensions: TYPEDFIELDS using Distributions using ..Turing: Turing From d340e5847daaf99d36ce4eca3e84d65c4dca83ca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 13:37:26 +0100 Subject: [PATCH 50/58] another attempt at fixing tests --- test/test_utils/models.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index ee26d10f1..fc392b050 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -53,7 +53,7 @@ MoGtest_default = MoGtest([1.0 1.0 4.0 4.0]) mu1 ~ Normal(1, 1) mu2 ~ Normal(4, 1) - z ~ Vector{Int}(undef, 4) + z = Vector{Int}(undef, 4) z[1] ~ Categorical(2) if z[1] == 1 D[1] ~ Normal(mu1, 1) From 6f6fe7ad765a6884eabb335dd3bdfc0217858a31 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 14:07:51 +0100 Subject: [PATCH 51/58] another attempt at fixing tests --- test/experimental/gibbs.jl | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index 163dc715d..34c44e8d9 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -152,34 +152,26 @@ end @testset "CSMC + ESS" begin rng = Random.default_rng() model = MoGtest_default - vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) - alg_explicit = Turing.Experimental.Gibbs( + alg = Turing.Experimental.Gibbs( (@varname(z1), @varname(z2), @varname(z3), @varname(z4)) => CSMC(15), @varname(mu1) => ESS(), @varname(mu2) => ESS(), ) - # Here `@varname(z)` is supposed to cover all the `z`'s. - alg_z_implicit = Turing.Experimental.Gibbs( - @varname(z) => CSMC(15), - @varname(mu1) => ESS(), - @varname(mu2) => ESS(), - ) - for alg in [alg_explicit, alg_z_implicit] - # `step` - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) + # `step` + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) + check_transition_varnames(transition, vns) + for _ = 1:5 + transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) check_transition_varnames(transition, vns) - for _ = 1:5 - transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg), state) - check_transition_varnames(transition, vns) - end - - # Sample! - chain = sample(MoGtest_default, alg, 1000; progress=true) - check_MoGtest_default(chain, atol = 0.2) end + + # Sample! + chain = sample(MoGtest_default, alg, 1000; progress=true) + check_MoGtest_default(chain, atol = 0.2) end - @testset "CSMC + ESS" begin + @testset "CSMC + ESS (usage of implicit varname)" begin rng = Random.default_rng() model = MoGtest_default_z_vector alg = Turing.Experimental.Gibbs( @@ -187,7 +179,7 @@ end @varname(mu1) => ESS(), @varname(mu2) => ESS(), ) - vns = (@varname(z1), @varname(z2), @varname(z3), @varname(z4), @varname(mu1), @varname(mu2)) + vns = (@varname(z[1]), @varname(z[2]), @varname(z[3]), @varname(z[4]), @varname(mu1), @varname(mu2)) # `step` transition, state = AbstractMCMC.step(rng, model, DynamicPPL.Sampler(alg)) check_transition_varnames(transition, vns) From 9c93162fafe5f8433bf65e2130f280d8ec7a055e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 14:41:34 +0100 Subject: [PATCH 52/58] attempt at fix tests --- test/experimental/gibbs.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index 34c44e8d9..4f1384110 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -189,7 +189,7 @@ end end # Sample! - chain = sample(MoGtest_default, alg, 1000; progress=true) - check_MoGtest_default(chain, atol = 0.2) + chain = sample(model, alg, 1000; progress=true) + check_MoGtest_default_z_vector(chain, atol = 0.2) end end From cfa8fb37ffb1666b13cd7319642aad57267e24b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 14:41:44 +0100 Subject: [PATCH 53/58] forgot something in previos commit --- test/test_utils/numerical_tests.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 090dabb31..333a3f14a 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -64,3 +64,10 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0) [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], atol=atol, rtol=rtol) end + +function check_MoGtest_default_z_vector(chain; atol=0.2, rtol=0.0) + check_numerical(chain, + [Symbol("z[1]"), Symbol("z[2]"), Symbol("z[3]"), Symbol("z[4]"), :mu1, :mu2], + [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], + atol=atol, rtol=rtol) +end From 15e83b91386a4ade347a8556e8e2dda7669f1e78 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 15:25:25 +0100 Subject: [PATCH 54/58] cleaned up the experimental Gibbs sampler a bit --- src/experimental/gibbs.jl | 94 ++++++++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 26 deletions(-) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index e08233091..2e6623da0 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -345,38 +345,75 @@ function AbstractMCMC.step( end # TODO: Remove this once we've done away with the selector functionality in DynamicPPL. -function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, sampler_previous::DynamicPPL.Sampler) +function make_rerun_sampler(model::DynamicPPL.Model, sampler::DynamicPPL.Sampler) # NOTE: This is different from the implementation used in the old `Gibbs` sampler, where we specifically provide # a `gid`. Here, because `model` only contains random variables to be sampled by `sampler`, we just use the exact # same `selector` as before but now with `rerun` set to `true` if needed. - return Setfield.@set sampler.selector.rerun = gibbs_rerun(sampler_previous.alg, sampler.alg) + return Setfield.@set sampler.selector.rerun = true end -# TODO: Once we have removed all the selector stuff in DynamicPPL, replace this with an improved mechanism -# for determining whether we need to re-run the model. -function gibbs_rerun_maybe( +# Interface we need a sampler to implement to work as a component in a Gibbs sampler. +""" + gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) + +Check if the log-probability of the destination model needs to be recomputed. + +Defaults to `true` +""" +function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) + return true +end + +# TODO: Remove `rng`? +""" + recompute_logprob!!(rng, model, sampler, state) + +Recompute the log-probability of the `model` based on the given `state` and return the resulting state. +""" +function recompute_logprob!!( rng::Random.AbstractRNG, model::DynamicPPL.Model, sampler::DynamicPPL.Sampler, - sampler_previous::DynamicPPL.Sampler, - varinfo::DynamicPPL.AbstractVarInfo, + state ) - # Return early if we don't need it. - gibbs_rerun(sampler, sampler_previous) || return varinfo - - # Make the re-run sampler. + varinfo = Turing.Inference.varinfo(state) # NOTE: Need to do this because some samplers might need some other quantity than the log-joint, # e.g. log-likelihood in the scenario of `ESS`. # NOTE: Need to update `sampler` too because the `gid` might change in the re-run of the model. - sampler_rerun = make_rerun_sampler(model, sampler, sampler_previous) + sampler_rerun = make_rerun_sampler(model, sampler) # NOTE: If we hit `DynamicPPL.maybe_invlink_before_eval!!`, then this will result in a `invlink`ed # `varinfo`, even if `varinfo` was linked. - return last(DynamicPPL.evaluate!!( + varinfo_new = last(DynamicPPL.evaluate!!( model, varinfo, + # TODO: Check if it's safe to drop the `rng` argument, i.e. just use default RNG. DynamicPPL.SamplingContext(rng, sampler_rerun) )) + # Update the state we're about to use if need be. + # NOTE: If the sampler requires a linked varinfo, this should be done in `gibbs_state`. + return Turing.Inference.gibbs_state(model, sampler, state, varinfo_new) end + +function gibbs_step_inner( + rng::Random.AbstractRNG, + model_dst, + sampler_dst, + sampler_src, + state_dst, + state_src; + kwargs... +) + # `model_dst` might be different here, e.g. conditioned on new values, so we need to check if need to recompute the log-probability. + if gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) + # Re-evaluate the log density of the destination model. + state_dst = recompute_logprob!!(model_dst, sampler_dst, state_dst, logprob_dst) + end + + # Step! + return AbstractMCMC.step(rng, model_dst, sampler_dst, state_dst; kwargs...) +end + + function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, @@ -406,22 +443,27 @@ function gibbs_step_inner( # distributions. model_local = make_conditional(model, varinfo_local_invlinked, varinfos_invlinked) - # NOTE: We use `logjoint` instead of `evaluate!!` and capturing the resulting varinfo because - # the resulting varinfo might be in un-transformed space even if `varinfo_local` - # is in transformed space. This can occur if we hit `maybe_invlink_before_eval!!`. - - # Re-run the sampler if needed. + # Extract the previous sampler and state. sampler_previous = samplers[index == 1 ? length(samplers) : index - 1] - varinfo_local = gibbs_rerun_maybe(rng, model_local, sampler_local, sampler_previous, varinfo_local) - - # 2. Take step with local sampler. - # Update the state we're about to use if need be. - # If the sampler requires a linked varinfo, this should be done in `gibbs_state`. - current_state_local = Turing.Inference.gibbs_state( - model_local, sampler_local, state_local, varinfo_local + state_previous = states[index == 1 ? length(states) : index - 1] + + # 1. Re-run the sampler if needed. + if gibbs_requires_recompute_logprob( + model_local, + sampler_local, + sampler_previous, + state_local, + state_previous ) + current_state_local = recompute_logprob!!( + rng, + model_local, + sampler_local, + state_local, + ) + end - # Take a step. + # 2. Take step with local sampler. new_state_local = last( AbstractMCMC.step( rng, From aed307a2ee58aeca9b1b62f824f9e4394308dccc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 17 Apr 2024 15:41:25 +0100 Subject: [PATCH 55/58] removed accidentaly psuedocode inclusion --- src/experimental/gibbs.jl | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index 2e6623da0..93414d11a 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -394,26 +394,6 @@ function recompute_logprob!!( return Turing.Inference.gibbs_state(model, sampler, state, varinfo_new) end -function gibbs_step_inner( - rng::Random.AbstractRNG, - model_dst, - sampler_dst, - sampler_src, - state_dst, - state_src; - kwargs... -) - # `model_dst` might be different here, e.g. conditioned on new values, so we need to check if need to recompute the log-probability. - if gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, state_dst, state_src) - # Re-evaluate the log density of the destination model. - state_dst = recompute_logprob!!(model_dst, sampler_dst, state_dst, logprob_dst) - end - - # Step! - return AbstractMCMC.step(rng, model_dst, sampler_dst, state_dst; kwargs...) -end - - function gibbs_step_inner( rng::Random.AbstractRNG, model::DynamicPPL.Model, From 04c177a16dcfb7007c8c545a7babe63b61a26334 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 19 Apr 2024 14:00:45 +0100 Subject: [PATCH 56/58] Apply suggestions from code review --- src/Turing.jl | 1 - src/experimental/gibbs.jl | 1 + src/mcmc/Inference.jl | 1 - test/experimental/gibbs.jl | 6 ------ 4 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index d8ccb5170..093b0d26b 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -71,7 +71,6 @@ export @model, # modelling ESS, Gibbs, GibbsConditional, - GibbsV2, HMC, # Hamiltonian-like sampling SGLD, diff --git a/src/experimental/gibbs.jl b/src/experimental/gibbs.jl index 93414d11a..1e51cc9f1 100644 --- a/src/experimental/gibbs.jl +++ b/src/experimental/gibbs.jl @@ -412,6 +412,7 @@ function gibbs_step_inner( varinfos_invlinked = map(varinfos) do vi # NOTE: This is immutable linking! # TODO: Do we need the `istrans` check here or should we just always use `invlink`? + # FIXME: Suffers from https://github.com/TuringLang/Turing.jl/issues/2195 DynamicPPL.istrans(vi) ? DynamicPPL.invlink(vi, model) : vi end varinfo_local_invlinked = varinfos_invlinked[index] diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index efdf8fdbf..b990dac67 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -48,7 +48,6 @@ export InferenceAlgorithm, Emcee, Gibbs, # classic sampling GibbsConditional, - GibbsV2, HMC, SGLD, PolynomialStepsize, diff --git a/test/experimental/gibbs.jl b/test/experimental/gibbs.jl index 4f1384110..29713d2d5 100644 --- a/test/experimental/gibbs.jl +++ b/test/experimental/gibbs.jl @@ -13,10 +13,6 @@ function check_transition_varnames( end end -# Okay, so what do we actually need to test here. -# 1. (✓) Needs to be compatible with most models. -# 2. (???) Restricted to usage of pairs for now to make things simple. - const DEMO_MODELS_WITHOUT_DOT_ASSUME = Union{ Model{typeof(DynamicPPL.TestUtils.demo_assume_index_observe)}, Model{typeof(DynamicPPL.TestUtils.demo_assume_multivariate_observe)}, @@ -107,8 +103,6 @@ end end @testset "gdemo with CSMC & ESS" begin - # `Turing.Experimental.Gibbs` does not work with SMC samplers, e.g. `CSMC`. - # FIXME: Oooor it is (see tests below). Uncertain. Random.seed!(100) alg = Turing.Experimental.Gibbs(@varname(s) => CSMC(15), @varname(m) => ESS()) chain = sample(gdemo(1.5, 2.0), alg, 10_000) From 0acd27b8acb77994040a21d121319b2d302f51fc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 20 Apr 2024 12:04:38 +0100 Subject: [PATCH 57/58] relaxed olerance in one MH test a bit --- test/mcmc/mh.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/mcmc/mh.jl b/test/mcmc/mh.jl index afd84b2aa..d047f55a7 100644 --- a/test/mcmc/mh.jl +++ b/test/mcmc/mh.jl @@ -242,6 +242,6 @@ MH(AdvancedMH.RandomWalkProposal(filldist(Normal(), 3))), 10_000 ) - check_numerical(chain, [Symbol("x[1]"), Symbol("x[2]"), Symbol("x[3]")], [0, 0, 0], atol=0.1) + check_numerical(chain, [Symbol("x[1]"), Symbol("x[2]"), Symbol("x[3]")], [0, 0, 0], atol=0.2) end end From 0f30514386395e88780fb23c4a24e42ca9e35e16 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 20 Apr 2024 14:38:10 +0100 Subject: [PATCH 58/58] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9f031d0a6..4f49a72b7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.30.8" +version = "0.30.9" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"