Skip to content

Commit

Permalink
[V1] Integrate Piecewise CUDA graphs (vllm-project#10058)
Browse files Browse the repository at this point in the history
Signed-off-by: Woosuk Kwon <[email protected]>
Signed-off-by: Sumit Dubey <[email protected]>
  • Loading branch information
WoosukKwon authored and sumitd2 committed Nov 14, 2024
1 parent c5fbe44 commit f619c6a
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 36 deletions.
7 changes: 5 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,11 @@ def __call__(self, *args) -> Any:
return entry.runnable(*args)

if self.is_first_graph:
logger.info("Capturing a cudagraph for shape %s",
runtime_shape)
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
Expand Down
35 changes: 21 additions & 14 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class FlashAttentionMetadata:
# |-------------------- seq_len ---------------------|
# |-- query_len ---|

num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
Expand Down Expand Up @@ -134,7 +135,9 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

output = torch.ops.vllm.unified_flash_attention(
output = torch.empty_like(query)
torch.ops.vllm.unified_flash_attention(
output,
query,
key,
value,
Expand All @@ -154,6 +157,7 @@ def forward(


def unified_flash_attention(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -168,17 +172,17 @@ def unified_flash_attention(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
) -> None:
current_metadata = get_forward_context()
if current_metadata is None:
# Profiling run.
return torch.empty_like(query)
return

assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
Expand All @@ -188,18 +192,18 @@ def unified_flash_attention(
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)

output = flash_attn_varlen_func(
q=query,
attn_output = flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc,
Expand All @@ -213,10 +217,13 @@ def unified_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
return output.view(num_tokens, hidden_size)
attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
output[:num_actual_tokens].copy_(attn_output)


def unified_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -231,13 +238,13 @@ def unified_flash_attention_fake(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
) -> None:
return


direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
mutates_args=["kv_cache", "output"],
fake_impl=unified_flash_attention_fake,
)
127 changes: 107 additions & 20 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from unittest.mock import patch
Expand All @@ -7,11 +9,16 @@
import torch.distributed
import torch.nn as nn

from vllm import envs
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalDataDict
from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available)
Expand Down Expand Up @@ -86,6 +93,18 @@ def __init__(
pin_memory=self.pin_memory,
)

self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
Expand Down Expand Up @@ -268,12 +287,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])

input_ids = input_ids.to(self.device, non_blocking=True)
positions = positions.to(self.device, non_blocking=True).long()
self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)

query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
Expand All @@ -287,7 +310,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return input_ids, positions, attn_metadata, logits_indices
return attn_metadata, logits_indices

def _prepare_sampling(
self,
Expand All @@ -310,16 +333,26 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
inputs = self._prepare_inputs(scheduler_output)
input_ids, positions, attn_metadata, logits_indices = inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self._get_padded_batch_size(
num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = num_scheduled_tokens

with set_forward_context(attn_metadata):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
attn_metadata=None,
)
hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)

Expand Down Expand Up @@ -371,6 +404,18 @@ def execute_model(
return model_runner_output

def load_model(self) -> None:
if self.use_cuda_graph:
# FIXME(woosuk): Currently, the custom ops are not supported
# in the piecewise compilation mode. We rely on TorchInductor
# to optimize the model.
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_flash_attention"],
use_inductor=True,
))

logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
Expand All @@ -381,26 +426,61 @@ def load_model(self) -> None:
self.model_memory_usage / float(2**30))

def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
input_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
positions = torch.zeros(num_tokens,
dtype=torch.long,
device=self.device)
kv_caches = [None for _ in range(self.num_attn_layers)]
model(input_ids, positions, kv_caches, attn_metadata=None)
return
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
model(self.input_ids,
self.positions,
dummy_kv_caches,
attn_metadata=None)

@torch.inference_mode()
def profile_run(self) -> None:
self._dummy_run(self.model, self.max_num_tokens)
torch.cuda.synchronize()
return

@torch.inference_mode()
def capture_model(self) -> None:
# TODO: Implement CUDA graph support.
return
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
CompilationLevel.PIECEWISE)
return

start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]

with set_forward_context(None):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))

def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0
Expand All @@ -412,6 +492,13 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
dtype=self.kv_cache_dtype,
device=self.device))

def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
for size in self.cudagraph_batch_sizes:
if batch_size <= size:
return size
return None


@dataclass
class CachedRequestState:
Expand Down

0 comments on commit f619c6a

Please sign in to comment.