Skip to content

Commit

Permalink
fix cudagraph
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 28, 2024
1 parent 32ba641 commit ce80011
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 27 deletions.
2 changes: 1 addition & 1 deletion csrc/prepare_inputs/advance_step.cu
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ void advance_step_flashinfer(
if (logging) {
printf("launching kernel with %d blocks\n", blocks);
}

// TODO(will): support arbitrary block_tables stride
if ((blocks * threads) / block_tables.stride(0) < num_queries) {
TORCH_CHECK(false,
Expand Down
27 changes: 16 additions & 11 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,23 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
input_positions, seq_lens, slot_mapping,
block_tables)

def advance_step_flashinfer( num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor, seq_lens: torch.Tensor,
slot_mapping: torch.Tensor, block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor, paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor, block_table_bound: torch.Tensor
) -> None:

return torch.ops._C.advance_step_flashinfer(num_seqs, num_queries,
block_size, input_tokens, sampled_token_ids, input_positions,
seq_lens, slot_mapping, block_tables, paged_kv_indices, paged_kv_indptr,
paged_kv_last_page_len, block_table_bound)
def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int,
input_tokens: torch.Tensor,
sampled_token_ids: torch.Tensor,
input_positions: torch.Tensor,
seq_lens: torch.Tensor, slot_mapping: torch.Tensor,
block_tables: torch.Tensor,
paged_kv_indices: torch.Tensor,
paged_kv_indptr: torch.Tensor,
paged_kv_last_page_len: torch.Tensor,
block_table_bound: torch.Tensor) -> None:

return torch.ops._C.advance_step_flashinfer(
num_seqs, num_queries, block_size, input_tokens, sampled_token_ids,
input_positions, seq_lens, slot_mapping, block_tables,
paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len,
block_table_bound)


# quantization ops
Expand Down
24 changes: 15 additions & 9 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@ def begin_forward(self):
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
assert self.block_table_bound is not None
assert self.seq_lens_tensor is not None
batch_size = self.query_start_loc.shape[0] - 1
assert batch_size >= 0
# We will use flash attention for profiling to
Expand All @@ -325,16 +327,18 @@ def begin_forward(self):
self.num_qo_heads, self.num_kv_heads, self.head_dim,
self.page_size)
else:
#if not self.use_cuda_graph:
assert self.paged_kv_indices is not None
assert self.paged_kv_indptr is not None
assert self.paged_kv_last_page_len is not None
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
self.device)
self.block_table_bound = self.block_table_bound.to(self.device)
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)
# handle model warmup path
if self.block_table_bound is not None:
self.block_table_bound = self.block_table_bound.to(self.device)
if self.seq_lens_tensor is not None:
self.seq_lens_tensor = self.seq_lens_tensor.to(self.device)

assert self.decode_wrapper is not None
self.decode_wrapper.end_forward()
Expand Down Expand Up @@ -524,7 +528,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],

if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] for _ in range(cuda_graph_pad_size))
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size

# The shape of graph_block_tables is
Expand Down Expand Up @@ -574,8 +578,8 @@ def build(self, seq_lens: List[int], query_lens: List[int],
if len(self.paged_kv_indptr) > 0:
# extend to the maximum number of blocks as returned by the
# scheduler
self.paged_kv_indices.extend([0] *
(self.total_blocks - len(self.paged_kv_indices)))
self.paged_kv_indices.extend(
[0] * (self.total_blocks - len(self.paged_kv_indices)))
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
dtype=torch.int)
Expand All @@ -584,14 +588,16 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=torch.int)
paged_kv_last_page_len_tensor = torch.tensor(
self.paged_kv_last_page_len, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(
len(self.paged_kv_indptr) - 1, device="cpu", dtype=torch.int)
block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) -
1,
device="cpu",
dtype=torch.int)
else:
paged_kv_indices_tensor = None
paged_kv_indptr_tensor = None
paged_kv_last_page_len_tensor = None
block_table_bound_tensor = None

kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)
return FlashInferMetadata(
Expand Down
15 changes: 9 additions & 6 deletions vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs,
assert seq_group.query_len is None # Decode

def _advance_step_flashattn(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
out: SamplerOutput) -> StatefulModelInput:
frozen_model_input = model_input.frozen_model_input
assert frozen_model_input is not None
assert frozen_model_input.attn_metadata is not None
Expand Down Expand Up @@ -378,7 +378,7 @@ def _advance_step_flashattn(self, model_input: StatefulModelInput,
frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i]

return model_input

def _advance_step_flashinfer(
self,
model_input: StatefulModelInput,
Expand All @@ -394,7 +394,10 @@ def _advance_step_flashinfer(
num_queries = model_input.num_queries

sampled_tokens = model_input.cached_outputs[-1].sampled_token_ids
frozen_model_input.input_tokens[:num_queries] = sampled_tokens.flatten()
assert sampled_tokens is not None
assert frozen_model_input.input_tokens is not None
frozen_model_input.input_tokens[:num_queries] = sampled_tokens.flatten(
)

# Update GPU tensors
ops.advance_step_flashinfer(
Expand All @@ -411,18 +414,18 @@ def _advance_step_flashinfer(
paged_kv_indptr=attn_metadata.paged_kv_indptr,
paged_kv_last_page_len=attn_metadata.paged_kv_last_page_len,
block_table_bound=attn_metadata.block_table_bound)
#frozen_model_input.seq_lens[:num_queries] = [x + 1 for x in frozen_model_input.seq_lens[:num_queries]]

return model_input

def _advance_step(self, model_input: StatefulModelInput,
out: SamplerOutput) -> StatefulModelInput:
if self.attn_backend.get_name() == "flash-attn":
return self._advance_step_flashattn(model_input, out)
elif self.attn_backend.get_name() == "flashinfer":
return self._advance_step_flashinfer(model_input, out)
else:
raise ValueError(f"Unsupported attention backend: {self.attn_backend}")
raise ValueError(
f"Unsupported attention backend: {self.attn_backend}")

def load_model(self) -> None:
return self._base_model_runner.load_model()
Expand Down

0 comments on commit ce80011

Please sign in to comment.