You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am writing a library to perform different low-bit matmul kernels in Triton/CUDA.
The Triton kernels work great on Ada gpus like the 4090 RTX and the A6000 Ada - on par with Marlin on large matrices. However, the same kernels perform poorly on Ampere gpus (3090, A100 - also H100 by the way).
I also tried the old gpus like the Titan RTX and the 2080 Ti and I get an ok speed up (~2.5x) with INT4 data. So the issue is with Ampere and Hopper.
So I tried to debug every single op and it turns out, there are two strange issues:
tl.load() with repeated indices offs_k[:, None] // elements_per_sample is much slower on the A100 compared to the 4090. Running a gemm with an int4 16k x 16k packed matrix is 3.4x faster than fp16 on the 4090 - as expected for a memory-bound scenario - but it's only 1.13x faster on the A100.
The shift operator when used to unpack the weights can heavily impact the performance depending on the content of the shifter. For example: b >> offs_k[:, None] % 8 is much slower than b >> off_k[:, None] end-to-end on the A100, while this doesn't impact the end-to-end performance on the 4090. (and even if you put it outside the loop, it doesn't impact the result much, so it's really the content of the shifter that seems to impact the speed).
Would really appreciate any insights, since this really hinders using Triton as a solution for low-bit kernels. Really open to any suggestions to make this thing work!
Thank you very much in advance 🙏
The text was updated successfully, but these errors were encountered:
Is this potentially related to the the loop? When I use for k in range(0, total_blocks_k): it's faster for fp16 x fp16, but it crashes throwing this error for lower bits (elements_per_sample>1) on both the A100 and the H100
triton LLVM ERROR: mma16816 data type not supported
When I use for k in tl.range(0, total_blocks_k, 1, num_stages=1), it doesn't crash, but it's slower for fp16 x fp16 🤔
This error is not happening on older gpus like the Titan RTX or the 2080 Ti.
I reported this in a separate issue: #4922
Update:
I was able to get good performance on the 3090 by limiting the number of stages to 1.
However, the same trick doesn't work for the A100, H100, performance is still very poor with tl.load and bitpacked data, including using the latest build from October 14th.
I am writing a library to perform different low-bit matmul kernels in Triton/CUDA.
The Triton kernels work great on Ada gpus like the 4090 RTX and the A6000 Ada - on par with Marlin on large matrices. However, the same kernels perform poorly on Ampere gpus (3090, A100 - also H100 by the way).
I also tried the old gpus like the Titan RTX and the 2080 Ti and I get an ok speed up (~2.5x) with INT4 data. So the issue is with Ampere and Hopper.
So I tried to debug every single op and it turns out, there are two strange issues:
tl.load()
with repeated indicesoffs_k[:, None] // elements_per_sample
is much slower on the A100 compared to the 4090. Running a gemm with an int4 16k x 16k packed matrix is 3.4x faster than fp16 on the 4090 - as expected for a memory-bound scenario - but it's only 1.13x faster on the A100.The shift operator when used to unpack the weights can heavily impact the performance depending on the content of the shifter. For example:
b >> offs_k[:, None] % 8
is much slower thanb >> off_k[:, None]
end-to-end on the A100, while this doesn't impact the end-to-end performance on the 4090. (and even if you put it outside the loop, it doesn't impact the result much, so it's really the content of the shifter that seems to impact the speed).I put a full gist here with numbers: https://gist.github.com/mobicham/06acdeb343490ab4ed1159526cdfa509
I suspect problem 1 is the main blocker making the whole kernel slower.
Would really appreciate any insights, since this really hinders using Triton as a solution for low-bit kernels. Really open to any suggestions to make this thing work!
Thank you very much in advance 🙏
The text was updated successfully, but these errors were encountered: