Skip to content

Commit

Permalink
[Kernel] Add env variable to force flashinfer backend to enable tenso…
Browse files Browse the repository at this point in the history
…r cores (#9497)

Signed-off-by: Thomas Parnell <[email protected]>
Co-authored-by: Chih-Chieh Yang <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
  • Loading branch information
3 people authored Oct 19, 2024
1 parent d11bf43 commit 0c9a525
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 2 deletions.
7 changes: 5 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
Expand Down Expand Up @@ -124,7 +125,8 @@ def _get_decode_wrapper(self):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
Expand Down Expand Up @@ -183,7 +185,8 @@ def graph_capture_get_metadata_for_batch(
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: bool = False
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
Expand Down Expand Up @@ -286,6 +287,11 @@ def get_default_config_root():
"VLLM_USE_FLASHINFER_SAMPLER":
lambda: bool(int(os.getenv("VLLM_USE_FLASHINFER_SAMPLER", "0"))),

# If set, vllm will force flashinfer to use tensor cores;
# otherwise will use heuristic based on model architecture.
"VLLM_FLASHINFER_FORCE_TENSOR_CORES":
lambda: bool(int(os.getenv("VLLM_FLASHINFER_FORCE_TENSOR_CORES", "0"))),

# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
Expand Down

0 comments on commit 0c9a525

Please sign in to comment.