Skip to content

Commit

Permalink
[multi-step] add flashinfer backend (vllm-project#7928)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
SolitaryThinker authored and garg-amit committed Oct 28, 2024
1 parent 60379ec commit 20e9b56
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {

// prepare_inputs advance_step
ops.def(
"advance_step(int num_seqs, int num_queries, int block_size, "
"advance_step_flashattn(int num_seqs, int num_queries, int block_size, "
"Tensor! input_tokens, Tensor sampled_token_ids, "
"Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, "
"Tensor block_tables) -> ()");
ops.impl("advance_step", torch::kCUDA, &advance_step);
ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn);

ops.def(
"advance_step_flashinfer("
" int num_seqs, int num_queries, int block_size,"
" Tensor! input_tokens, Tensor sampled_token_ids,"
" Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping,"
" Tensor block_tables, Tensor! paged_kv_indices,"
" Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len,"
" Tensor! block_table_bounds"
") -> ()");
ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer);

// Layernorm
// Apply Root Mean Square (RMS) Normalization to the input tensor.
Expand Down

0 comments on commit 20e9b56

Please sign in to comment.