From 9b3a23f6464f7a0d4f527499ea803990ca1134a9 Mon Sep 17 00:00:00 2001 From: Matt Wong <156021403+mawong-amd@users.noreply.github.com> Date: Sat, 20 Jul 2024 11:39:07 -0500 Subject: [PATCH] [Bugfix][CI/Build][Hardware][AMD] Fix AMD tests, add HF cache, update CK FA, add partially supported model notes (#6543) --- .buildkite/run-amd-test.sh | 7 +++ .buildkite/test-pipeline.yaml | 3 +- CMakeLists.txt | 4 +- Dockerfile.rocm | 60 +++++++++++-------- .../getting_started/amd-installation.rst | 7 ++- requirements-rocm.txt | 4 ++ tests/basic_correctness/test_cpu_offload.py | 9 ++- tests/models/test_paligemma.py | 18 +++++- tests/models/test_phi3v.py | 9 ++- vllm/attention/backends/rocm_flash_attn.py | 8 +++ vllm/model_executor/models/__init__.py | 17 +++++- vllm/spec_decode/draft_model_runner.py | 9 ++- 12 files changed, 116 insertions(+), 39 deletions(-) diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 363bc07fc2de4..618d712b0279b 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -66,11 +66,18 @@ trap remove_docker_container EXIT echo "--- Running container" +HF_CACHE="$(realpath ~)/huggingface" +mkdir -p ${HF_CACHE} +HF_MOUNT="/root/.cache/huggingface" + docker run \ --device /dev/kfd --device /dev/dri \ --network host \ + --shm-size=16gb \ --rm \ -e HF_TOKEN \ + -v ${HF_CACHE}:${HF_MOUNT} \ + -e HF_HOME=${HF_MOUNT} \ --name ${container_name} \ ${image_name} \ /bin/bash -c "${@}" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ae2b36653bad1..e7dd1fdb2e660 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -44,7 +44,8 @@ steps: mirror_hardwares: [amd] fast_check: true commands: - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl + # This flashinfer installation will fail on AMD ROCm, so it is set as optional. + - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl || true - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py diff --git a/CMakeLists.txt b/CMakeLists.txt index d6bef748516e0..e83e478ab068e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,7 +33,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # versions are derived from Dockerfile.rocm # set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1") -set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0") # # Try to find python package with an executable that exactly matches @@ -101,7 +101,7 @@ elseif(HIP_FOUND) # ROCm 5.X and 6.X if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM}) - message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} " + message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} " "expected for ROCm build, saw ${Torch_VERSION} instead.") endif() else() diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 85dfda8dbb532..ff39791456398 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -4,18 +4,21 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" # Default ROCm ARCHes to build vLLM for. ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" -# Whether to build CK-based flash-attention -# If 0, will not build flash attention -# This is useful for gfx target where flash-attention is not supported -# (i.e. those that do not appear in `FA_GFX_ARCHS`) -# Triton FA is used by default on ROCm now so this is unnecessary. +# Whether to install CK-based flash-attention +# If 0, will not install flash-attention ARG BUILD_FA="1" +# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL` +# If this succeeds, we use the downloaded wheel and skip building flash-attention. +# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the +# architectures specified in `FA_GFX_ARCHS` +ARG TRY_FA_WHEEL="1" +ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl" ARG FA_GFX_ARCHS="gfx90a;gfx942" -ARG FA_BRANCH="ae7928c" +ARG FA_BRANCH="23a2b1c2" # Whether to build triton on rocm ARG BUILD_TRITON="1" -ARG TRITON_BRANCH="0ef1848" +ARG TRITON_BRANCH="e0fc12c" ### Base image build stage FROM $BASE_IMAGE AS base @@ -43,15 +46,15 @@ RUN apt-get update && apt-get install -y \ ARG APP_MOUNT=/vllm-workspace WORKDIR ${APP_MOUNT} -RUN pip install --upgrade pip +RUN python3 -m pip install --upgrade pip # Remove sccache so it doesn't interfere with ccache # TODO: implement sccache support across components -RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)" +RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" # Install torch == 2.5.0 on ROCm RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ *"rocm-6.1"*) \ - pip uninstall -y torch torchaudio torchvision \ - && pip install --no-cache-dir --pre \ + python3 -m pip uninstall -y torch torchaudio torchvision \ + && python3 -m pip install --no-cache-dir --pre \ torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \ torchvision==0.20.0.dev20240710 \ --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ @@ -70,24 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache FROM base AS build_amdsmi # Build amdsmi wheel always RUN cd /opt/rocm/share/amd_smi \ - && pip wheel . --wheel-dir=/install + && python3 -m pip wheel . --wheel-dir=/install ### Flash-Attention wheel build stage FROM base AS build_fa ARG BUILD_FA +ARG TRY_FA_WHEEL +ARG FA_WHEEL_URL ARG FA_GFX_ARCHS ARG FA_BRANCH # Build ROCm flash-attention wheel if `BUILD_FA = 1` RUN --mount=type=cache,target=${CCACHE_DIR} \ if [ "$BUILD_FA" = "1" ]; then \ - mkdir -p libs \ - && cd libs \ - && git clone https://github.com/ROCm/flash-attention.git \ - && cd flash-attention \ - && git checkout "${FA_BRANCH}" \ - && git submodule update --init \ - && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ + if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \ + # If a suitable wheel exists, we download it instead of building FA + mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \ + else \ + mkdir -p libs \ + && cd libs \ + && git clone https://github.com/ROCm/flash-attention.git \ + && cd flash-attention \ + && git checkout "${FA_BRANCH}" \ + && git submodule update --init \ + && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ + fi; \ # Create an empty directory otherwise as later build stages expect one else mkdir -p /install; \ fi @@ -126,7 +136,7 @@ RUN case "$(which python3)" in \ # Package upgrades for useful functionality or to avoid dependency issues RUN --mount=type=cache,target=/root/.cache/pip \ - pip install --upgrade numba scipy huggingface-hub[cli] + python3 -m pip install --upgrade numba scipy huggingface-hub[cli] # Make sure punica kernels are built (for LoRA) ENV VLLM_INSTALL_PUNICA_KERNELS=1 @@ -137,7 +147,7 @@ ENV TOKENIZERS_PARALLELISM=false RUN --mount=type=cache,target=${CCACHE_DIR} \ --mount=type=cache,target=/root/.cache/pip \ - pip install -U -r requirements-rocm.txt \ + python3 -m pip install -Ur requirements-rocm.txt \ && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ *"rocm-6.1"*) \ # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM @@ -153,7 +163,7 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \ mkdir -p libs \ && cp /install/*.whl libs \ # Preemptively uninstall to avoid same-version no-installs - && pip uninstall -y amdsmi; + && python3 -m pip uninstall -y amdsmi; # Copy triton wheel(s) into final image if they were built RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ @@ -161,7 +171,7 @@ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \ && if ls /install/*.whl; then \ cp /install/*.whl libs \ # Preemptively uninstall to avoid same-version no-installs - && pip uninstall -y triton; fi + && python3 -m pip uninstall -y triton; fi # Copy flash-attn wheel(s) into final image if they were built RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ @@ -169,11 +179,11 @@ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \ && if ls /install/*.whl; then \ cp /install/*.whl libs \ # Preemptively uninstall to avoid same-version no-installs - && pip uninstall -y flash-attn; fi + && python3 -m pip uninstall -y flash-attn; fi # Install wheels that were built to the final image RUN --mount=type=cache,target=/root/.cache/pip \ if ls libs/*.whl; then \ - pip install libs/*.whl; fi + python3 -m pip install libs/*.whl; fi CMD ["/bin/bash"] diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 1f9e4fabc4fc9..61efad2013b2a 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -90,12 +90,12 @@ Alternatively, you can install PyTorch using PyTorch wheels. You can check PyTor Install ROCm's Triton flash attention (the default triton-mlir branch) following the instructions from `ROCm/triton `_ -2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_ +2. Optionally, if you choose to use CK flash attention, you can install `flash attention for ROCm `_ -Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/flash-attention `_ +Install ROCm's flash attention (v2.5.9.post1) following the instructions from `ROCm/flash-attention `_ +Alternatively, wheels intended for vLLM use can be accessed under the releases. .. note:: - - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) 3. Build vLLM. @@ -110,5 +110,6 @@ Install ROCm's flash attention (v2.0.4) following the instructions from `ROCm/fl .. tip:: - Triton flash attention is used by default. For benchmarking purposes, it is recommended to run a warm up step before collecting perf numbers. + - Triton flash attention does not currently support sliding window attention. If using half precision, please use CK flash-attention for sliding window support. - To use CK flash-attention or PyTorch naive attention, please use this flag ``export VLLM_USE_TRITON_FLASH_ATTN=0`` to turn off triton flash attention. - The ROCm version of PyTorch, ideally, should match the ROCm driver version. diff --git a/requirements-rocm.txt b/requirements-rocm.txt index cc42839a975d0..cc955e279a845 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,5 +2,9 @@ -r requirements-common.txt # Dependencies for AMD GPUs +awscli +boto3 +botocore ray >= 2.10.0 +peft pytest-asyncio diff --git a/tests/basic_correctness/test_cpu_offload.py b/tests/basic_correctness/test_cpu_offload.py index 3ab01d52277d7..9ebcc48a9b93e 100644 --- a/tests/basic_correctness/test_cpu_offload.py +++ b/tests/basic_correctness/test_cpu_offload.py @@ -1,8 +1,13 @@ +from vllm.utils import is_hip + from ..utils import compare_two_settings def test_cpu_offload(): compare_two_settings("meta-llama/Llama-2-7b-hf", [], ["--cpu-offload-gb", "4"]) - compare_two_settings("nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", - [], ["--cpu-offload-gb", "1"]) + if not is_hip(): + # compressed-tensors quantization is currently not supported in ROCm. + compare_two_settings( + "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t", [], + ["--cpu-offload-gb", "1"]) diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index 81afd11a6e697..e1c39ee6fecb6 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple, Type import pytest @@ -5,6 +6,7 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs +from vllm.utils import is_hip from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close @@ -22,6 +24,12 @@ models = ["google/paligemma-3b-mix-224"] +# ROCm Triton FA can run into compilation issues with these models due to, +# excessive use of shared memory. Use other backends in the meantime. +# FIXME (mattwong, gshtrasb, hongxiayan) +if is_hip(): + os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" + def vllm_to_hf_output(vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]], @@ -130,7 +138,15 @@ def run_test( [0.25, 0.5, 1.0], ], ) -@pytest.mark.parametrize("dtype", ["float", "half"]) +@pytest.mark.parametrize("dtype", [ + pytest.param( + "float", + marks=pytest.mark.skipif( + is_hip(), + reason= + "ROCm FA does not yet fully support 32-bit precision on PaliGemma") + ), "half" +]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 636a9d3f1a65e..9da25ab8d78fe 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -1,3 +1,4 @@ +import os import re from typing import List, Optional, Tuple, Type @@ -6,7 +7,7 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from vllm.utils import is_cpu +from vllm.utils import is_cpu, is_hip from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close @@ -47,6 +48,12 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, if is_cpu(): target_dtype = "bfloat16" +# ROCm Triton FA can run into shared memory issues with these models, +# use other backends in the meantime +# FIXME (mattwong, gshtrasb, hongxiayan) +if is_hip(): + os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0" + def run_test( hf_runner: Type[HfRunner], diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 17c3b25034bf3..058c8df0eaf8b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -275,6 +275,12 @@ def __init__( triton_attention) self.attn_func = triton_attention logger.debug("Using Triton FA in ROCmBackend") + if self.sliding_window != (-1, -1): + logger.warning("ROCm Triton FA does not currently support " + "sliding window attention. If using half " + "precision, please try using the ROCm CK " + "FA backend instead by setting the env var " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") else: # if not using triton, navi3x/navi21/navi10 do not use flash-attn # either @@ -434,6 +440,8 @@ def forward( max_seqlen_k=prefill_meta.max_prefill_seq_len, softmax_scale=self.scale, causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, ) # common code for prefill diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index aa5a70757b31c..f3c3fe31c68a3 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -87,13 +87,24 @@ # Models partially supported by ROCm. # Architecture -> Reason. +_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") _ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { "Qwen2ForCausalLM": - "Sliding window attention is not yet supported in ROCm's flash attention", + _ROCM_SWA_REASON, "MistralForCausalLM": - "Sliding window attention is not yet supported in ROCm's flash attention", + _ROCM_SWA_REASON, "MixtralForCausalLM": - "Sliding window attention is not yet supported in ROCm's flash attention", + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma"), + "Phi3VForCausalLM": + ("ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index d2c7e6e3710a8..95071ecb6c8da 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -3,7 +3,14 @@ import torch from vllm import _custom_ops as ops -from vllm.attention.backends.flash_attn import FlashAttentionMetadata + +try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata +except ModuleNotFoundError: + # vllm_flash_attn is not installed, use the identical ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) + from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig)