From ca6e667603a1825e9feabb1450368bbb9b5b611a Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Mon, 22 Apr 2024 00:57:24 -0400 Subject: [PATCH] [AMD][Hardware][Misc][Bugfix] xformer cleanup and light navi logic and CI fixes and refactoring (#4129) --- .buildkite/test-pipeline.yaml | 2 - Dockerfile.rocm | 5 +- patch_xformers.rocm.sh | 33 ---- .../commonpy_xformers-0.0.23.rocm.patch | 13 -- rocm_patch/flashpy_xformers-0.0.23.rocm.patch | 152 ------------------ vllm/attention/backends/rocm_flash_attn.py | 31 ++-- 6 files changed, 19 insertions(+), 217 deletions(-) delete mode 100644 patch_xformers.rocm.sh delete mode 100644 rocm_patch/commonpy_xformers-0.0.23.rocm.patch delete mode 100644 rocm_patch/flashpy_xformers-0.0.23.rocm.patch diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 0f920c7ec1442..f7c1569696249 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -15,10 +15,8 @@ steps: commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py - - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_basic_correctness.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_ATTENTION_BACKEND=ROCM_FLASH pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test command: pytest -v -s core diff --git a/Dockerfile.rocm b/Dockerfile.rocm index b1c5fac9d78ef..3f84b949481d1 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -14,7 +14,7 @@ RUN echo "Base image is $BASE_IMAGE" ARG FA_GFX_ARCHS="gfx90a;gfx942" RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" -ARG FA_BRANCH="3d2b6f5" +ARG FA_BRANCH="ae7928c" RUN echo "FA_BRANCH is $FA_BRANCH" # whether to build flash-attention @@ -92,13 +92,10 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip numba -RUN python3 -m pip install xformers==0.0.23 --no-deps RUN cd /app \ && cd vllm \ && pip install -U -r requirements-rocm.txt \ - && if [ "$BUILD_FA" = "1" ]; then \ - bash patch_xformers.rocm.sh; fi \ && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ && python3 setup.py install \ && cd .. diff --git a/patch_xformers.rocm.sh b/patch_xformers.rocm.sh deleted file mode 100644 index de427b24d306f..0000000000000 --- a/patch_xformers.rocm.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -set -e - -XFORMERS_VERSION="0.0.23" - -export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)') - -if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then - echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed" - exit 1 -fi - -export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') -export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') - -echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}" -echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}" - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" - patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" -else - echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" -fi - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" - patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" -else - echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" -fi diff --git a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch b/rocm_patch/commonpy_xformers-0.0.23.rocm.patch deleted file mode 100644 index 4d7495cf13e1d..0000000000000 --- a/rocm_patch/commonpy_xformers-0.0.23.rocm.patch +++ /dev/null @@ -1,13 +0,0 @@ ---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000 -+++ common.py 2023-11-28 16:14:19.846233146 +0000 -@@ -298,8 +298,8 @@ - dtype = d.query.dtype - if device_type not in cls.SUPPORTED_DEVICES: - reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") -- if device_type == "cuda" and not _built_with_cuda: -- reasons.append("xFormers wasn't build with CUDA support") -+ #if device_type == "cuda" and not _built_with_cuda: -+ # reasons.append("xFormers wasn't build with CUDA support") - if device_type == "cuda": - device_capability = torch.cuda.get_device_capability(d.device) - if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: diff --git a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch b/rocm_patch/flashpy_xformers-0.0.23.rocm.patch deleted file mode 100644 index ac846728a7a91..0000000000000 --- a/rocm_patch/flashpy_xformers-0.0.23.rocm.patch +++ /dev/null @@ -1,152 +0,0 @@ ---- flash_ori.py 2023-12-13 05:43:31.530752623 +0000 -+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000 -@@ -36,44 +36,44 @@ - - FLASH_VERSION = "0.0.0" - try: -- try: -- from ... import _C_flashattention # type: ignore[attr-defined] -- from ..._cpp_lib import _build_metadata -- -- if _build_metadata is not None: -- FLASH_VERSION = _build_metadata.flash_version -- except ImportError: -- import flash_attn -- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention -- -- FLASH_VERSION = flash_attn.__version__ -- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) -- if ( -- flash_ver_parsed != (2, 3, 6) -- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" -- ): -- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") -+ #try: -+ # from ... import _C_flashattention # type: ignore[attr-defined] -+ # from ..._cpp_lib import _build_metadata -+ -+ # if _build_metadata is not None: -+ # FLASH_VERSION = _build_metadata.flash_version -+ #except ImportError: -+ import flash_attn -+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention -+ -+ FLASH_VERSION = flash_attn.__version__ -+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) -+ # if ( -+ # flash_ver_parsed != (2, 3, 6) -+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" -+ # ): -+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") - - # create library so that flash-attn goes through the PyTorch Dispatcher -- _flash_lib = torch.library.Library("xformers_flash", "DEF") -- -- _flash_lib.define( -- "flash_fwd(Tensor query, Tensor key, Tensor value, " -- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " -- "int max_seqlen_q, int max_seqlen_k, " -- "float p, float softmax_scale, " -- "bool is_causal, int window_left, " -- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" -- ) -+ #_flash_lib = torch.library.Library("xformers_flash", "DEF") - -- _flash_lib.define( -- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " -- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " -- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " -- "int max_seqlen_q, int max_seqlen_k, " -- "float p, float softmax_scale, bool is_causal, " -- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" -- ) -+ #_flash_lib.define( -+ # "flash_fwd(Tensor query, Tensor key, Tensor value, " -+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " -+ # "int max_seqlen_q, int max_seqlen_k, " -+ # "float p, float softmax_scale, " -+ # "bool is_causal, int window_left, " -+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" -+ #) -+ -+ #_flash_lib.define( -+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " -+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " -+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " -+ # "int max_seqlen_q, int max_seqlen_k, " -+ # "float p, float softmax_scale, bool is_causal, " -+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" -+ #) - - def _flash_fwd( - query, -@@ -111,8 +111,8 @@ - p, - softmax_scale, - is_causal, -- window_left, # window_size_left -- window_right, # window_size_right -+ # window_left, # window_size_left -+ # window_right, # window_size_right - return_softmax, - None, # rng - ) -@@ -134,15 +134,15 @@ - out, - cu_seq_lens_q, - cu_seq_lens_k, -- seqused_k, -+ # seqused_k, - max_seq_len_q, - max_seq_len_k, - p, - softmax_scale, - False, - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - return_softmax, - None, - ) -@@ -184,8 +184,8 @@ - p, - softmax_scale, - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - None, - rng_state, - ) -@@ -208,15 +208,15 @@ - softmax_scale, - False, # zero_tensors - is_causal, -- window_left, -- window_right, -+ # window_left, -+ # window_right, - None, - rng_state, - ) - return dq, dk, dv - -- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") -- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") -+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") -+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") - except ImportError: - pass - -@@ -400,7 +400,7 @@ - implementation. - """ - -- OPERATOR = get_operator("xformers_flash", "flash_fwd") -+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") - SUPPORTED_DEVICES: Set[str] = {"cuda"} - CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) - SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c42660fb8f74f..dbaa71fd16add 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -154,25 +154,30 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {suppored_head_sizes}.") - self.use_naive_attn = torch.cuda.get_device_capability()[0] != 9 + self.use_naive_attn = False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = (os.environ.get( "VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")) - if self.use_naive_attn: - # AMD Radeon 7900 series (gfx1100) currently does not support - # xFormers nor FlashAttention. As a temporary workaround, we use - # naive PyTorch implementation of attention. - self.attn_fuc = _naive_attention - logger.debug("Using naive attention in ROCmBackend") - elif self.use_triton_flash_attn: + if self.use_triton_flash_attn: from vllm.attention.ops.triton_flash_attention import ( # noqa: F401 triton_attention) self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") else: - from flash_attn import flash_attn_varlen_func # noqa: F401 - self.attn_func = flash_attn_varlen_func - logger.debug("Using CK FA in ROCmBackend") + # if not using triton, navi3x not use flash-attn either + if torch.cuda.get_device_capability()[0] == 11: + self.use_naive_attn = True + else: + try: + from flash_attn import flash_attn_varlen_func # noqa: F401 + self.attn_func = flash_attn_varlen_func + logger.debug("Using CK FA in ROCmBackend") + except ModuleNotFoundError: + self.use_naive_attn = True + + if self.use_naive_attn: + self.attn_func = _naive_attention + logger.debug("Using naive attention in ROCmBackend") def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" @@ -247,13 +252,13 @@ def forward( # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if self.use_naive_attn or self.use_triton_flash_attn: + if self.use_triton_flash_attn or self.use_naive_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) if self.use_naive_attn: - out = self.attn_fuc( + out = self.attn_func( query, key, value,