-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Layernorm performance optimization #3662
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mawong-amd Thanks for submitting the PR! This optimization seems to be necessary for MI300x GPUs.
Unfortunately, I didn't see noticeable e2e performance boost for A100 GPUs. Is this expected? Also, I'm a bit worried about whether the new kernels keep the semantics of the current kernels. Could you double check?
scalar_t z = input[blockIdx.x * hidden_size + idx]; | ||
z += residual[blockIdx.x * hidden_size + idx]; | ||
float x = (float) z; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this change the semantics of the kernel since we do the addition in FP16/BF16 instead of FP32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does in theory, however I've not noticed any observable effects from doing the addition in lower precision so far (even the logprobs of generated sequences are identical).
In terms of a possible increase in rounding error, this is likely still negligible compared to typical errors incurred during the reduction phase and in the approximate rsqrt.
The benefit of doing the addition in FP16/BF16 is that it can be implemented as a packed operation. But this step shouldn't be a bottleneck in any case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, makes sense. Thanks for the explanation!
list(REMOVE_ITEM GPU_FLAGS | ||
"-D__CUDA_NO_HALF_OPERATORS__" | ||
"-D__CUDA_NO_HALF_CONVERSIONS__" | ||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__" | ||
"-D__CUDA_NO_HALF2_OPERATORS__") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this affect other CUDA kernels?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could, but I haven't noticed any side effects and neither have the tests. The existing defines seem to originate from Torch's default defines as a legacy item and it's not clear to me if there's a good reason to retain them nowadays (e.g. seems like the recently added Punica extension similarly disables these defines).
If this is a concern, we could either limit the scope of removing these defines to this file or use free functions instead of operators (e.g. __hadd/__hadd2 for __half/__half2 operator+). But this increases code bloat and non-portability even further: the current implementation is already compromised to an extent by the (deficient) headers provided by CUDA/HIP (neither __hadd/__hadd2 as free functions or "heterogeneous" operators like float2::operator*(float) are consistently implemented in CUDA, while conversion operators/constructors are not consistently implemented by both).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. Thanks for the explanation!
Hi, I managed to run a few performance tests on H100 last night and also observed that there was no speed up. I looked at the PTX and SASS assembly and NVCC was not fusing the loads/stores as expected. It appears NVCC needs to know these global memory ops are aligned on a 16 byte boundary to unlock the full 128-bit coalesced op; I've added this alignment requirement to the vector struct and now I'm observing similar speedups on H100. Preliminary numbers I'm seeing on H100 are:
One "drawback" of this change is we can now only enable optimizations when the hidden_size is a multiple of 8 and the tensor pointers are aligned on a 16 byte boundary. But these conditions should be met essentially all the time. As for the changed semantics, I'll discuss it in the relevant review comment thread. Thanks! |
Bulk conversions (packed halfs into half2, using vectors of half2); block and warp reduce with AMD wavesize 64 (vs 32); using smaller block sizes for improved block occupancy on CUs Use larger block sizes for decode; optimize warp and block reduce fully Refactor vector to use half to maintain same alignment as c10::Half; move packed logic into member functions Add a few missing unroll directives Fix blockReduce stall caused by warp divergence on CUDA (vLLM uses universal masks) Refactor vector type to enable optimizations for bf16 Re-apply the blockReduceSum fix for warp divergence Hotfix: Disable BF16 opts due to ROCm 5.7 incompatibility Remove redundant inline specifiers; preparing for upstream
4f94b87
to
a1bbdc4
Compare
Quick update on end-to-end runtime numbers. With the latest changes, I'm seeing small but observable improvements on H100. Specifically, on the latency benchmark (50 iters on each test):
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mawong-amd LGTM! Thanks for the optimization! Didn't know that RMSNorm can affect the performance this much.
I realized that this pr breaks cuda 11.8 support because of the usage of |
I think we can hotfix in a define guard to enable these optimizations only when the cuda version is > 11.8. Let me prepare a diff that does that. |
EDIT: Hotfix created as the following PR #3782 |
@mawong-amd Can you send a PR to land that patch? |
This PR primarily creates optimized specializations of fused_add_rms_norm_kernel, used in many layernorms. It also includes a slightly optimized version of blockReduceSum/warpReduceSum which slightly reduce the number of shuffles done when the max block size is <=512 and known at compile time.
It is observed that fused_add_rms_norm is memory latency bound under many scenarios. The optimized implementation primarily derives its benefits by
The same ideas contained here can be applied to other relatively simple kernels which should be memory bound (e.g. some activation kernels).
More performance numbers can be provided as they become available or if requested. The existing test suite appears sufficient, but additional tests can be created on request.
Some examples of the speed up, as obtained by profiling via benchmark_latency on Llama2-70B (hidden size 8192), FP16, TP = 1, on MI300X:
Another optimization attempted was the use of shared memory, which effectively converts a global memory load into a shared memory load/store pair per item. While this improves performance when applied to baseline, it was not observed to improve performance on top of the current optimizations.
BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE
PR Checklist (Click to Expand)
Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.
PR Title and Classification
Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:
[Bugfix]
for bug fixes.[CI/Build]
for build or continuous integration improvements.[Doc]
for documentation fixes and improvements.[Model]
for adding a new model or improving an existing model. Model name should appear in the title.[Frontend]
For changes on the vLLM frontend (e.g., OpenAI API server,LLM
class, etc.)[Kernel]
for changes affecting CUDA kernels or other compute kernels.[Core]
for changes in the core vLLM logic (e.g.,LLMEngine
,AsyncLLMEngine
,Scheduler
, etc.)[Hardware][Vendor]
for hardware-specific changes. Vendor name should appear in the prefix (e.g.,[Hardware][AMD]
).[Misc]
for PRs that do not fit the above categories. Please use this sparingly.Note: If the PR spans more than one category, please include all relevant prefixes.
Code Quality
The PR need to meet the following code quality standards:
format.sh
to format your code.docs/source/
if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.Notes for Large Changes
Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with
rfc-required
and might not go through the PR.What to Expect for the Reviews
The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:
action-required
label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!