From e662e6d141914c5a0b93b22e01da63f8aaf50ce3 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 5 Nov 2024 19:51:14 -0800 Subject: [PATCH 1/2] Reset Signed-off-by: Woosuk Kwon --- vllm/compilation/backends.py | 7 +- vllm/v1/attention/backends/flash_attn.py | 33 +++--- vllm/v1/worker/gpu_model_runner.py | 126 +++++++++++++++++++---- 3 files changed, 131 insertions(+), 35 deletions(-) diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 05deee7bd5473..abd1d16accaf7 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -496,8 +496,11 @@ def __call__(self, *args) -> Any: return entry.runnable(*args) if self.is_first_graph: - logger.info("Capturing a cudagraph for shape %s", - runtime_shape) + # Since we capture cudagraph for many different shapes and + # capturing is fast, we don't need to log it for every shape. + # We only log it in the debug mode. + logger.debug("Capturing a cudagraph for shape %s", + runtime_shape) input_addresses = [ x.data_ptr() for x in args if isinstance(x, torch.Tensor) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index b2af89ebf854a..3a33853ebd505 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -51,6 +51,7 @@ class FlashAttentionMetadata: # |-------------------- seq_len ---------------------| # |-- query_len ---| + num_actual_tokens: int # Number of tokens excluding padding. max_query_len: int query_start_loc: torch.Tensor max_seq_len: int @@ -134,7 +135,9 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - output = torch.ops.vllm.unified_flash_attention( + output = torch.empty_like(query) + torch.ops.vllm.unified_flash_attention( + output, query, key, value, @@ -154,6 +157,7 @@ def forward( def unified_flash_attention( + out: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -168,17 +172,17 @@ def unified_flash_attention( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: +) -> None: current_metadata = get_forward_context() if current_metadata is None: # Profiling run. - return torch.empty_like(query) + return assert current_metadata is not None assert isinstance(current_metadata, FlashAttentionMetadata) attn_metadata: FlashAttentionMetadata = current_metadata + num_actual_tokens = attn_metadata.num_actual_tokens - num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, num_heads, head_size) key = key.view(-1, num_kv_heads, head_size) @@ -188,10 +192,10 @@ def unified_flash_attention( key_cache = kv_cache[0] value_cache = kv_cache[1] torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], + key[:num_actual_tokens], + value[:num_actual_tokens], + key_cache, + value_cache, attn_metadata.slot_mapping, kv_cache_dtype, k_scale, @@ -199,7 +203,7 @@ def unified_flash_attention( ) output = flash_attn_varlen_func( - q=query, + q=query[:num_actual_tokens], k=key_cache, v=value_cache, cu_seqlens_q=attn_metadata.query_start_loc, @@ -213,10 +217,13 @@ def unified_flash_attention( block_table=attn_metadata.block_table, softcap=logits_soft_cap, ) - return output.view(num_tokens, hidden_size) + output = output.view(num_actual_tokens, -1) + # TODO(woosuk): Optimize this. + out[:num_actual_tokens].copy_(output, non_blocking=True) def unified_flash_attention_fake( + out: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -231,13 +238,13 @@ def unified_flash_attention_fake( window_size: Optional[List[int]] = None, alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - return torch.empty_like(query) +) -> None: + return direct_register_custom_op( op_name="unified_flash_attention", op_func=unified_flash_attention, - mutates_args=["kv_cache"], + mutates_args=["kv_cache", "out"], fake_impl=unified_flash_attention_fake, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ae4239f8e1fab..6ab9f00881821 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,3 +1,4 @@ +import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set from unittest.mock import patch @@ -7,11 +8,16 @@ import torch.distributed import torch.nn as nn +from vllm import envs +from vllm.compilation.compile_context import set_compile_context +from vllm.compilation.config import CompilationConfig +from vllm.compilation.levels import CompilationLevel from vllm.config import VllmConfig from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model from vllm.multimodal import MultiModalDataDict +from vllm.plugins import set_compilation_config from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, is_pin_memory_available) @@ -86,6 +92,18 @@ def __init__( pin_memory=self.pin_memory, ) + self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL + == CompilationLevel.PIECEWISE + and not self.model_config.enforce_eager) + # TODO(woosuk): Provide an option to tune the max cudagraph batch size. + self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] + self.input_ids = torch.zeros(self.max_num_tokens, + dtype=torch.int32, + device=self.device) + self.positions = torch.zeros(self.max_num_tokens, + dtype=torch.int64, + device=self.device) + def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. @@ -268,12 +286,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_start_loc_np[0] = 0 np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - input_ids = input_ids.to(self.device, non_blocking=True) - positions = positions.to(self.device, non_blocking=True).long() + self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, + non_blocking=True) + self.positions[:total_num_scheduled_tokens].copy_(positions, + non_blocking=True) + query_start_loc = query_start_loc.to(self.device, non_blocking=True) seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() attn_metadata = FlashAttentionMetadata( + num_actual_tokens=total_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens, query_start_loc=query_start_loc, max_seq_len=max_seq_len, @@ -287,7 +309,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # token from the partial request. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - return input_ids, positions, attn_metadata, logits_indices + return attn_metadata, logits_indices def _prepare_sampling( self, @@ -310,16 +332,26 @@ def execute_model( scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: self._update_states(scheduler_output) - inputs = self._prepare_inputs(scheduler_output) - input_ids, positions, attn_metadata, logits_indices = inputs + attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + if (self.use_cuda_graph + and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): + # Use piecewise CUDA graphs. + # Add padding to the batch size. + num_input_tokens = self._get_padded_batch_size( + num_scheduled_tokens) + else: + # Eager mode. + num_input_tokens = num_scheduled_tokens with set_forward_context(attn_metadata): hidden_states = self.model( - input_ids=input_ids, - positions=positions, + input_ids=self.input_ids[:num_input_tokens], + positions=self.positions[:num_input_tokens], kv_caches=self.kv_caches, - attn_metadata=attn_metadata, + attn_metadata=None, ) + hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(hidden_states, None) @@ -371,6 +403,18 @@ def execute_model( return model_runner_output def load_model(self) -> None: + if self.use_cuda_graph: + # FIXME(woosuk): Currently, the custom ops are not supported + # in the piecewise compilation mode. We rely on TorchInductor + # to optimize the model. + envs.VLLM_CUSTOM_OPS = "none" + set_compilation_config( + CompilationConfig( + use_cudagraph=True, + non_cudagraph_ops=["vllm.unified_flash_attention"], + use_inductor=True, + )) + logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): @@ -381,26 +425,61 @@ def load_model(self) -> None: self.model_memory_usage / float(2**30)) def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: - input_ids = torch.zeros(num_tokens, - dtype=torch.int32, - device=self.device) - positions = torch.zeros(num_tokens, - dtype=torch.long, - device=self.device) - kv_caches = [None for _ in range(self.num_attn_layers)] - model(input_ids, positions, kv_caches, attn_metadata=None) - return + # use an empty tensor instead of `None`` to force Dynamo to pass + # it by reference, rather by specializing on the value `None`. + # the `dtype` argument does not matter, and we use `float32` as + # a placeholder (it has wide hardware support). + # it is important to create tensors inside the loop, rather than + # multiplying the list, to avoid Dynamo from treating them as + # tensor aliasing. + dummy_kv_caches = [ + torch.tensor([], dtype=torch.float32, device=self.device) + for _ in range(self.num_attn_layers) + ] + with set_forward_context(None): # noqa: SIM117 + with set_compile_context(self.cudagraph_batch_sizes): + # Trigger compilation for general shape. + model(self.input_ids, + self.positions, + dummy_kv_caches, + attn_metadata=None) @torch.inference_mode() def profile_run(self) -> None: self._dummy_run(self.model, self.max_num_tokens) torch.cuda.synchronize() - return @torch.inference_mode() def capture_model(self) -> None: - # TODO: Implement CUDA graph support. - return + if self.use_cuda_graph: + logger.warning( + "Skipping CUDA graph capture. Please set " + "VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.", + CompilationLevel.PIECEWISE) + return + + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + with set_forward_context(None): + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for num_tokens in reversed(self.cudagraph_batch_sizes): + self.model( + self.input_ids[:num_tokens], + self.positions[:num_tokens], + kv_caches=self.kv_caches, + attn_metadata=None, + ) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + # This usually takes 5~20 seconds. + logger.info("Graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) def initialize_kv_cache(self, num_blocks: int) -> None: assert len(self.kv_caches) == 0 @@ -412,6 +491,13 @@ def initialize_kv_cache(self, num_blocks: int) -> None: dtype=self.kv_cache_dtype, device=self.device)) + def _get_padded_batch_size(self, batch_size: int) -> Optional[int]: + # TODO: Optimize this? + for size in self.cudagraph_batch_sizes: + if batch_size <= size: + return size + return None + @dataclass class CachedRequestState: From 7801bde8cb26916445af37902d1a91845ae9b9d5 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 5 Nov 2024 21:45:52 -0800 Subject: [PATCH 2/2] Address review Signed-off-by: Woosuk Kwon --- vllm/v1/attention/backends/flash_attn.py | 12 ++++++------ vllm/v1/worker/gpu_model_runner.py | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 3a33853ebd505..906f06777a136 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -157,7 +157,7 @@ def forward( def unified_flash_attention( - out: torch.Tensor, + output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -202,7 +202,7 @@ def unified_flash_attention( v_scale, ) - output = flash_attn_varlen_func( + attn_output = flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, v=value_cache, @@ -217,13 +217,13 @@ def unified_flash_attention( block_table=attn_metadata.block_table, softcap=logits_soft_cap, ) - output = output.view(num_actual_tokens, -1) + attn_output = attn_output.view(num_actual_tokens, -1) # TODO(woosuk): Optimize this. - out[:num_actual_tokens].copy_(output, non_blocking=True) + output[:num_actual_tokens].copy_(attn_output) def unified_flash_attention_fake( - out: torch.Tensor, + output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -245,6 +245,6 @@ def unified_flash_attention_fake( direct_register_custom_op( op_name="unified_flash_attention", op_func=unified_flash_attention, - mutates_args=["kv_cache", "out"], + mutates_args=["kv_cache", "output"], fake_impl=unified_flash_attention_fake, ) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 6ab9f00881821..63bf7c2e605a2 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,3 +1,4 @@ +import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set @@ -407,7 +408,7 @@ def load_model(self) -> None: # FIXME(woosuk): Currently, the custom ops are not supported # in the piecewise compilation mode. We rely on TorchInductor # to optimize the model. - envs.VLLM_CUSTOM_OPS = "none" + os.environ["VLLM_CUSTOM_OPS"] = "none" set_compilation_config( CompilationConfig( use_cudagraph=True, @@ -451,7 +452,7 @@ def profile_run(self) -> None: @torch.inference_mode() def capture_model(self) -> None: - if self.use_cuda_graph: + if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please set " "VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",