Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a benchmarks module #470

Closed
wants to merge 37 commits into from
Closed

Create a benchmarks module #470

wants to merge 37 commits into from

Conversation

neerajprad
Copy link
Member

@neerajprad neerajprad commented Nov 28, 2019

This creates a benchmarks module in the main repo. Currently this has changes from #469, which will be merged in once that PR lands. This branch should only contain changes to the benchmarks module.

We should ensure the following:

  • Use float64 for CPU if possible and float32 for CUDA, specially when comparing with Stan.
  • Report both time per leapfrog and time per effective sample.
  • Exclude compilation + warmup from the times.
  • Add requirements.txt with pinned dependencies so that the benchmarks are reproducible.

@neerajprad neerajprad mentioned this pull request Nov 28, 2019
3 tasks
@neerajprad
Copy link
Member Author

@fehiepsi - we have made a number of changes to the interface, and jax's caching has changed too since the last version, so if you notice any benchmarks look off, let me know.

@fehiepsi
Copy link
Member

Yup, so far there is a bit regression in hmm benchmark: 0.09 -> 0.12 ms/leapfrog because we use progress_bar=True here. But that is not important IMO. The benefit of enabling probbar=True is we don't need to set num_samples=100,000 (to reduce the contribution of compiling time).

FYI, Pyro is 1.5x faster than before (with PyTorch 1.3.1), but it is still very slow... In 32bit, NumPyro has a bit smaller n_eff than Stan (566 vs 603), a much higher number of divergences than Stan (282 vs 42). But in 64 bits, n_eff of NumPyro is higher than Stan (752 vs 603), no divergences, 1.5x slower than 32 bits (0.17 vs 0.12 ms/leapfrog).

@neerajprad
Copy link
Member Author

Yup, so far there is a bit regression in hmm benchmark: 0.09 -> 0.12 ms/leapfrog because we use progress_bar=True here.

I think its fine to use progress_bar=False for benchmarks. One issue that I notice is that progress_bar=False does not give identical results currently (it is likely wrong). I'll try to debug and push a fix for that.


device=cpu
N=100
benchmark_dir=$( cd $(dirname "$0") ; pwd -P )
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi - here's a driver script to run the benchmarks..could you run this with the same system config as the remaining ones? Let me fix the issue with the progress bar first.

@martinjankowiak
Copy link
Collaborator

i'm afraid i haven't much intuition about hyperparameters....

btw in pyro why don't you merge half-steps when you run multiple verlet steps?

(i.e. combine line 46 with line 54)

@fehiepsi
Copy link
Member

fehiepsi commented Dec 4, 2019

merge half-steps when you run multiple verlet steps

I recall that we merged them previously (for HMC) but separating it out during refactoring. Updating those r sites given a know z_grad is cheap IMO. But we can reinstate if it is necessary for you.

@martinjankowiak
Copy link
Collaborator

oh i'm not making a suggestion. i just would have imagined that merging could give a noticeable perf bump but maybe not

@junpenglao
Copy link

Instead of copying the old experimental Edward nuts implementation to your repository, why not importing the TFP NUTS implementation (https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/mcmc/nuts.py)? It is also a iterative implementation and you can compile it to XLA+GPU.

@neerajprad
Copy link
Member Author

@junpenglao - Thanks for the pointer. We will definitely add the tfp NUTS implementation to the benchmarks. This branch is basically resurrecting some of our old code from a few months back (I think the tfp one was experimental back then, but I could be wrong).

@fehiepsi
Copy link
Member

@junpenglao Could you give me an example of how to use TFP nuts so I can give it a try (I am just curious, so please take your time)? As @neerajprad pointed out, these are codes from a few months ago and we put it here for reproducible purposes. You know, it is not easy to keep up with new release packages for benchmarking. FYI, this benchmarks branch does not serve for state-of-the-art purposes.

@junpenglao
Copy link

Thank you both! I think the best place to start is the test case for tfp nuts and my introductory colab.

I am more than happy to review and help in subsequent PRs if you are interested in substituting the Edward_nuts with the TFP_nuts.

@fehiepsi
Copy link
Member

fehiepsi commented Dec 10, 2019 via email

@fehiepsi
Copy link
Member

fehiepsi commented Dec 19, 2019

@neerajprad @junpenglao It seems that tfp does a bit better job (1.3x 1.05x faster) in this GPU example (see the gist) but slow in CPU (even slower than ed2 - I'm not sure why - probably I missed some configs to make it work - or tf xla is just in an experimental stage).

It would be nice to see if we can improve the speed in GPU. I don't have an answer for this difference. Probably the small operators add up? @junpenglao given your expertise of tfp, do you have any insight for this difference? I think tfp also uses the iterative algorithm, runs in XLA, and most computation time (probably I am wrong about this assumption) should lie at leapfrog step... so the performance on this example should be similar.

@neerajprad I don't want to catch up with the "speed" thing in GPU unless it provides many benefits. It is good to keep the benchmarks as-is because we only want to benchmark against the recursive algorithms. However, we should create an issue to track down this problem (I suspect lax.cond plays a role here) in the future.

Edit: Sorry, I compared tfp with the old numpyro benchmark result. The updated difference (tfp is 1.05x faster) might come from how bernoulli logprob is computed or some other small tensor ops, probably not worth to investigate. :)

@neerajprad
Copy link
Member Author

@fehiepsi - For large datasets, I suppose that most of the compute will be dominated by tensor ops on the backward pass. I don't think this is worth investigating unless we have a few more examples. In any case, I think the unrolled NUTS implementation using XLA should be pretty much the same in terms of run time even for CPU. It will be surprising if that's not the case.

@junpenglao
Copy link

Thanks a lot for the early feedback!
I wouldn't expect a large differences as well, one of the main differences (please correct me if I am wrong on the numpyro end) is that TFP version batched at each leapfrog step, and numpyro version batched at the end using vectorized_map. Since I need to implement in a ways to make sure batching works, there are quite a lot of shape handling that might under-perform/out-perform autobatching in terms of memory (my largest concern) and using TF operations that are slow.

Otherwise, once we compile to XLA, everything should give similar speed as TF and JAX are basically different interface to XLA. However, we might see some differences when we do compare large batch size (e.g., num_chain=1, 10, 100, 500)

@fehiepsi the slowness in CPU is pretty strange, maybe it is something to do with compiling to XLA - I usually run once and then run again to do the timing.

@fehiepsi
Copy link
Member

@junpenglao Yes, numpyro chain_method='vectorized' batches at the end, so it wastes computational time if some chains finish their trajectories earlier than the others :( (but to draw a lot of chains in GPU, this might not be a big problem). For cpu benchmark, I follow your suggestion - you can take a look at this gist (edward2 gives 60ms/leapfrog but tfp+XLA gives 80ms/leapfrog - but as I said, this probably comes from the experimental stage of tf.xla.experimental.compile; your colab with tf_nightly and tf.function(..., autograph=False, experimental_compile=True) might also do better). Anyway, thanks for sharing your suggestions! Using tfp seems to be not as complicated as I have thought previously. :D

@junpenglao
Copy link

TF got a lot simpler with TF2😉
Also, the edward nuts implemented its own leapfrog with a for loop that built a larger graph if I remember correctly, that might be one reason as well.

from numpyro.examples.datasets import COVTYPE, load_dataset

# pyro
import torch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note in case we face this issue in the future: with jaxlib 0.1.37, import torch might cause GPU memory error when running numpyro mcmc on this example (I don't have a reason for this and also can't come with an isolated code to raise upstream).

@neerajprad
Copy link
Member Author

Closing this since all tasks are resolved, and all changes will remain in the benchmarks branch. We will periodically update this branch, and add git tags to pin status at specific time points - the last one is pinned at benchmarks-20191222.

@junpenglao - Thanks for sharing your thoughts on the possible differences in the tfp implementation which could result in a different perf profile, it makes sense. Please feel free to use models from this branch for profiling if you find it useful. Please share any useful findings with us!

I wouldn't expect a large differences as well, one of the main differences (please correct me if I am wrong on the numpyro end) is that TFP version batched at each leapfrog step, and numpyro version batched at the end using vectorized_map.

In NumPyro, parallel chains work by either using device parallelism (pmap) or vectorization. In the latter case, as @fehiepsi mentioned, we need to wait for all the K chains to meet the terminating condition before we can collect the K samples and then repeat. Is this the same with tfp, or is Alexey's auto-batching solution incorporated in the tfp implementation? 🙂

@neerajprad neerajprad closed this Dec 24, 2019
@junpenglao
Copy link

Thanks @neerajprad!

Both the current TFP NUTS and Alexey's auto-batched NUTS is doing the same thing as numpyro then: wait for all chains to meet termination (u-turn or divergence), finalized one sample, start the next sample (resample momentum etc).
Potentially we can add a flag to have chains not wait for anyone - terminate when any chains terminated. This will increase the speed but reduce the effective sample size as the tree building would terminated too early for most chain. Not sure how it would effect num_effective_samples per second but certainly an interesting idea to explore.
We dont have pmap version (which IIUC would be the chains wont wait for each other) but maybe it is possible to use the TF pmap to do that.

@neerajprad
Copy link
Member Author

Thanks for explaining, @junpenglao.

Potentially we can add a flag to have chains not wait for anyone - terminate when any chains terminated. This will increase the speed but reduce the effective sample size as the tree building would terminated too early for most chain. Not sure how it would effect num_effective_samples per second but certainly an interesting idea to explore.

That's something that we can also easily explore with our current setup - since we are wasting less computation, it is possible that the higher number of drawn samples results in a higher effective sample size, despite early termination. It is worth investigating!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants