Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] support Radeon™ 7900 series (gfx1100) without using flash-attention #2768

Merged
merged 4 commits into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
ARG FA_BRANCH="3d2b6f5"
RUN echo "FA_BRANCH is $FA_BRANCH"

# whether to build flash-attention
# if 0, will not build flash attention
# this is useful for gfx target where flash-attention is not supported
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -50,7 +56,8 @@ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:

# Install ROCm flash-attention
RUN mkdir libs \
RUN if [ "$BUILD_FA" == "1" ]; then \
mkdir libs \
&& cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \
Expand All @@ -60,7 +67,8 @@ RUN mkdir libs \
&& if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
&& python3 setup.py install \
&& cd ..
&& cd ..; \
fi

COPY ./ /app/vllm

Expand All @@ -75,7 +83,8 @@ RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& bash patch_xformers.rocm.sh \
&& if [ "$BUILD_FA" == "1" ]; then \
bash patch_xformers.rocm.sh; fi \
&& patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
&& python3 setup.py install \
&& cd ..
Expand Down
3 changes: 2 additions & 1 deletion docs/source/getting_started/amd-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Requirements

* OS: Linux
* Python: 3.8 -- 3.11
* GPU: MI200s (gfx90a), MI300 (gfx942)
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
* Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)

Expand Down Expand Up @@ -105,6 +105,7 @@ The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.

Their values can be passed in when running ``docker build`` with ``--build-arg`` options.

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Supported NVIDIA GPU architectures.
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx942", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)


Expand Down
45 changes: 45 additions & 0 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Multi-head attention."""
from typing import List, Optional

import importlib
import torch
import torch.nn as nn
from xformers import ops as xops
Expand Down Expand Up @@ -58,6 +59,40 @@ def __init__(
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

self.use_ref_attention = self.check_use_ref_attention()

def check_use_ref_attention(self) -> bool:
if not is_hip():
return False
# For ROCm, check whether flash attention is installed or not.
# if not, use_ref_attention needs to be True
return importlib.util.find_spec("flash_attn") is None

def ref_masked_attention(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> torch.Tensor:
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

seq_len, _, _ = query.shape
attn_mask = torch.triu(torch.ones(seq_len,
seq_len,
dtype=query.dtype,
device=query.device),
diagonal=1)
attn_mask = attn_mask * torch.finfo(query.dtype).min

attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
key).float()
attn_weights = attn_weights + attn_mask.float()
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
return out

def forward(
self,
query: torch.Tensor,
Expand Down Expand Up @@ -137,6 +172,16 @@ def forward(
self.alibi_slopes, self.num_kv_heads, batch_size,
seq_len, query.dtype)

if self.use_ref_attention:
output = self.ref_masked_attention(
query,
key,
value,
)
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
return output.reshape(batch_size, seq_len, hidden_size)

# TODO(woosuk): Too many view operations. Let's try to reduce
# them in the future for code readability.
if self.alibi_slopes is None:
Expand Down
Loading