Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Poor performance on Ampere vs. Ada with bitpacked weights #4906

Open
mobicham opened this issue Oct 14, 2024 · 2 comments
Open

Poor performance on Ampere vs. Ada with bitpacked weights #4906

mobicham opened this issue Oct 14, 2024 · 2 comments

Comments

@mobicham
Copy link

mobicham commented Oct 14, 2024

I am writing a library to perform different low-bit matmul kernels in Triton/CUDA.
The Triton kernels work great on Ada gpus like the 4090 RTX and the A6000 Ada - on par with Marlin on large matrices. However, the same kernels perform poorly on Ampere gpus (3090, A100 - also H100 by the way).
I also tried the old gpus like the Titan RTX and the 2080 Ti and I get an ok speed up (~2.5x) with INT4 data. So the issue is with Ampere and Hopper.

So I tried to debug every single op and it turns out, there are two strange issues:

  1. tl.load() with repeated indices offs_k[:, None] // elements_per_sample is much slower on the A100 compared to the 4090. Running a gemm with an int4 16k x 16k packed matrix is 3.4x faster than fp16 on the 4090 - as expected for a memory-bound scenario - but it's only 1.13x faster on the A100.

  2. The shift operator when used to unpack the weights can heavily impact the performance depending on the content of the shifter. For example: b >> offs_k[:, None] % 8 is much slower than b >> off_k[:, None] end-to-end on the A100, while this doesn't impact the end-to-end performance on the 4090. (and even if you put it outside the loop, it doesn't impact the result much, so it's really the content of the shifter that seems to impact the speed).

I put a full gist here with numbers: https://gist.github.com/mobicham/06acdeb343490ab4ed1159526cdfa509
I suspect problem 1 is the main blocker making the whole kernel slower.

Would really appreciate any insights, since this really hinders using Triton as a solution for low-bit kernels. Really open to any suggestions to make this thing work!

Thank you very much in advance 🙏

@mobicham
Copy link
Author

mobicham commented Oct 15, 2024

Is this potentially related to the the loop? When I use for k in range(0, total_blocks_k): it's faster for fp16 x fp16, but it crashes throwing this error for lower bits (elements_per_sample>1) on both the A100 and the H100

triton LLVM ERROR: mma16816 data type not supported

When I use for k in tl.range(0, total_blocks_k, 1, num_stages=1), it doesn't crash, but it's slower for fp16 x fp16 🤔

This error is not happening on older gpus like the Titan RTX or the 2080 Ti.
I reported this in a separate issue: #4922

@mobicham
Copy link
Author

mobicham commented Oct 16, 2024

Update:
I was able to get good performance on the 3090 by limiting the number of stages to 1.
However, the same trick doesn't work for the A100, H100, performance is still very poor with tl.load and bitpacked data, including using the latest build from October 14th.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant