From b6237d038c999e9243c2f1acf3ccbfb738bb937e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 30 Oct 2023 14:20:08 -0700 Subject: [PATCH 1/4] Force paged attention v2 for long contexts --- vllm/model_executor/input_metadata.py | 2 ++ vllm/model_executor/layers/attention.py | 3 ++- vllm/worker/worker.py | 21 +++++++++------------ 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index b3b5852e48769..4823ab506b340 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -32,6 +32,7 @@ def __init__( selected_token_indices: torch.Tensor, categorized_sample_indices: Dict[SamplingType, torch.Tensor], sliding_window: Optional[int] = None, + force_paged_attention_v2: bool = False, ) -> None: self.seq_groups = seq_groups self.seq_data = seq_data @@ -42,6 +43,7 @@ def __init__( self.block_tables = block_tables self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices + self.force_paged_attention_v2 = force_paged_attention_v2 self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.to_cache = None diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 58f868d407bf7..4fa7160a0ab4e 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -156,7 +156,8 @@ def single_query_cached_kv_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. # TODO(woosuk): Tune this heuristic. - use_v1 = max_num_partitions == 1 or num_seqs * num_heads > 512 + use_v1 = not input_metadata.force_paged_attention_v2 and ( + max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. attention_ops.paged_attention_v1( diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index fd6faecccbfb2..4409a55deedd7 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -47,6 +47,8 @@ def __init__( self.cache_events = None self.gpu_cache = None + self.force_paged_attention_v2 = False + def init_model(self): # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) @@ -146,7 +148,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: else: max_seq_len = min(self.scheduler_config.max_model_len, self.sliding_window) - _check_if_can_support_max_seq_len(max_seq_len, self.block_size) + self.force_paged_attention_v2 = _should_force_paged_attention_v2( + max_seq_len, self.block_size) self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) @@ -332,6 +335,7 @@ def _prepare_inputs( selected_token_indices=selected_token_indices, categorized_sample_indices=categorized_sample_indices, sliding_window=self.sliding_window, + force_paged_attention_v2=self.force_paged_attention_v2, ) return tokens_tensor, positions_tensor, input_metadata @@ -421,24 +425,17 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: return x + [pad] * (max_len - len(x)) -def _check_if_can_support_max_seq_len(max_seq_len: int, - block_size: int) -> None: +def _should_force_paged_attention_v2(max_seq_len: int, + block_size: int) -> bool: # Follows the logic in - # attention_kernels.cu::single_query_cached_kv_attention_launcher + # attention_kernels.cu::paged_attention_kernel max_shared_mem = get_max_shared_memory_bytes() float32_bytes = torch.finfo(torch.float).bits // 8 padded_max_seq_len = ( (max_seq_len + block_size - 1) / block_size) * block_size # padded_max_seq_len + extra buffer required_shared_mem = (padded_max_seq_len + 512) * float32_bytes - if padded_max_seq_len * float32_bytes > max_shared_mem: - raise RuntimeError( - f"vLLM cannot currently support max_model_len={max_seq_len} " - f"with block_size={block_size} on GPU with compute " - f"capability {torch.cuda.get_device_capability()} " - f"(required shared memory {required_shared_mem} > " - f"available shared memory {max_shared_mem}). " - "This will be fixed in a future release.") + return required_shared_mem > max_shared_mem def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): From b3c22acc1e624fa6cf0a824445dc2087b9b8d664 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 1 Nov 2023 11:19:37 -0700 Subject: [PATCH 2/4] Apply suggestion from code review --- vllm/model_executor/input_metadata.py | 2 -- vllm/model_executor/layers/attention.py | 2 +- vllm/worker/worker.py | 18 ------------------ 3 files changed, 1 insertion(+), 21 deletions(-) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index 4823ab506b340..b3b5852e48769 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -32,7 +32,6 @@ def __init__( selected_token_indices: torch.Tensor, categorized_sample_indices: Dict[SamplingType, torch.Tensor], sliding_window: Optional[int] = None, - force_paged_attention_v2: bool = False, ) -> None: self.seq_groups = seq_groups self.seq_data = seq_data @@ -43,7 +42,6 @@ def __init__( self.block_tables = block_tables self.selected_token_indices = selected_token_indices self.categorized_sample_indices = categorized_sample_indices - self.force_paged_attention_v2 = force_paged_attention_v2 self.max_prompt_len = max(prompt_lens) if prompt_lens else 0 self.to_cache = None diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 4fa7160a0ab4e..3037acc5b26ca 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -156,7 +156,7 @@ def single_query_cached_kv_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. # TODO(woosuk): Tune this heuristic. - use_v1 = not input_metadata.force_paged_attention_v2 and ( + use_v1 = input_metadata.max_context_len <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: # Run PagedAttention V1. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 4409a55deedd7..416505b8debc0 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -47,8 +47,6 @@ def __init__( self.cache_events = None self.gpu_cache = None - self.force_paged_attention_v2 = False - def init_model(self): # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) @@ -148,8 +146,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: else: max_seq_len = min(self.scheduler_config.max_model_len, self.sliding_window) - self.force_paged_attention_v2 = _should_force_paged_attention_v2( - max_seq_len, self.block_size) self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) @@ -335,7 +331,6 @@ def _prepare_inputs( selected_token_indices=selected_token_indices, categorized_sample_indices=categorized_sample_indices, sliding_window=self.sliding_window, - force_paged_attention_v2=self.force_paged_attention_v2, ) return tokens_tensor, positions_tensor, input_metadata @@ -425,19 +420,6 @@ def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: return x + [pad] * (max_len - len(x)) -def _should_force_paged_attention_v2(max_seq_len: int, - block_size: int) -> bool: - # Follows the logic in - # attention_kernels.cu::paged_attention_kernel - max_shared_mem = get_max_shared_memory_bytes() - float32_bytes = torch.finfo(torch.float).bits // 8 - padded_max_seq_len = ( - (max_seq_len + block_size - 1) / block_size) * block_size - # padded_max_seq_len + extra buffer - required_shared_mem = (padded_max_seq_len + 512) * float32_bytes - return required_shared_mem > max_shared_mem - - def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): # Check if the GPU supports the dtype. if torch_dtype == torch.bfloat16: From 7023e5e223481c85222a180a0a759f5a60f5bb86 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 1 Nov 2023 11:23:24 -0700 Subject: [PATCH 3/4] Lint --- vllm/worker/worker.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 416505b8debc0..4c52b46f55a30 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -13,7 +13,7 @@ from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine -from vllm.utils import get_gpu_memory, get_max_shared_memory_bytes +from vllm.utils import get_gpu_memory class Worker: @@ -141,12 +141,6 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.block_size = cache_config.block_size self.sliding_window = cache_config.sliding_window - if self.sliding_window is None: - max_seq_len = self.scheduler_config.max_model_len - else: - max_seq_len = min(self.scheduler_config.max_model_len, - self.sliding_window) - self.cache_engine = CacheEngine(self.cache_config, self.model_config, self.parallel_config) self.cache_events = self.cache_engine.events From 2997044363ede8832db5403d07b3c029a658c79b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 1 Nov 2023 16:16:34 -0700 Subject: [PATCH 4/4] Update vllm/model_executor/layers/attention.py Co-authored-by: Woosuk Kwon --- vllm/model_executor/layers/attention.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 3037acc5b26ca..7aa01ffe14bab 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -156,6 +156,7 @@ def single_query_cached_kv_attention( # sequences or heads is large, we use V1 since there is enough work # to parallelize. # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. use_v1 = input_metadata.max_context_len <= 8192 and ( max_num_partitions == 1 or num_seqs * num_heads > 512) if use_v1: