Skip to content

Commit

Permalink
[Misc] Add get_name method to attention backends (vllm-project#4685)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and tjohnson31415 committed May 16, 2024
1 parent 1bb5e89 commit 456bcbc
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 12 deletions.
5 changes: 5 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
class AttentionBackend(ABC):
"""Abstract class for attention backends."""

@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError

@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

class FlashAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "flash-attn"

@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
Expand Down
16 changes: 7 additions & 9 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type

try:
import flashinfer
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
except ImportError:
flashinfer = None
flash_attn_varlen_func = None
BatchDecodeWithPagedKVCacheWrapper = None

import flashinfer
import torch
from flash_attn import flash_attn_varlen_func
from flashinfer import BatchDecodeWithPagedKVCacheWrapper

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
Expand All @@ -20,6 +14,10 @@

class FlashInferBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "flashinfer"

@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
return FlashInferImpl
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

class ROCmFlashAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "rocm-flash-attn"

@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
return ROCmFlashAttentionImpl
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

class TorchSDPABackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "torch-sdpa"

@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
return TorchSDPABackendImpl
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@

class XFormersBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "xformers"

@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
return XFormersImpl
Expand Down
5 changes: 2 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
get_attn_backend)
from vllm.attention.backends.flashinfer import FlashInferBackend
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
Expand Down Expand Up @@ -395,7 +394,7 @@ def _prepare_prompt(
dtype=seq_start_loc.dtype,
out=seq_start_loc[1:])

if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
attn_metadata = self.attn_backend.make_metadata(
is_prompt=True,
use_cuda_graph=False,
Expand Down Expand Up @@ -556,7 +555,7 @@ def _prepare_decode(
device=self.device,
)

if self.attn_backend is FlashInferBackend:
if self.attn_backend.get_name() == "flashinfer":
if not hasattr(self, "flashinfer_workspace_buffer"):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
Expand Down

0 comments on commit 456bcbc

Please sign in to comment.