Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD][Hardware][Misc][Bugfix] xformer cleanup and light navi logic and CI fixes and refactoring #4129

Merged
merged 6 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ steps:
commands:
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py

- label: Core Test
command: pytest -v -s core
Expand Down
5 changes: 1 addition & 4 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE"
ARG FA_GFX_ARCHS="gfx90a;gfx942"
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"

ARG FA_BRANCH="3d2b6f5"
ARG FA_BRANCH="ae7928c"
RUN echo "FA_BRANCH is $FA_BRANCH"

# whether to build flash-attention
Expand Down Expand Up @@ -92,13 +92,10 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip numba
RUN python3 -m pip install xformers==0.0.23 --no-deps

RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& if [ "$BUILD_FA" = "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cd ..
Expand Down
33 changes: 0 additions & 33 deletions patch_xformers.rocm.sh

This file was deleted.

13 changes: 0 additions & 13 deletions rocm_patch/commonpy_xformers-0.0.23.rocm.patch

This file was deleted.

152 changes: 0 additions & 152 deletions rocm_patch/flashpy_xformers-0.0.23.rocm.patch

This file was deleted.

31 changes: 18 additions & 13 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,30 @@ def __init__(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {suppored_head_sizes}.")

self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9
self.use_naive_attn = False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self.use_triton_flash_attn = (os.environ.get(
"VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1"))
if self.use_naive_attn:
# AMD Radeon 7900 series (gfx1100) currently does not support
# xFormers nor FlashAttention. As a temporary workaround, we use
# naive PyTorch implementation of attention.
self.attn_fuc = _naive_attention
logger.debug("Using naive attention in ROCmBackend")
elif self.use_triton_flash_attn:
if self.use_triton_flash_attn:
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
triton_attention)
self.attn_func = triton_attention
logger.debug("Using Triton FA in ROCmBackend")
else:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
# if not using triton, navi3x not use flash-attn either
if torch.cuda.get_device_capability()[0] == 11:
self.use_naive_attn = True
else:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True

if self.use_naive_attn:
self.attn_func = _naive_attention
logger.debug("Using naive attention in ROCmBackend")

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
Expand Down Expand Up @@ -247,13 +252,13 @@ def forward(
# triton attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if self.use_naive_attn or self.use_triton_flash_attn:
if self.use_triton_flash_attn or self.use_naive_attn:
if self.num_kv_heads != self.num_heads:
# Interleave for MQA workaround.
key = self.repeat_kv(key, self.num_queries_per_kv)
value = self.repeat_kv(value, self.num_queries_per_kv)
if self.use_naive_attn:
out = self.attn_fuc(
out = self.attn_func(
query,
key,
value,
Expand Down
Loading