Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1] Integrate Piecewise CUDA graphs #10058

Merged
merged 2 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,11 @@ def __call__(self, *args) -> Any:
return entry.runnable(*args)

if self.is_first_graph:
logger.info("Capturing a cudagraph for shape %s",
runtime_shape)
# Since we capture cudagraph for many different shapes and
# capturing is fast, we don't need to log it for every shape.
# We only log it in the debug mode.
logger.debug("Capturing a cudagraph for shape %s",
runtime_shape)

input_addresses = [
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
Expand Down
35 changes: 21 additions & 14 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class FlashAttentionMetadata:
# |-------------------- seq_len ---------------------|
# |-- query_len ---|

num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
Expand Down Expand Up @@ -134,7 +135,9 @@ def forward(
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")

output = torch.ops.vllm.unified_flash_attention(
output = torch.empty_like(query)
torch.ops.vllm.unified_flash_attention(
output,
query,
key,
value,
Expand All @@ -154,6 +157,7 @@ def forward(


def unified_flash_attention(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -168,17 +172,17 @@ def unified_flash_attention(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
) -> None:
current_metadata = get_forward_context()
if current_metadata is None:
# Profiling run.
return torch.empty_like(query)
return

assert current_metadata is not None
assert isinstance(current_metadata, FlashAttentionMetadata)
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens

num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
Expand All @@ -188,18 +192,18 @@ def unified_flash_attention(
key_cache = kv_cache[0]
value_cache = kv_cache[1]
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[0],
kv_cache[1],
key[:num_actual_tokens],
value[:num_actual_tokens],
key_cache,
value_cache,
attn_metadata.slot_mapping,
kv_cache_dtype,
k_scale,
v_scale,
)

output = flash_attn_varlen_func(
q=query,
attn_output = flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
cu_seqlens_q=attn_metadata.query_start_loc,
Expand All @@ -213,10 +217,13 @@ def unified_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
return output.view(num_tokens, hidden_size)
attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
output[:num_actual_tokens].copy_(attn_output)


def unified_flash_attention_fake(
output: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand All @@ -231,13 +238,13 @@ def unified_flash_attention_fake(
window_size: Optional[List[int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
logits_soft_cap: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(query)
) -> None:
return


direct_register_custom_op(
op_name="unified_flash_attention",
op_func=unified_flash_attention,
mutates_args=["kv_cache"],
mutates_args=["kv_cache", "output"],
fake_impl=unified_flash_attention_fake,
)
127 changes: 107 additions & 20 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set
from unittest.mock import patch
Expand All @@ -7,11 +9,16 @@
import torch.distributed
import torch.nn as nn

from vllm import envs
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.config import CompilationConfig
from vllm.compilation.levels import CompilationLevel
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalDataDict
from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available)
Expand Down Expand Up @@ -86,6 +93,18 @@ def __init__(
pin_memory=self.pin_memory,
)

self.use_cuda_graph = (envs.VLLM_TORCH_COMPILE_LEVEL
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)]
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=self.device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)

def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
Expand Down Expand Up @@ -268,12 +287,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])

input_ids = input_ids.to(self.device, non_blocking=True)
positions = positions.to(self.device, non_blocking=True).long()
self.input_ids[:total_num_scheduled_tokens].copy_(input_ids,
non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)

query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
Expand All @@ -287,7 +310,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return input_ids, positions, attn_metadata, logits_indices
return attn_metadata, logits_indices

def _prepare_sampling(
self,
Expand All @@ -310,16 +333,26 @@ def execute_model(
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
inputs = self._prepare_inputs(scheduler_output)
input_ids, positions, attn_metadata, logits_indices = inputs
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self._get_padded_batch_size(
num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = num_scheduled_tokens

with set_forward_context(attn_metadata):
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
kv_caches=self.kv_caches,
attn_metadata=attn_metadata,
attn_metadata=None,
)
hidden_states = hidden_states[:num_scheduled_tokens]
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)

Expand Down Expand Up @@ -371,6 +404,18 @@ def execute_model(
return model_runner_output

def load_model(self) -> None:
if self.use_cuda_graph:
# FIXME(woosuk): Currently, the custom ops are not supported
# in the piecewise compilation mode. We rely on TorchInductor
# to optimize the model.
os.environ["VLLM_CUSTOM_OPS"] = "none"
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_flash_attention"],
use_inductor=True,
))

logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
with patch("vllm.model_executor.layers.sampler.Sampler", Sampler):
Expand All @@ -381,26 +426,61 @@ def load_model(self) -> None:
self.model_memory_usage / float(2**30))

def _dummy_run(self, model: nn.Module, num_tokens: int) -> None:
input_ids = torch.zeros(num_tokens,
dtype=torch.int32,
device=self.device)
positions = torch.zeros(num_tokens,
dtype=torch.long,
device=self.device)
kv_caches = [None for _ in range(self.num_attn_layers)]
model(input_ids, positions, kv_caches, attn_metadata=None)
return
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_forward_context(None): # noqa: SIM117
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
model(self.input_ids,
self.positions,
dummy_kv_caches,
attn_metadata=None)

@torch.inference_mode()
def profile_run(self) -> None:
self._dummy_run(self.model, self.max_num_tokens)
torch.cuda.synchronize()
return

@torch.inference_mode()
def capture_model(self) -> None:
# TODO: Implement CUDA graph support.
return
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
CompilationLevel.PIECEWISE)
return

start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]

with set_forward_context(None):
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
for num_tokens in reversed(self.cudagraph_batch_sizes):
self.model(
self.input_ids[:num_tokens],
self.positions[:num_tokens],
kv_caches=self.kv_caches,
attn_metadata=None,
)

end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))

def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0
Expand All @@ -412,6 +492,13 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
dtype=self.kv_cache_dtype,
device=self.device))

def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
for size in self.cudagraph_batch_sizes:
if batch_size <= size:
return size
return None


@dataclass
class CachedRequestState:
Expand Down