Skip to content

jonmorton/pytorch_jax_benchmarks

Repository files navigation

Comparing GPU performance of jax and torch.compile on some standard ML models/layers. Mostly for fun.

The benchmark does a full forward + backward pass, and enables tensorfloat32 matmuls for both frameworks. There is a full sync after each forward+backward.

Results on my setup (YMMV, the many benchmarking caveats apply):

attn_seq1024_dim512_f16
-----------------------
pytorch:  9.5ms ± 0.20ms p90=9.7ms
jax:     11.5ms ± 0.72ms p90=12.1ms

attn_seq1024_dim512_tf32
------------------------
pytorch: 19.6ms ± 1.60ms p90=20.6ms
jax:     23.0ms ± 2.20ms p90=26.5ms

attn_seq2048_dim256_tf32
------------------------
pytorch: 23.6ms ± 0.99ms p90=23.8ms
jax:     25.7ms ± 2.22ms p90=29.7ms

resnet50
--------
pytorch: 58.5ms ± 2.89ms p90=61.6ms
jax:     62.2ms ± 6.87ms p90=64.1ms

Note: attention impl is from scratch (both in a similar manner), not using any built-in modules, to make this a test of the compiler and not the library.

About

benchmarking jax and torch.compile

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages