From 456bcbc396d6cdaf2bd426e79114b7826080a6b2 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 8 May 2024 09:59:31 -0700 Subject: [PATCH] [Misc] Add `get_name` method to attention backends (#4685) --- vllm/attention/backends/abstract.py | 5 +++++ vllm/attention/backends/flash_attn.py | 4 ++++ vllm/attention/backends/flashinfer.py | 16 +++++++--------- vllm/attention/backends/rocm_flash_attn.py | 4 ++++ vllm/attention/backends/torch_sdpa.py | 4 ++++ vllm/attention/backends/xformers.py | 4 ++++ vllm/worker/model_runner.py | 5 ++--- 7 files changed, 30 insertions(+), 12 deletions(-) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index b2b6e7ac810e3..02a2fd603faa8 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -9,6 +9,11 @@ class AttentionBackend(ABC): """Abstract class for attention backends.""" + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + @staticmethod @abstractmethod def get_impl_cls() -> Type["AttentionImpl"]: diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index da672d5df6161..bee482c3431c4 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -19,6 +19,10 @@ class FlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flash-attn" + @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: return FlashAttentionImpl diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 2851cbe2396b2..67b99ba2eade4 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -1,16 +1,10 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type -try: - import flashinfer - from flash_attn import flash_attn_varlen_func - from flashinfer import BatchDecodeWithPagedKVCacheWrapper -except ImportError: - flashinfer = None - flash_attn_varlen_func = None - BatchDecodeWithPagedKVCacheWrapper = None - +import flashinfer import torch +from flash_attn import flash_attn_varlen_func +from flashinfer import BatchDecodeWithPagedKVCacheWrapper from vllm import _custom_ops as ops from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, @@ -20,6 +14,10 @@ class FlashInferBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "flashinfer" + @staticmethod def get_impl_cls() -> Type["FlashInferImpl"]: return FlashInferImpl diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index c3b522e63b4b8..10c94f02ff05b 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -17,6 +17,10 @@ class ROCmFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "rocm-flash-attn" + @staticmethod def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]: return ROCmFlashAttentionImpl diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 03825f6023f4c..c1c07abef0ce6 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -15,6 +15,10 @@ class TorchSDPABackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "torch-sdpa" + @staticmethod def get_impl_cls() -> Type["TorchSDPABackendImpl"]: return TorchSDPABackendImpl diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 4c7fa71a2c78e..2a9150dea5875 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -20,6 +20,10 @@ class XFormersBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + return "xformers" + @staticmethod def get_impl_cls() -> Type["XFormersImpl"]: return XFormersImpl diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ab248596490f6..c96f13c590fc4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -9,7 +9,6 @@ from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage, get_attn_backend) -from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce @@ -395,7 +394,7 @@ def _prepare_prompt( dtype=seq_start_loc.dtype, out=seq_start_loc[1:]) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": attn_metadata = self.attn_backend.make_metadata( is_prompt=True, use_cuda_graph=False, @@ -556,7 +555,7 @@ def _prepare_decode( device=self.device, ) - if self.attn_backend is FlashInferBackend: + if self.attn_backend.get_name() == "flashinfer": if not hasattr(self, "flashinfer_workspace_buffer"): # Allocate 16MB workspace buffer # Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html