Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm][Hardware][AMD] Use Triton Kernel for default FA on ROCm #3643

Merged

Conversation

jpvillam-amd
Copy link
Contributor

@jpvillam-amd jpvillam-amd commented Mar 26, 2024

This PR creates and makes default new triton exclusive backend for attention. Additionally removes some unsupported arguments from AMD's version of the flash_attn_varlen_func function.

  • Changed selector to allow picking between multiple backends rather than just between two.
  • Added a new triton FA backend available to ROCm.
  • Added new VLLM_USE_FLASH_ATTN_TRITON option to be able to swap between Triton and Default FA
  • Removed unsupported attributes from AMD's flash_attn_varlen_func
  • Added latest build of AMD's triton to Dockerfile.rocm

Appreciate any and all feedback, and apologies for the long PR

Co-authored-by: Vinayak Gokhale [email protected]

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

Copy link
Collaborator

@hongxiayang hongxiayang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have verified the changes on MI250x and MI300x.
lgtm.

@hongxiayang
Copy link
Collaborator

Since latest ray has problem with tp>1 when running on AMD system now, can we also do this in Dockerfile.rocm:
RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3

@gshtras
Copy link
Contributor

gshtras commented Mar 28, 2024

Since latest ray has problem with tp>1 when running on AMD system now, can we also do this in Dockerfile.rocm: RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3

Also in requirements-rocm.txt

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of the CK-based FlashAttention, as we discussed offline in today's meeting?

vllm/attention/backends/flash_attn_triton.py Outdated Show resolved Hide resolved
@hongxiayang
Copy link
Collaborator

hongxiayang commented Mar 28, 2024

Can we get rid of the CK-based FlashAttention, as we discussed offline in today's meeting?

We can leave the clean up task in a different pull request, and I will re-open the cleanup-xformer PR (#3558) for that purpose.
There is also concern regarding different data type support.

Now users can use the flag to choose not to install FA during building the docker image using Dockerfile.rocm.

@jpvillam-amd jpvillam-amd requested a review from WoosukKwon March 29, 2024 20:34
@WoosukKwon WoosukKwon self-assigned this Mar 29, 2024
Comment on lines 196 to 222
if is_hip():
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len,
softmax_scale=self.scale,
causal=True,
)
else:
output = flash_attn_varlen_func(
q=query,
k=key,
v=value,
cu_seqlens_q=attn_metadata.seq_start_loc,
cu_seqlens_k=attn_metadata.seq_start_loc,
max_seqlen_q=attn_metadata.max_prompt_len,
max_seqlen_k=attn_metadata.max_prompt_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)

Copy link
Collaborator

@WoosukKwon WoosukKwon Mar 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Actually in #3648, we are planning to use FlashAttention's recent APIs that support attention keys and values stored in paged KV cache. I believe the new APIs are incompatible with AMD GPUs.

Can we just use Triton FlashAttention at all times?

@WoosukKwon
Copy link
Collaborator

@jpvillam-amd Can we have ROCmAttentionBackend that selects one of the four different implementations of prefill Attention (i.e., Triton, CK, xFormers, naive) while using PagedAttention for decoding?

I feel like mixing the code for NVIDIA GPUs and AMD GPUs for attention is not good at the moment, since the APIs are not unified yet. For example, as I mentioned above, FlashAttentionBackend will use flash-attn's paged KV cache APIs, which are not supported for AMD GPUs at the moment.

@jpvillam-amd
Copy link
Contributor Author

@jpvillam-amd Can we have ROCmAttentionBackend that selects one of the four different implementations of prefill Attention (i.e., Triton, CK, xFormers, naive) while using PagedAttention for decoding?

I feel like mixing the code for NVIDIA GPUs and AMD GPUs for attention is not good at the moment, since the APIs are not unified yet. For example, as I mentioned above, FlashAttentionBackend will use flash-attn's paged KV cache APIs, which are not supported for AMD GPUs at the moment.

I will rename the backend and have it container all the selector logic for ROCm. I will leave xFormers out of the backend and have it selected with the current selector code same as CUDA

@WoosukKwon Does that sounds ok?

@jpvillam-amd jpvillam-amd requested a review from WoosukKwon April 2, 2024 22:05
@WoosukKwon
Copy link
Collaborator

@jpvillam-amd Thanks for updating the PR. Do you mind if I directly edit this PR? I refactored the PR a bit and wanted to directly upstream it to this PR for faster integration. If you don't mind, could you allow editing the PR? Your branch is protected now.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpvillam-amd LGTM. Thanks for updating the PR. Could you please submit another PR to test the backend? Thanks.

@WoosukKwon WoosukKwon merged commit 6c0b045 into vllm-project:main Apr 9, 2024
67 checks passed
SageMoore pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 11, 2024
…project#3643)

Co-authored-by: jpvillam <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
andy-neuma pushed a commit to neuralmagic/nm-vllm that referenced this pull request Apr 12, 2024
…project#3643)

Co-authored-by: jpvillam <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
z103cb pushed a commit to z103cb/opendatahub_vllm that referenced this pull request Apr 22, 2024
…project#3643)

Co-authored-by: jpvillam <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Temirulan pushed a commit to Temirulan/vllm-whisper that referenced this pull request Sep 6, 2024
…project#3643)

Co-authored-by: jpvillam <[email protected]>
Co-authored-by: Gregory Shtrasberg <[email protected]>
Co-authored-by: Woosuk Kwon <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants