-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: batched sampling for MCMC #1176
Conversation
… amortizedsample
…rs' into amortizedsample
…from-different-posteriors' into amortizedsample
… reshapes in rejection
This reverts commit 17c5343.
I've made some progress now towards this PR, and would like some feedback before I continue.
Given
|
Great, it looks good. I like that the choice on iid or not can now be made at the I would also opt for your suggested option. The question arises because we squeeze the batch_shape into a single dimension, right? For "PyTorch" broadcasting, one would expect something like (1,batch_x_dim, x_dim) and (batch_theta_dim, betach_x_dim, theta_dim) -> (batch_x_dim, batch_theta_dim), so by squeezing the xs, thetas into 2d one would always get a dimension that is a multiple of batch_x_dim (otherwise it cannot be represented by a fixed size tensor). For (1,batch_x_dim,x_dim) and (batch_theta_dim, 1, theta_dim), PyTorch broadcasting semantics would compute all combinations. Unfortunately, after squeezing, these distinctions between cases can no longer be fully preserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great effort, thanks a lot for tacking this 👏
I do have a couple of comments and questions. Happy to discuss in person if needed.
Thanks for the review! I implemented your suggestions. An additional point - For |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I added just a couple of last questions..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thanks a lot, great effort!
What does this implement/fix? Explain your changes
This pull request aims to implement the
sample_batched
method for MCMC.Current problem
BasePotential
can either "allow_iid" or not. Hence, each batch dimension will be interpreted as IID samples.allow_iid
with a mutable attribute (or optional input argument)interpret_as_iid
.The current implementation will let you sample the correct shape, BUT will output the wrong solution. This is because the potential function will broadcast, repeat and finally sum up the first dimension which is incorrect.