Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Test] Add xformer and flash attn tests (vllm-project#3961)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Mo <[email protected]>
  • Loading branch information
2 people authored and SageMoore committed Apr 11, 2024
1 parent 782c1cf commit 6f60fdb
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
import pytest

from vllm.attention.selector import VLLM_ATTENTION_BACKEND

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
Expand All @@ -14,6 +16,7 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True])
@pytest.mark.parametrize("attn_backend", ["XFORMERS", "FLASH_ATTN"])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -22,7 +25,10 @@ def test_models(
dtype: str,
max_tokens: int,
enforce_eager: bool,
attn_backend: str,
monkeypatch,
) -> None:
monkeypatch.setenv(VLLM_ATTENTION_BACKEND, attn_backend)
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
Expand Down
9 changes: 9 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import os
from functools import lru_cache
from typing import Type

Expand All @@ -10,6 +11,8 @@

logger = init_logger(__name__)

VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
Expand Down Expand Up @@ -75,4 +78,10 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
"Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance.")
return _Backend.XFORMERS

backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var is not None:
return _Backend[backend_by_env_var]

# Default case.
return _Backend.FLASH_ATTN

0 comments on commit 6f60fdb

Please sign in to comment.