Skip to content

Commit

Permalink
[Misc] Fix flash attention backend log (#4368)
Browse files Browse the repository at this point in the history
  • Loading branch information
esmeetu authored Apr 25, 2024
1 parent b5b4a39 commit b6dcb4d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class _Backend(enum.Enum):
def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
backend = _which_attn_to_use(dtype)
if backend == _Backend.FLASH_ATTN:
logger.info("Using FlashAttention backend.")
logger.info("Using FlashAttention-2 backend.")
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
Expand Down Expand Up @@ -62,21 +62,21 @@ def _which_attn_to_use(dtype: torch.dtype) -> _Backend:
# NVIDIA GPUs.
if torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("Cannot use FlashAttention backend for Volta and Turing "
logger.info("Cannot use FlashAttention-2 backend for Volta and Turing "
"GPUs.")
return _Backend.XFORMERS

if dtype not in (torch.float16, torch.bfloat16):
logger.info("Cannot use FlashAttention backend for dtype other than "
logger.info("Cannot use FlashAttention-2 backend for dtype other than "
"torch.float16 or torch.bfloat16.")
return _Backend.XFORMERS

try:
import flash_attn # noqa: F401
except ImportError:
logger.info(
"Cannot use FlashAttention backend because the flash_attn package "
"is not found. Please install it for better performance.")
"Cannot use FlashAttention-2 backend because the flash_attn "
"package is not found. Please install it for better performance.")
return _Backend.XFORMERS

backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
Expand Down

0 comments on commit b6dcb4d

Please sign in to comment.