Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force paged attention v2 for long contexts #1510

Merged
merged 4 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def single_query_cached_kv_attention(
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1 = input_metadata.max_context_len <= 8192 and (
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
max_num_partitions == 1 or num_seqs * num_heads > 512)
if use_v1:
# Run PagedAttention V1.
attention_ops.paged_attention_v1(
Expand Down
29 changes: 1 addition & 28 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes
from vllm.utils import get_gpu_memory


class Worker:
Expand Down Expand Up @@ -141,13 +141,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.block_size = cache_config.block_size
self.sliding_window = cache_config.sliding_window

if self.sliding_window is None:
max_seq_len = self.scheduler_config.max_model_len
else:
max_seq_len = min(self.scheduler_config.max_model_len,
self.sliding_window)
_check_if_can_support_max_seq_len(max_seq_len, self.block_size)

self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.cache_events = self.cache_engine.events
Expand Down Expand Up @@ -421,26 +414,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
return x + [pad] * (max_len - len(x))


def _check_if_can_support_max_seq_len(max_seq_len: int,
block_size: int) -> None:
# Follows the logic in
# attention_kernels.cu::single_query_cached_kv_attention_launcher
max_shared_mem = get_max_shared_memory_bytes()
float32_bytes = torch.finfo(torch.float).bits // 8
padded_max_seq_len = (
(max_seq_len + block_size - 1) / block_size) * block_size
# padded_max_seq_len + extra buffer
required_shared_mem = (padded_max_seq_len + 512) * float32_bytes
if padded_max_seq_len * float32_bytes > max_shared_mem:
raise RuntimeError(
f"vLLM cannot currently support max_model_len={max_seq_len} "
f"with block_size={block_size} on GPU with compute "
f"capability {torch.cuda.get_device_capability()} "
f"(required shared memory {required_shared_mem} > "
f"available shared memory {max_shared_mem}). "
"This will be fixed in a future release.")


def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
# Check if the GPU supports the dtype.
if torch_dtype == torch.bfloat16:
Expand Down
Loading