Skip to content

Commit

Permalink
[Core] CUDA Graphs for Multi-Step + Chunked-Prefill (vllm-project#8645)
Browse files Browse the repository at this point in the history
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
  • Loading branch information
varun-sundar-rabindranath and Varun Sundar Rabindranath authored Oct 2, 2024
1 parent 7f60520 commit afb050b
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 34 deletions.
11 changes: 11 additions & 0 deletions csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel(
long const* sampled_token_ids_ptr, long* input_positions_ptr,
int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr,
int64_t const block_tables_stride) {
int const n_pad = num_seqs - num_queries;
if (n_pad && blockIdx.x == 0) {
// Handle cuda graph padding
int const offset = num_queries;
for (int i = threadIdx.x; i < n_pad; i += blockDim.x) {
input_tokens_ptr[offset + i] = 0;
input_positions_ptr[offset + i] = 0;
slot_mapping_ptr[offset + i] = -1;
}
}

int num_query_blocks = div_ceil(num_queries, num_threads);

if (blockIdx.x >= num_query_blocks) {
Expand Down
48 changes: 28 additions & 20 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,30 @@ def _add_seq_group(
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)

def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs

graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]

return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Expand Down Expand Up @@ -533,29 +557,13 @@ def build(self, seq_lens: List[int], query_lens: List[int],
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens

num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size

# The shape of graph_block_tables is
# [max batch size, max context len // block size].
input_block_tables = self.runner.graph_block_tables[:batch_size]
max_blocks = input_block_tables.shape[1]
for i, block_table in enumerate(self.block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
input_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
input_block_tables[
i, :max_blocks] = block_table[:max_blocks]

block_tables = torch.from_numpy(input_block_tables).to(
device=device, non_blocking=True)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
Expand Down
72 changes: 58 additions & 14 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,14 +712,62 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):

def _use_captured_graph(self,
batch_size: int,
decode_only: bool,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> bool:
return (self.decode_only and not self.runner.model_config.enforce_eager
return (decode_only and not self.runner.model_config.enforce_eager
and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1]
and max_decode_seq_len <= self.runner.max_seq_len_to_capture
and max_encoder_seq_len <= self.runner.max_seq_len_to_capture
and batch_size <= self.runner.max_batchsize_to_capture)

def _get_cuda_graph_pad_size(self,
num_seqs: int,
max_decode_seq_len: int,
max_encoder_seq_len: int = 0) -> int:
"""
Determine the number of padding sequences required for running in
CUDA graph mode. Returns -1 if CUDA graphs cannot be used.
In the multi-step + chunked-prefill case, only the first step
has Prefills (if any). The rest of the steps are guaranteed to be all
decodes. In this case, we set up the padding as if all the sequences
are decodes so we may run all steps except the first step in CUDA graph
mode. The padding is accounted for in the multi-step `advance_step`
family of functions.
Args:
num_seqs (int): Number of sequences scheduled to run.
max_decode_seq_len (int): Greatest of all the decode sequence
lengths. Used only in checking the viablility of using
CUDA graphs.
max_encoder_seq_len (int, optional): Greatest of all the encode
sequence lengths. Defaults to 0. Used only in checking the
viability of using CUDA graphs.
Returns:
int: Returns the determined number of padding sequences. If
CUDA graphs is not viable, returns -1.
"""
is_mscp: bool = self.runner.scheduler_config.is_multi_step and \
self.runner.scheduler_config.chunked_prefill_enabled
decode_only = self.decode_only or is_mscp
if not decode_only:
# Early exit so we can treat num_seqs as the batch_size below.
return -1

# batch_size out of this function refers to the number of input
# tokens being scheduled. This conflation of num_seqs as batch_size
# is valid as this is a decode-only case.
batch_size = num_seqs
if not self._use_captured_graph(batch_size, decode_only,
max_decode_seq_len,
max_encoder_seq_len):
return -1

graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
return graph_batch_size - batch_size

def build(self) -> ModelInputForGPU:
"""Finalize the builder intermediate data and
create on-device tensors.
Expand Down Expand Up @@ -778,21 +826,17 @@ def build(self) -> ModelInputForGPU:
for data in self.inter_data_list
}

batch_size = len(input_tokens)
use_captured_graph = self._use_captured_graph(
batch_size,
max_decode_seq_len,
cuda_graph_pad_size = self._get_cuda_graph_pad_size(
num_seqs=len(seq_lens),
max_decode_seq_len=max_encoder_seq_len,
max_encoder_seq_len=max_encoder_seq_len)

# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
cuda_graph_pad_size = -1
if use_captured_graph:
graph_batch_size = _get_graph_batch_size(batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
batch_size = graph_batch_size
batch_size = len(input_tokens)
if cuda_graph_pad_size != -1:
# If cuda graph can be used, pad tensors accordingly.
# See `capture_model` API for more details.
# vLLM uses cuda graph only for decoding requests.
batch_size += cuda_graph_pad_size

# Tokens and positions.
if cuda_graph_pad_size:
Expand Down

0 comments on commit afb050b

Please sign in to comment.