-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SMC: refactor, speed-up and run multiple chains in parallel for diagnostics #3981
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very interesting NB, thanks @aloctavodia ! I left my comments and questions below 😉
Thanks for the comments @AlexAndorra, I think I addressed all of them :-) |
log_R = np.log(np.random.rand(self.n_steps, self.draws)) | ||
|
||
for n_step in range(self.n_steps): | ||
proposal = floatX(self.posterior + proposals[n_step]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we no longer support discrete RV?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still works for discrete variables, this is not new. We need this. otherwise it fails for float32.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @aloctavodia ! Did another review: I think there were 2 comments from my previous review that you didn't address. I also spotted some other typos and added other questions / suggestions.
@@ -4,7 +4,7 @@ | |||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why you're saying the ESS is 2 -- I see about 50 in the plot above
Reply via ReviewNB
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catcht .there are different ESSs. ESS bulk (the ess for the "central part" of the distribution is around 2, but the plots are showing local ESSs . I will clarify this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok! But how can local ESS be higher than ESS for the whole bulk? I would expect the opposite if bulk includes several local neighborhoods 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a few lines. I will improve the explanations in the ArviZ educational resources
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good now 👌 Thanks for your work on this @aloctavodia !
* Update GP NBs to use standard notebook style (pymc-devs#3978) * update gp-latent nb to use arviz * rerun, run black * rerun after fixes from comments * rerun black * rewrite radon notebook using ArviZ and xarray (pymc-devs#3963) * rewrite radon notebook using ArviZ and xarray Roughly half notebook has been updated * add comments on xarray usage * rewrite 2n half of notebook * minor fix * rerun notebook and minor changes * rerun notebook on pymc3.9.2 and ArviZ 0.9.0 * remove unused import * add change to release notes * SMC: refactor, speed-up and run multiple chains in parallel for diagnostics (pymc-devs#3981) * first attempt to vectorize smc kernel * add ess, remove multiprocessing * run multiple chains * remove unused imports * add more info to report * minor fix * test log * fix type_num error * remove unused imports update BF notebook * update notebook with diagnostics * update notebooks * update notebook * update notebook * Honor discard_tuned_samples during KeyboardInterrupt (pymc-devs#3785) * Honor discard_tuned_samples during KeyboardInterrupt * Do not compute convergence checks without samples * Add time values as sampler stats for NUTS (pymc-devs#3986) * Add time values as sampler stats for NUTS * Use float time counters for nuts stats * Add timing sampler stats to release notes * Improve doc of time related sampler stats Co-authored-by: Alexandre ANDORRA <[email protected]> Co-authored-by: Alexandre ANDORRA <[email protected]> * Drop support for py3.6 (pymc-devs#3992) * Drop support for py3.6 * Update RELEASE-NOTES.md Co-authored-by: Colin <[email protected]> Co-authored-by: Colin <[email protected]> * Fix Mixture distribution mode computation and logp dimensions Closes pymc-devs#3994. * Add more info to divergence warnings (pymc-devs#3990) * Add more info to divergence warnings * Add dataclasses as requirement for py3.6 * Fix tests for extra divergence info * Remove py3.6 requirements * follow-up of py36 drop (pymc-devs#3998) * Revert "Drop support for py3.6 (pymc-devs#3992)" This reverts commit 1bf867e. * Update README.rst * Update setup.py * Update requirements.txt * Update requirements.txt Co-authored-by: Adrian Seyboldt <[email protected]> * Show pickling issues in notebook on windows (pymc-devs#3991) * Merge close remote connection * Manually pickle step method in multiprocess sampling * Fix tests for extra divergence info * Add test for remote process crash * Better formatting in test_parallel_sampling Co-authored-by: Junpeng Lao <[email protected]> * Use mp_ctx forkserver on MacOS * Add test for pickle with dill Co-authored-by: Junpeng Lao <[email protected]> * Fix keep_size for arviz structures. (pymc-devs#4006) * Fix posterior pred. sampling keep_size w/ arviz input. Previously posterior predictive sampling functions did not properly handle the `keep_size` keyword argument when getting an xarray Dataset as parameter. Also extended these functions to accept InferenceData object as input. * Reformatting. * Check type errors. Make errors consistent across sample_posterior_predictive and fast_sample_posterior_predictive, and add 2 tests. * Add changelog entry. Co-authored-by: Robert P. Goldman <[email protected]> * SMC-ABC add distance, refactor and update notebook (pymc-devs#3996) * update notebook * move dist functions out of simulator class * fix docstring * add warning and test for automatic selection of sort sum_stat when using wassertein and energy distances * update release notes * fix typo * add sim_data test * update and add tests * update and add tests * add docs for interpretation of length scales in periodic kernel (pymc-devs#3989) * fix the expression of periodic kernel * revert change and add doc * FIXUP: add suggested doc string * FIXUP: revertchanges in .gitignore * Fix Matplotlib type error for tests (pymc-devs#4023) * Fix for issue 4022. Check for support for `warn` argument in `matplotlib.use()` call. Drop it if it causes an error. * Alternative fix. * Switch from pm.DensityDist to pm.Potential to describe the likelihood in MLDA notebooks and script examples. This is done because of the bug described in arviz-devs/arviz#1279. The commit also changes a few parameters in the MLDA .py example to match the ones in the equivalent notebook. * Remove Dirichlet distribution type restrictions (pymc-devs#4000) * Remove Dirichlet distribution type restrictions Closes pymc-devs#3999. * Add missing Dirichlet shape parameters to tests * Remove Dirichlet positive concentration parameter constructor tests This test can't be performed in the constructor if we're allowing Theano-type distribution parameters. * Add a hack to statically infer Dirichlet argument shapes Co-authored-by: Brandon T. Willard <[email protected]> Co-authored-by: Bill Engels <[email protected]> Co-authored-by: Oriol Abril-Pla <[email protected]> Co-authored-by: Osvaldo Martin <[email protected]> Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Alexandre ANDORRA <[email protected]> Co-authored-by: Colin <[email protected]> Co-authored-by: Brandon T. Willard <[email protected]> Co-authored-by: Junpeng Lao <[email protected]> Co-authored-by: rpgoldman <[email protected]> Co-authored-by: Robert P. Goldman <[email protected]> Co-authored-by: Tirth Patel <[email protected]> Co-authored-by: Brandon T. Willard <[email protected]>
By default this will run more than one "smc chain" in parallel. According to my tests r-hat and ess (as implemented in ArviZ) seems to be useful diagnostics after all (I was skeptic due to the differences between MCMC and SMC). Both ess and R-hat clearly show when something goes wrong! When sampling goes right I have the impression that ess is higher than it should be. I will keep exploring this and will try with an alternative way of computing ess (but that is more on the ArviZ side).
The
log_marginal_likelihood
is no longer stored as a model's attribute instead is saved as atrace.report.log_marginal_likelihood
By looping over the
s_steps
instead of overdraws
I get a ~2X speed-up (this is following @ColCarroll vectorization blog-post!). Using multiprocessing can be slower than not using it for simple models (this was already the case before this PR), but it brings some extra speed-up for more expensive models, so It is set True by default.Tests are updated, but there is a problem with ABC test. If
parallel=True
I getCan't pickle local object 'test_one_gaussian.<locals>.normal_sim'
. The same model runs ok in a Jupyter notebook or script.I still need to update the examples notebooks to reflect these changes, including one example of a diagnostic showing the sampler is failing.