-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create a benchmarks module #470
Conversation
@fehiepsi - we have made a number of changes to the interface, and jax's caching has changed too since the last version, so if you notice any benchmarks look off, let me know. |
Yup, so far there is a bit regression in hmm benchmark: FYI, Pyro is 1.5x faster than before (with PyTorch 1.3.1), but it is still very slow... In 32bit, NumPyro has a bit smaller n_eff than Stan (566 vs 603), a much higher number of divergences than Stan (282 vs 42). But in 64 bits, n_eff of NumPyro is higher than Stan (752 vs 603), no divergences, 1.5x slower than 32 bits (0.17 vs 0.12 ms/leapfrog). |
I think its fine to use |
|
||
device=cpu | ||
N=100 | ||
benchmark_dir=$( cd $(dirname "$0") ; pwd -P ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi - here's a driver script to run the benchmarks..could you run this with the same system config as the remaining ones? Let me fix the issue with the progress bar first.
i'm afraid i haven't much intuition about hyperparameters.... btw in pyro why don't you merge half-steps when you run multiple verlet steps? (i.e. combine line 46 with line 54) |
I recall that we merged them previously (for HMC) but separating it out during refactoring. Updating those |
oh i'm not making a suggestion. i just would have imagined that merging could give a noticeable perf bump but maybe not |
Instead of copying the old experimental Edward nuts implementation to your repository, why not importing the TFP NUTS implementation (https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/mcmc/nuts.py)? It is also a iterative implementation and you can compile it to XLA+GPU. |
@junpenglao - Thanks for the pointer. We will definitely add the tfp NUTS implementation to the benchmarks. This branch is basically resurrecting some of our old code from a few months back (I think the tfp one was experimental back then, but I could be wrong). |
@junpenglao Could you give me an example of how to use TFP nuts so I can give it a try (I am just curious, so please take your time)? As @neerajprad pointed out, these are codes from a few months ago and we put it here for reproducible purposes. You know, it is not easy to keep up with new release packages for benchmarking. FYI, this |
Thank you both! I think the best place to start is the test case for tfp nuts and my introductory colab. I am more than happy to review and help in subsequent PRs if you are interested in substituting the Edward_nuts with the TFP_nuts. |
Awesome! Thanks, I’ll give it a try tonight!
…On Tue, Dec 10, 2019 at 6:22 AM Junpeng Lao ***@***.***> wrote:
Thank you both! I think the best place to start is the test case for tfp
nuts
<https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/mcmc/nuts_test.py>
and my introductory colab
<https://colab.research.google.com/gist/junpenglao/51cd25c6372f8d2ab3490d4af8f97401/tfp_nuts_demo.ipynb>
.
I am more than happy to review and help in subsequent PRs if you are
interested in substituting the Edward_nuts with the TFP_nuts.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#470?email_source=notifications&email_token=ABEEKVTDNSNAFCGE3ZE3EBLQX6QZNA5CNFSM4JSO6QZKYY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEGPMVAY#issuecomment-564054659>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABEEKVX34HBVYWAYT6QIPATQX6QZNANCNFSM4JSO6QZA>
.
|
@neerajprad @junpenglao It seems that tfp does a bit better job (
Edit: Sorry, I compared tfp with the old numpyro benchmark result. The updated difference (tfp is 1.05x faster) might come from how bernoulli logprob is computed or some other small tensor ops, probably not worth to investigate. :) |
@fehiepsi - For large datasets, I suppose that most of the compute will be dominated by tensor ops on the backward pass. I don't think this is worth investigating unless we have a few more examples. In any case, I think the unrolled NUTS implementation using XLA should be pretty much the same in terms of run time even for CPU. It will be surprising if that's not the case. |
Thanks a lot for the early feedback! Otherwise, once we compile to XLA, everything should give similar speed as TF and JAX are basically different interface to XLA. However, we might see some differences when we do compare large batch size (e.g., num_chain=1, 10, 100, 500) @fehiepsi the slowness in CPU is pretty strange, maybe it is something to do with compiling to XLA - I usually run once and then run again to do the timing. |
@junpenglao Yes, numpyro |
TF got a lot simpler with TF2😉 |
from numpyro.examples.datasets import COVTYPE, load_dataset | ||
|
||
# pyro | ||
import torch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note in case we face this issue in the future: with jaxlib 0.1.37, import torch
might cause GPU memory error when running numpyro mcmc on this example (I don't have a reason for this and also can't come with an isolated code to raise upstream).
Closing this since all tasks are resolved, and all changes will remain in the benchmarks branch. We will periodically update this branch, and add git tags to pin status at specific time points - the last one is pinned at @junpenglao - Thanks for sharing your thoughts on the possible differences in the tfp implementation which could result in a different perf profile, it makes sense. Please feel free to use models from this branch for profiling if you find it useful. Please share any useful findings with us!
In NumPyro, parallel chains work by either using device parallelism ( |
Thanks @neerajprad! Both the current TFP NUTS and Alexey's auto-batched NUTS is doing the same thing as numpyro then: wait for all chains to meet termination (u-turn or divergence), finalized one sample, start the next sample (resample momentum etc). |
Thanks for explaining, @junpenglao.
That's something that we can also easily explore with our current setup - since we are wasting less computation, it is possible that the higher number of drawn samples results in a higher effective sample size, despite early termination. It is worth investigating! |
This creates a benchmarks module in the main repo. Currently this has changes from #469, which will be merged in once that PR lands. This branch should only contain changes to the benchmarks module.
We should ensure the following:
requirements.txt
with pinned dependencies so that the benchmarks are reproducible.