Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace old Gibbs sampler with the experimental one. #2328

Open
wants to merge 79 commits into
base: master
Choose a base branch
from

Conversation

mhauru
Copy link
Member

@mhauru mhauru commented Sep 23, 2024

Closes #2318.

Work in progress.

Copy link

codecov bot commented Sep 23, 2024

Codecov Report

Attention: Patch coverage is 51.37255% with 124 lines in your changes missing coverage. Please review.

Project coverage is 74.17%. Comparing base (c0a4ee9) to head (a15ce2f).

Files with missing lines Patch % Lines
src/mcmc/gibbs.jl 53.08% 99 Missing ⚠️
src/mcmc/hmc.jl 0.00% 7 Missing ⚠️
src/mcmc/repeat_sampler.jl 65.00% 7 Missing ⚠️
src/mcmc/Inference.jl 0.00% 4 Missing ⚠️
src/mcmc/sghmc.jl 0.00% 4 Missing ⚠️
src/mcmc/emcee.jl 0.00% 1 Missing ⚠️
src/mcmc/is.jl 0.00% 1 Missing ⚠️
src/mcmc/particle_mcmc.jl 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2328       +/-   ##
===========================================
+ Coverage   44.72%   74.17%   +29.45%     
===========================================
  Files          22       21        -1     
  Lines        1554     1584       +30     
===========================================
+ Hits          695     1175      +480     
+ Misses        859      409      -450     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@coveralls
Copy link

coveralls commented Sep 23, 2024

Pull Request Test Coverage Report for Build 12145143867

Details

  • 131 of 255 (51.37%) changed or added relevant lines in 11 files are covered.
  • 21 unchanged lines in 5 files lost coverage.
  • Overall coverage increased (+30.7%) to 68.778%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/mcmc/emcee.jl 0 1 0.0%
src/mcmc/is.jl 0 1 0.0%
src/mcmc/particle_mcmc.jl 2 3 66.67%
src/mcmc/Inference.jl 0 4 0.0%
src/mcmc/sghmc.jl 0 4 0.0%
src/mcmc/hmc.jl 0 7 0.0%
src/mcmc/repeat_sampler.jl 13 20 65.0%
src/mcmc/gibbs.jl 112 211 53.08%
Files with Coverage Reduction New Missed Lines %
src/mcmc/hmc.jl 1 0.0%
src/mcmc/ess.jl 1 94.64%
src/mcmc/Inference.jl 2 63.1%
src/mcmc/abstractmcmc.jl 8 78.72%
src/mcmc/particle_mcmc.jl 9 86.75%
Totals Coverage Status
Change from base Build 12110228665: 30.7%
Covered Lines: 1086
Relevant Lines: 1579

💛 - Coveralls

HISTORY.md Outdated

The old Gibbs constructor relied on being called with several subsamplers, and each of the constructors of the subsamplers would take as arguments the symbols for the variables that they are to sample, e.g. `Gibbs(HMC(:x), MH(:y))`. This constructor has been deprecated, and will be removed in the future. The new constructor works by assigning samplers to either symbols or `VarNames`, e.g. `Gibbs(; x=HMC(), y=MH())` or `Gibbs(@varname(x) => HMC(), @varname(y) => MH())`. This allows more granular specification of which sampler to use for which variable.

Likewise, the old constructor for calling one subsampler more often than another, `Gibbs((HMC(:x), 2), (MH(:y), 1))` has been deprecated. The new way to achieve this effect is to list the same sampler multiple times, e.g. as `hmc = HMC(); mh = MH(); Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gibbs(@varname(x) => hmc, @varname(x) => hmc, @varname(y) => mh)

This looks rather awkward. Can we introduce a simple wrapper, Repeated and support:

Gibbs(@varname(x) => Repeated(hmc, n), @varname(y) => mh)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We had a chat about a closely related issue with @torfjelde too, I'll rework the interface around this a bit.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I introduced a wrapper sampler: Gibbs(@varname(x) => RepeatSampler(HMC(0.01, 4), 2), @varname(y) => MH())

@mhauru
Copy link
Member Author

mhauru commented Sep 26, 2024

@torfjelde, if you have a moment to take a look at the one remaining test failure, would be interested in your thoughts. We are sampling for a model with two vector variables, m and z, and we seem to somehow end up with a case where there's a VarInfo with only z in it, but the sampler is looking for m too. I wonder if it's something about the interaction between particle sampling with Libtask and how the new Gibbs does things with the local varinfos. The test that fails is this one:

    @testset "dynamic model" begin
        @model function imm(y, alpha, ::Type{M}=Vector{Float64}) where {M}
            N = length(y)
            rpm = DirichletProcess(alpha)

            z = zeros(Int, N)
            cluster_counts = zeros(Int, N)
            fill!(cluster_counts, 0)

            for i in 1:N
                z[i] ~ ChineseRestaurantProcess(rpm, cluster_counts)
                cluster_counts[z[i]] += 1
            end

            Kmax = findlast(!iszero, cluster_counts)
            m = M(undef, Kmax)
            for k in 1:Kmax
                m[k] ~ Normal(1.0, 1.0)
            end
        end
        model = imm(Random.randn(100), 1.0)
        # https://github.com/TuringLang/Turing.jl/issues/1725
        # sample(model, Gibbs(MH(:z), HMC(0.01, 4, :m)), 100);
        sample(model, Gibbs(; z=PG(10), m=HMC(0.01, 4; adtype=adbackend)), 100)
    end

@torfjelde
Copy link
Member

Will have a look at this in a bit @mhauru (just need to do some grocery shopping 😬 )

@mhauru
Copy link
Member Author

mhauru commented Sep 26, 2024

Collecting links to old relevant PRs so I don't have to look for them again: #2231, #2099

@torfjelde
Copy link
Member

Think I found the error: if the number of m increases, say, from length(m) = 2 to length(m) = 3 during the PG step, then the lines

if has_conditioned_gibbs(context, vn)
value = get_conditioned_gibbs(context, vn)
return value, logpdf(right, value), vi
end
# Otherwise, falls back to the default behavior.
return DynamicPPL.tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
)

doesn't hit the Gibbs branch since @varname(m[3]) is not present in the GibbsContext 😕

@torfjelde
Copy link
Member

doesn't hit the Gibbs branch since @varname(m[3]) is not present in the GibbsContext

I'm a bit uncertain how we should best handle this @yebai @mhauru

The first partially viable idea that comes to mind is to subset the varinfo to make sure that it only contains the correct variables. If we do this, then m[3] will just be "ignored" (in the varinfos) until we're actually sampling the m variables, in which case it would be captured correctly.

But this would not quite be equivalent to the current implementation of Gibbs, which, AFAIK, keeps the very first occurence of m around rather than resampling everytime. And naively, I would expect this to be incorrect.

Another way is to explicitly add the varinfos to the GibbsContext itself, and then, when we encounter a value that should in fact go into a different varinfo, we add it there. But this has a few issues:

  1. Requires the VarInfo to be mutable.
  2. Requires the VarInfo to have a container that can keep the new incoming value m[3].
  3. Implementation of Gibbs does end up being more complicated than the current approach. However, it might be worth it.

Thoughts?

@yebai
Copy link
Member

yebai commented Sep 27, 2024

Another way is to explicitly add the varinfos to the GibbsContext itself, and then, when we encounter a value that should in fact go into a different varinfo, we add it there. But this has a few issues:

Requires the VarInfo to be mutable.
Requires the VarInfo to have a container that can keep the new incoming value m[3].
Implementation of Gibbs does end up being more complicated than the current approach. However, it might be worth it.

I lean towards the above approach and (maybe later) provide explicit APIs to inference algorithms. This will enable us to handle reversible jumps (varying model dimensions) in MCMC more flexibly. At the moment, this is only possible in particle Gibbs; if it happens in HMC/MH, inference will likely fail (silently)

EDIT: we can keep VarInfos immutable by default, and requires inference developers to hook into specific APIs to mutate VarInfos.

@torfjelde
Copy link
Member

This does however complicate the new Gibbs sampling procedure quite drastically 😕

And it makes me bring up a question I really didn't think I'd be asking: is it then actually preferable to the current Gibbs with keeping it all in a single VarInfo with a flag to specify whether it should be sampled or not? 😬

I guess we should first have a go at implementing this for the new Gibbs and then we can see 👍

Another point to add to the conversation that @mhauru brought to my attention the other day: we also want to support stuff like Gibbs(@varname(m) => NUTS(), @varname(m) => HMC()), i.e. multiple samplers targeting the same variables. This adds a few "complications" (beyond addressing the growing model problem discussed above):

  1. Need to determine which varinfo to pick from varinfos based on the varnames present / targeted.
  2. A naive implementation will result in duplicated entries in varinfos. We can however address this if we really feel like it's worth it, so probably a non-issue atm.

So all in all, immediate things we need to address with Gibbs:

  1. Support changing dimensions.
  2. Support picking a varinfo to condition on based on the varnames present rather than based on ===.

@mhauru
Copy link
Member Author

mhauru commented Oct 10, 2024

I've been trying to think of a way to fix this, that would also fix the problem where different Gibbs subsamplers can't sample the same variables (e.g. you can't first sample x and y using one sampler, and then y and z with a different one). My best thought at the moment is the following design:

  1. There is only one, global VarInfo, call it vi.
  2. make_conditional takes that vi and a list of VarNames that the current subsampler samples. It hijacks the tilde pipeline to condition all other variables to their current values in vi.
  3. vi may have some variables linked, some not.
  4. Every time we call a subsampler we can hand it vi as the VarInfo. It won’t mess with any of the variables it’s not supposed to touch, because the tilde pipeline hijack from point 2.

Point 3. is maybe undesirable, but I think it’s minor compared to all the Selector/gibbsid stuff, which we would still get rid of.

The only problem I see with this is combining the local state from the previous iteration of the current subsampler with the global vi. Somehow we would need to join up-to-date information from the global vi with state-information from the previous iteration, specific to this subsampler. The right way to do this depends on the state, which is a different type of object for different subsamplers. EDIT: Actually, maybe this is okay, because we seem to already assume that every state object has a field called state.vi , we could just reset that.

The great benefit of sticking to one, global VarInfo is never having to worry about moving data between the local VarInfos. That would have to happen in both cases, when a new variable is introduced by one sampler (the failing test in this PR) and when two samplers sample the same variable. It sounds like a pain to implement.

@mhauru
Copy link
Member Author

mhauru commented Oct 10, 2024

I can imagine two different philosophies to implementing a Gibbs sampler:

  1. Every subsampler is doing its own sampling process on a low-dimensional model (a conditioned version of the full model), independent of the others. The logprobability function it's sampling from just keeps changing between iterations, because the other variables change and thus the conditioned model changes, but otherwise it's blind to the existence of the variables it isn't sampling. This is what the new Gibbs sampler does.
  2. Every subsampler is working with the same, full model, with all the variables, but only makes the changes to a subset of those variables. It still "sees" the whole model. This is what the old Gibbs sampler did.

My above proposal would essentially be doing 2., but using code that's very much like the new sampler, where the information about which sampler modifies which variables is in the sampler/GibbsContext, and not in VarInfo like it was in the old Gibbs.

The reason I'm leaning towards 2. is that 1. seems to run to some fundamental issues in cases where either

  • Variables appear and disappear based on values of other variables,
  • Two samplers want to modify the value of the same variable.

Both of those situations quite deeply violate the idea that the different subsamplers can operate mostly independently of each other.

Any thoughts very welcome, I'm still very much trying to understand the landscape of the problem.

@yebai
Copy link
Member

yebai commented Oct 10, 2024

Thanks, @mhauru, for the excellent summary of the problem and proposals. Storing conditioned variables in a context, like GibbsContext as you suggested, is very sensible. The consequence is that VarInfo and Context will have overlapped model parameters, e.g. conditioned variables will be found in both VarInfo and Context, which is fine.

In addition, it's worth mentioning that we currently have two mechanisms for passing observations to models, i.e.

(1) via model arguments, e.g. gdemo(x, y).
(2) via condition API, e.g. condition(model, (x=1,y=2)).

Among these options, (1) will hardcode observation information directly in the model while (2) stores them in a context. You could look at the DynamicPPL codebase for a more detailed picture of how it works. We want to unify these options, perhaps towards using (2) only.

This Gibbs refactoring could be an excellent starting point for a design_notes repo to record these thoughts and discussions.

@torfjelde
Copy link
Member

Every subsampler is working with the same, full model, with all the variables, but only makes the changes to a subset of those variables. It still "sees" the whole model. This is what the old Gibbs sampler did.

Overall, I'm also in favour of this @mhauru 👍 I think your reasoning is solid here.

The only other "option" I'm seeing is to keep track of which variables correpond to which varinfos (with each varinfo only containing the relevant information), but then we're effectively just re-implementing a lot of the functionality that is already provided in varinfo 😕

The only "issue" is that this does mean we have to support this "link / transform only part of the varinfo, which does mean we need something "equivalent" to all the getindex(varinfo, sampler) stuff that we've been trying to move away from (since we need a way to extract the vectorized part relevant only for the specific sampler we're going to use in that particular step) 😕

Doulby however, I think we can make this much nicer than the current approach by simply making all these getindex(varinfo, sampler) instead take the relevant varnames instead of the samplers themselves, which should make it all less painful.

But yeah, don't see how we can take approach (1) in a "nice" way, and so I'm also in favour of just trying to make (2) as painless as possible to maintain.

@mhauru
Copy link
Member Author

mhauru commented Oct 11, 2024

Thanks for the comments both, this is very helpful.

Doulby however, I think we can make this much nicer than the current approach by simply making all these getindex(varinfo, sampler) instead take the relevant varnames instead of the samplers themselves, which should make it all less painful.

Yeah, I think this is the way to go.

I tried using ESS instead, because I thought it would test behavior a
bit more broadly, given similarities between HMC and NUTS. It worked
locally, but the KS test fails in one or two cases on CI.
@torfjelde
Copy link
Member

Will do buddio!

@mhauru
Copy link
Member Author

mhauru commented Nov 28, 2024

Maybe give it a moment to see that I get tests to pass again. I messed a bit with the constructors.

@mhauru
Copy link
Member Author

mhauru commented Nov 28, 2024

I got everything except the slowest tests to pass locally, so I think we are good now, @torfjelde.

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added some comments and whatnot, but, as we discused on slack, we'll defer these changes to later / another PR:)

Just make sure to add some issues so we can keep track of it 👍

Awesome, awsome stuff @mhauru ! I imagine this one has been quite painful, so appreciate you stickiing with it until the end 👏

Comment on lines +48 to +52
The naive implementation of `GibbsContext` would simply have a field `target_varnames` that
would be a collection of `VarName`s that the current component sampler is sampling. The
reason we instead have a `Tuple` type parameter listing `Symbol`s is to allow
`is_target_varname` to benefit from compile time constant propagation. This is important
for type stability of `tilde_assume`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. We sure we want to do this?

I'm happy with doing this for now, but this does seem a bit overly restrictive.

There's nothing stopping us from taking the same approach as we do in VarInfo, allowing both handling of the case where compile-time is possible, and also the case when it's not.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in, this is technically breaking functionality of Turing.Experimental.Gibbs, which had as one of the motivation to allow more flexible conditioning, e.g. @varname(x[1]) => ..., similar to condition and fix.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, agree that this is an unfortunate regression of the abilities of the experimental Gibbs, if not of the old main Gibbs sampler. Tor and I discussed this on Slack, and the reason the regression happened was that the parts that were handling all this target_varnames business had to be essentially rewritten to deal with the issues with dynamic models, and thus I first went for something that works and type stable when targets are just Symbols, which is the vast majority of use cases.

I agree with that this should be improved, and as, again, discussed on Slack, shouldn't be impossible to do. I would like to merge this PR though, because it does not regress non-experimental stuff, and this has already gotten too big and long-lived as a PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue to track: #2403

src/mcmc/gibbs.jl Show resolved Hide resolved
src/mcmc/gibbs.jl Outdated Show resolved Hide resolved
src/mcmc/gibbs.jl Outdated Show resolved Hide resolved
test/mcmc/gibbs.jl Outdated Show resolved Hide resolved
@mhauru
Copy link
Member Author

mhauru commented Nov 29, 2024

I'm done making the changes I had in mind. I may still experiment with some performance improvements, but not sure if any will make it in here. I'll also try to reduce the iteration counts in some tests to make them faster, the only CI failure is because one job just timed out at 6h.

Since both Tor and I seem to be happy, I'm gonna ping others in case they want to take a look: @penelopeysm, @willtebbutt, @sunxd3, @yebai. I think we can rely on @torfjelde giving an expert review, everyone else can judge for themselves how thorough a look they want to take, but I think everyone should be at least aware that this, somewhat major, change is happening. If you want to give this PR a review but haven't yet had time, self-request a review and we'll make sure to wait before merging.

For help in reviewing: This PR does a few things:

  1. Deletes the old src/mcmc/gibbs.jl, and the related src/mcmc/gibbs_conditional.jl.
  2. Moves src/experimental/gibbs.jl to be the new src/mcmc/gibbs.jl, and merges test/experimental/gibbs.jl and test/mcmc/gibbs.jl.
  3. Makes a lot of edits to the experimental/new Gibbs to accommodate dynamic models and some other things.
  4. Adds more, new tests to test/mcmc/gibbs.jl.
  5. Introduces RepeatSampler and its tests. This has to be done in the same PR because the old Gibbs had repeat functionality built-in, whereas the new Gibbs doesn't.
  6. Makes a bunch of small changes to various samplers to accommodate the new Gibbs.

Points 4-6 one can reviewed like usual, as a diff of a few hundred lines. Points 2-3 I think are better viewed as a new Gibbs sampler from scratch. The changes in point 3 are so extensive that reading it as a diff doesn't make much sense unless you know the old code really well.

@penelopeysm
Copy link
Member

I'm happy to take a look next week, but doubt I'll get to it today as my head is already several layers deep in DynamicPPL stuff 😄

@mhauru
Copy link
Member Author

mhauru commented Nov 29, 2024

I managed to decrease the iteration counts on a lot of the heaviest tests, the total runtime should be reduced substantially now. They seem to still pass somewhat robustly, i.e. I tried at least two random seeds.

Also did some quick checks of performance overheads, and the previous large overheads are gone in my example cases. Now, rather than being e.g. 100-500% slower than the old Gibbs we are more like 0-50% slower. This for models dominated by overheads from outside model evaluation, i.e. fast models where performance is not a big deal.

@mhauru
Copy link
Member Author

mhauru commented Dec 2, 2024

The Mooncake stack overflows are something @willtebbutt is aware of and knows the reason for, so we can ignore them for now. Would still hold off from merging until they are fixed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Remove old Gibbs sampler, make the experimental one the default
5 participants