Skip to content

Commit

Permalink
Force paged attention v2 for long contexts (#1510)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Nov 1, 2023
1 parent 1fe0990 commit 9738b84
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 29 deletions.
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def single_query_cached_kv_attention(
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
attention_ops.paged_attention_v1(
Expand Down
29 changes: 1 addition & 28 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
from vllm.utils import get_gpu_memory


class Worker:
Expand Down Expand Up @@ -141,13 +141,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window

if self.sliding_window is None:
max_seq_len = self.scheduler_config.max_model_len
else:
max_seq_len = min(self.scheduler_config.max_model_len,
self.sliding_window)
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)

self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
Expand Down Expand Up @@ -421,26 +414,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [pad] * (max_len - len(x))


def _check_if_can_support_max_seq_len(max_seq_len: int,
block_size: int) -> None:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem = get_max_shared_memory_bytes()
float32_bytes = torch.finfo(torch.float).bits // 8
padded_max_seq_len = (
(max_seq_len + block_size - 1) / block_size) * block_size
# padded_max_seq_len + extra buffer
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
if padded_max_seq_len * float32_bytes > max_shared_mem:
raise RuntimeError(
f"vLLM cannot currently support max_model_len={max_seq_len} "
f"with block_size={block_size} on GPU with compute "
f"capability {torch.cuda.get_device_capability()} "
f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.")


def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
Expand Down

0 comments on commit 9738b84

Please sign in to comment.