-
-
Notifications
You must be signed in to change notification settings - Fork 5.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Unify the kernel used in flash attention backend #6052
base: main
Are you sure you want to change the base?
[Kernel] Unify the kernel used in flash attention backend #6052
Conversation
I think the direction makes sense! It is also more cuda graph friendly approach
|
Yeah, the PR should be ready for review. Some kernel benchmark numbers on a single A100, all numbers are in ms.
Only one case we see great performance degradation
In all other cases, the performance is quite similar. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM and it's much cleaner now!
cc @rkooo567 and @cadedaniel to have a final pass.
# Fields that are not used in flash attention backend, | ||
# but used in other backends | ||
context_lens_tensor: Optional[torch.Tensor] = None | ||
seq_lens_tensor: Optional[torch.Tensor] = None | ||
max_prefill_seq_len: Optional[int] = None | ||
max_decode_seq_len: Optional[int] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good finding! I'll remove them after refactoring prepare input.
if kv_cache is None or (attn_metadata.block_tables is not None | ||
and attn_metadata.block_tables.numel()) == 0: | ||
k = key | ||
v = value | ||
block_tables = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be for pure prefill or memory profiling? Better to add comment for it.
@@ -151,8 +151,8 @@ def execute_model( | |||
# Currently cuda graph is only supported by the decode phase. | |||
assert model_input.attn_metadata is not None | |||
prefill_meta = model_input.attn_metadata.prefill_metadata | |||
decode_meta = model_input.attn_metadata.decode_metadata | |||
if prefill_meta is None and decode_meta.use_cuda_graph: | |||
if prefill_meta is None and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: This code snippet is removed by #6338 so this isn't a problem anymore.
vllm/worker/model_runner.py
Outdated
@@ -655,6 +655,7 @@ def _prepare_model_input_tensors( | |||
input_positions.append(0) | |||
slot_mapping.append(_PAD_SLOT_ID) | |||
seq_lens.append(1) | |||
query_lens.append(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is used when calculating query_start_loc
, which is the input for the flash attention backend when using the unified kernel.
for attr_expected, attr_actual in zip(vars(attn_metadata.decode_metadata), | ||
vars(decode_meta_actual)): | ||
assert attr_expected[1] == attr_actual[1] | ||
if attn_metadata.prefill_metadata: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it always None for flash attention backend now?
The review ETA is tonight! Besides, I'd like to know the e2e performance improvement (or that it matches the performance). Is it possible to run some e2e benchmark with/without the PR and share the result? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I prefer to see the e2e result before merging it! but PR looks beautiful :))
Looks like the model output is chaos and totally different after unifying the kernel... |
Thanks for reporting, will take a look. |
@jjjjohnson Could you provide the model/prompt you used for testing. The results seem correct for basic_correctness. Thanks! |
@rkooo567
|
I now can reproduce the bug with |
Hmm that's pretty odd. there's nothing lora-related in this kernel iiuc |
btw I saw a CI failure in LM Eval Small Models as follows
Looks like the |
I tried Qwen/Qwen-14B-Chat, without lora, can be any prompt, the result is totally different with or without enforce-eager |
Looks like short prompt is OK, if you change to |
I tried the example_long_prompts with Qwen and it did fail. But after looking into that, it fails for both eager and non-eager mode. It also failed for other backends such as XFORMERS. Therefore, it seems like numerical issues in that case. Did you observe similar things? |
Could you provide the exact prompt and the hardware you use? After some manual checking on H100 with Qwen/Qwen-14B-Chat. Setting enforce-eager or not give the same output. It might also be possible that bugs with cuda graph preparation are not stable. Thanks! |
I use A800 TP1. If I change |
@jjjjohnson , when you say the test fails, was the output gibberish or still something reasonable? Changing the kernel may change the numerics slightly?
I think there is more likelihood to accumulate numerical error for long prompts so this checks out? |
Updates for this PR:
|
Based on your current test case defined in tests/kernels/test_flash_attn.py; here is a modified version: test_varlen_cg.py It should pass the given case for mixed prefill and decode now, with vllm_flash_attn v2.6.2. The major modifications are the following when use
Feel free to try it out. Hope it helps :) |
This pull request has merge conflicts that must be resolved before it can be |
Currently, we are using different kernels for different phases. Concretely, we use
flash_attn_with_kvcache
for decoding phase andflash_attn_varlen_func
for prefill phase and prefix caching. For chunked prefill, we will launch both kernels to handle prefill tokens and decoding tokens separately. The current way has some drawbacks:prefill_metadata
anddecode_metadata
.model_runner
to the backend than needed becauseflash_attn_with_kvcache
andflash_attn_varlen_func
have different requirements for the input.prefill_metadata
anddecode_metadata
on the fly. But this might be minor since we cache the two metadata.Moreover,
flash_attn_with_kvcache
andflash_attn_varlen_func
have similar performance as they share the same underlying implementation.Ideally, we should use a single kernel to handle all cases, including prefill phase, decoding phase, and prefix caching. For chunked prefill, we should just launch a single kernel to handle both the prefill tokens and decoding tokens.
This PR tries to simply the logic in the attention backend and use a single kernel. This is also needed for the MQA scorer (#5691) for speculative decoding.