Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: batched mcmc reshaping #1210

Merged
merged 1 commit into from
Aug 2, 2024
Merged

fix: batched mcmc reshaping #1210

merged 1 commit into from
Aug 2, 2024

Conversation

janfb
Copy link
Contributor

@janfb janfb commented Aug 1, 2024

when the num_chains was higher or not a multiple of the number of samples to be generated with MCMC sample_batched, the samples were reshaped incorrectly: The remaining samples were pushed into the last dimension here:

return samples.reshape((batch_size, *sample_shape, -1)).permute( # type: ignore

This is a fix where we first collect all the generated samples from the chains (possibly more than needed), then select as many as we need and then reshape into the desired (*sample_shape, batch_size, input_shape).

I also added tests that cover all the cases, which makes them quite slow.

@janfb janfb added the bug Something isn't working label Aug 1, 2024
@janfb janfb requested a review from gmoss13 August 1, 2024 15:51
@janfb janfb self-assigned this Aug 1, 2024
@janfb
Copy link
Contributor Author

janfb commented Aug 1, 2024

@gmoss13 I had to change one thing in your reshape-permute magic to make it work for len(sample_shape) > 1. Can you please double check that this does not mess up the collection of samples from the chains?

Copy link

codecov bot commented Aug 1, 2024

Codecov Report

Attention: Patch coverage is 91.66667% with 1 line in your changes missing coverage. Please review.

Project coverage is 75.86%. Comparing base (ba19688) to head (7e76605).
Report is 9 commits behind head on main.

Files Patch % Lines
sbi/inference/posteriors/mcmc_posterior.py 91.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1210      +/-   ##
==========================================
- Coverage   84.55%   75.86%   -8.70%     
==========================================
  Files          96       97       +1     
  Lines        7603     7682      +79     
==========================================
- Hits         6429     5828     -601     
- Misses       1174     1854     +680     
Flag Coverage Δ
unittests 75.86% <91.66%> (-8.70%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
sbi/inference/posteriors/mcmc_posterior.py 85.93% <91.66%> (-0.26%) ⬇️

... and 28 files with indirect coverage changes

Copy link
Contributor

@gmoss13 gmoss13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @janfb, I think these changes fix the issues we discussed! I added some suggestions, happy to discuss further (also happy to implement them myself if you agree).

sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Outdated Show resolved Hide resolved
sbi/inference/posteriors/mcmc_posterior.py Show resolved Hide resolved
tests/posterior_nn_test.py Show resolved Hide resolved
tests/posterior_nn_test.py Outdated Show resolved Hide resolved
@janfb janfb force-pushed the fix-batched-mcmc-reshaping branch from 62c24db to 295ff42 Compare August 2, 2024 06:53
@janfb janfb requested a review from gmoss13 August 2, 2024 06:55
@janfb janfb changed the title fix: batched mcmc reshaping' fix: batched mcmc reshaping Aug 2, 2024
Copy link
Contributor

@gmoss13 gmoss13 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Jan! Added a couple of very minor comments.

@janfb janfb added this to the Hackathon and release 2024 milestone Aug 2, 2024
@janfb janfb force-pushed the fix-batched-mcmc-reshaping branch from 295ff42 to 7e76605 Compare August 2, 2024 14:08
@janfb janfb merged commit d9e6a34 into main Aug 2, 2024
6 checks passed
@janfb janfb deleted the fix-batched-mcmc-reshaping branch August 2, 2024 14:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants