-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable sampling in chunks with external jax samplers #7465
base: main
Are you sure you want to change the base?
Conversation
The test failure seems a bit random - I haven't been able to trigger a failure locally. I get some acceptance_rates pretty far from 0.5 so I'm not sure how stable it's expected to be. |
tests/sampling/test_jax.py
Outdated
@@ -229,7 +229,7 @@ def test_get_log_likelihood(): | |||
b_true = trace.log_likelihood.b.values | |||
a = np.array(trace.posterior.a) | |||
sigma_log_ = np.log(np.array(trace.posterior.sigma)) | |||
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"] | |||
b_jax = jax.vmap(_get_log_likelihood_fn(model))([a, sigma_log_])["b"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why did the behavior (had to) change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For postprocessing I needed to be able to calculate the log_likelihood without the final wrapping vmap. It's possible to have it just calculate the likelihood instead of returning a function. However the extra vmap will still be necessary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to calculate the likelihood instead of using a returned likelihood calculator function.
I don't think this was failing before so might be related to the changes |
@ferrine any opinion on the removal of postprocessing_backend? |
import warnings | ||
|
||
warnings.warn( | ||
"postprocessing_backend={'cpu', 'gpu'} will be removed in a future release, " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should deprecate before rendering the argument useless or raise already. Also can the message mention the alternative is num_chunks now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, makes sense - I can add back that functionality - it's just a few extra branches to keep track of
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added back logic for postprocessing_backend
. It doesn't have any integration with chunked sampling but restores postprocessing all at once on a different backend.
It was splitting the random key 1 extra time as compared to the current behavior. Removing the extra split fixes the failure and I believe the numpyro samples generated will now be identical to those currently generated. However, this means that choosing the wrong key can still trigger the test failure. Based off of pyro-ppl/numpyro#1786, it seems like acceptance_rates won't be super stable. |
Okay if it's not stable feel free to choose the best code and pick a seed that happens to works |
What will be different there what what will be memory consumption? What overhead is put on the gpu/ram? |
What if a single sample does not compile on the gpu? Is it realistic? What about num_samples_in_chunk parameter? |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7465 +/- ##
==========================================
- Coverage 92.44% 92.40% -0.04%
==========================================
Files 103 103
Lines 17119 17153 +34
==========================================
+ Hits 15825 15850 +25
- Misses 1294 1303 +9
|
I don't know in general what the memory complexity of the postprocessing transformations can be. However, when sampling with Practically speaking the postprocessing is jit compiled with the sampling step so if sampling starts then the memory is sufficient (for numpyro if tuning starts then memory is sufficient). I'm not sure I can see a case where postprocessing memory requirements are very high and cpu memory is so dominant of gpu memory that
I'm not sure on if that happens / what the current resolution would be. The parameterization is with |
Is there a proper way to run tests with a gpu backend enabled? My test for |
No, GitHub actions doesn't include gpu in the free plan |
What's the stauts here, can we merge? |
Let me know if anything else needs to be done on my end |
I am leaning a bit on "this is too much complexity on our side". |
I believe you're talking about higher level complexity. But iirc for blackjax the Either way let me know what you decide |
@andrewdipper wanna give a try at that simpler approach? |
Sure, I'll give it a go |
Apologies for the delay - I got caught up. Switched to using blackjax.util.run_inference_algorithm and tried to clarify things a bit. Let me know if you think it's viable. I removed the postprocessing test as it doesn't get run and blackjax chunked sampling will no longer be identical to when it's just a single chunk. I plan to swap in some other sampling tests so the code has test coverage. |
Initial take on extending blackjax and numpyro samplers to be able to sequentially sample multiple chunks. This eliminates the requirement of the gpu having sufficient memory to store all samples at once - they just need to fit in cpu memory.
Changes / features:
num_chunks==1
samples are stored on the sampling device consistent with current behavior. With multiple chunks they are transferred to cpu memorySome question marks:
postprocessing_backend
option is removed. I think this is reasonable as any postprocessing memory requirements should be dominated by the already necessary transpose of the chains and samples dimensions (this is due to vmap(scan) materializing the scan dimension first and subsequently transposing). Unless I'm missing another reason to force the postprocessing backend?Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7465.org.readthedocs.build/en/7465/