Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: batched sampling for MCMC #1176

Merged
merged 81 commits into from
Jul 30, 2024
Merged

feat: batched sampling for MCMC #1176

merged 81 commits into from
Jul 30, 2024

Conversation

manuelgloeckler
Copy link
Contributor

@manuelgloeckler manuelgloeckler commented Jun 18, 2024

What does this implement/fix? Explain your changes

This pull request aims to implement the sample_batched method for MCMC.

Current problem

  • BasePotential can either "allow_iid" or not. Hence, each batch dimension will be interpreted as IID samples.
    • Replace allow_iid with a mutable attribute (or optional input argument) interpret_as_iid.
    • Remove warning for batched x and default to batched evaluation
  • Refactor all MCMC initialization methods to work with batch dim.
    • resample should break
    • SIR should break
    • proposal should work
  • Add tests to check if correct samples are in each dimension (currently, only shapes are checked)
    • The problem is currently not catched by tests...

The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.

manuelgloeckler and others added 30 commits April 29, 2024 09:04
…posteriors' into amortizedsample"

This reverts commit 07084e2, reversing
changes made to f16622d.
…from-different-posteriors' into amortizedsample
@gmoss13
Copy link
Contributor

gmoss13 commented Jun 27, 2024

I've made some progress now towards this PR, and would like some feedback before I continue.

BasePotential can either "allow_iid" or not.

Given batch_dim_theta!=batch_dim_x, we need to decide how to interpret how to evaluate potential(x,theta). We could return (batch_dim_x,batch_dim_theta) potentials (i.e. every combination), but I am worried this can add a lot of computational overhead, especially when sampling. Instead, the current implementation I suggest that we assume that batch_dim_theta is a multiple of batch_dim_x (i.e. for sampling, we have n chains in theta for each x). In this case we expand the batch dim of x to batch_theta, and match which x goes to which theta. If we are happy with this approach, I'll go ahead and apply this also to the MCMC init_strategy, etc., and make sure this is consistent with other calls.

Remove warning for batched x and default to batched evaluation
Not sure if we want batched evaluation as the default. I think it's easier to do batched evaluation when sample_batched or log_prob_batched is called, and otherwise assume iid (and warn if batch dim >1 as before).

@gmoss13 gmoss13 requested a review from janfb June 27, 2024 16:04
@manuelgloeckler
Copy link
Contributor Author

Great, it looks good. I like that the choice on iid or not can now be made at the set_x method which makes a lot of sense.

I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor).

For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved.

Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great effort, thanks a lot for tacking this 👏

I do have a couple of comments and questions. Happy to discuss in person if needed.

sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/utils/conditional_density_utils.py Outdated Show resolved Hide resolved
sbi/utils/potentialutils.py Outdated Show resolved Hide resolved
sbi/utils/sbiutils.py Outdated Show resolved Hide resolved
tests/posterior_nn_test.py Outdated Show resolved Hide resolved
tests/posterior_nn_test.py Outdated Show resolved Hide resolved
@gmoss13
Copy link
Contributor

gmoss13 commented Jul 19, 2024

Great effort, thanks a lot for tacking this 👏

I do have a couple of comments and questions. Happy to discuss in person if needed.

Thanks for the review! I implemented your suggestions.

An additional point - For posterior_based_potential, indeed we should not allow for iid_x, as this is handled by PermutationInvariantNetwork. Instead, we now always treat x batches as not iid. If the user tries to set potential.set_x(x,x_is_iid=True) with a PosteriorBasedPotential, we raise an error stating this. I added a few test cases in embedding_net_test.py::test_embedding_api_with_multiple_trials to test whether batches of x are interpreted correctly when we use a PermutationInvariantNetwork.

@gmoss13 gmoss13 requested a review from janfb July 19, 2024 15:51
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I added just a couple of last questions..

sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
tests/embedding_net_test.py Show resolved Hide resolved
@gmoss13 gmoss13 requested a review from janfb July 30, 2024 08:11
Copy link
Contributor

@janfb janfb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks a lot, great effort!

@janfb janfb self-assigned this Jul 30, 2024
@janfb janfb added this to the Hackathon and release 2024 milestone Jul 30, 2024
@janfb janfb added the enhancement New feature or request label Jul 30, 2024
@janfb
Copy link
Contributor

janfb commented Jul 30, 2024

closes #990
closes #944

@janfb janfb merged commit 81fffcf into main Jul 30, 2024
5 of 6 checks passed
@janfb janfb deleted the amortized_sample_mcmc branch July 30, 2024 09:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
4 participants