-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batched RoPE kernel #3095
Merged
Merged
Add batched RoPE kernel #3095
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
d0105ac
add batched rope kernel
tterrysun 09399b9
refactor kernel
tterrysun 98f0c7a
benchmarking script wip
tterrysun d7f8869
benchmarking script on
tterrysun d3fa2c1
Merge branch 'main' into batched_rope
tterrysun ccb3c74
formatting
tterrysun bfbe4db
update benchmarking script
tterrysun 870fcf2
remove breakpoint
tterrysun 77b0da5
align bm behavior
tterrysun d7f691e
minor polishing
tterrysun File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from typing import Optional | ||
|
||
import argparse | ||
import torch | ||
import nvtx | ||
from itertools import accumulate | ||
from vllm.model_executor.layers.rotary_embedding import get_rope | ||
|
||
|
||
def benchmark_rope_kernels_multi_lora( | ||
is_neox_style: bool, | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
head_size: int, | ||
rotary_dim: Optional[int], | ||
dtype: torch.dtype, | ||
seed: int, | ||
device: str, | ||
max_position: int = 8192, | ||
base: int = 10000, | ||
) -> None: | ||
torch.random.manual_seed(seed) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed(seed) | ||
torch.set_default_device(device) | ||
if rotary_dim is None: | ||
rotary_dim = head_size | ||
# silulating serving 4 LoRAs | ||
scaling_factors = [1, 2, 4, 8] | ||
# batched RoPE can take multiple scaling factors | ||
batched_rope = get_rope(head_size, rotary_dim, max_position, base, | ||
is_neox_style, { | ||
"type": "linear", | ||
"factor": tuple(scaling_factors) | ||
}) | ||
# non-batched RoPE takes only one scaling factor, we create multiple | ||
# instances to simulate the same behavior | ||
non_batched_ropes = [] | ||
for scaling_factor in scaling_factors: | ||
non_batched_ropes.append( | ||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style, | ||
{ | ||
"type": "linear", | ||
"factor": (scaling_factor, ) | ||
})) | ||
|
||
positions = torch.randint(0, max_position, (batch_size, seq_len)) | ||
query = torch.randn(batch_size, | ||
seq_len, | ||
num_heads * head_size, | ||
dtype=dtype) | ||
key = torch.randn_like(query) | ||
|
||
# create query offsets for batched RoPE, we concat multiple kv cache | ||
# together and each query needs to find the right kv cache of its type | ||
offset_map = torch.tensor( | ||
list( | ||
accumulate([0] + [ | ||
max_position * scaling_factor * 2 | ||
for scaling_factor in scaling_factors[:-1] | ||
]))) | ||
query_types = torch.randint(0, | ||
len(scaling_factors), (batch_size, seq_len), | ||
device=device) | ||
# map query types to offsets | ||
query_offsets = offset_map[query_types] | ||
# the kernel takes flattened offsets | ||
flatten_offsets = query_offsets.flatten() | ||
|
||
# batched queries of the same type together for non-batched RoPE | ||
queries = [query[query_types == i] for i in range(len(scaling_factors))] | ||
keys = [key[query_types == i] for i in range(len(scaling_factors))] | ||
packed_qkr = zip(queries, keys, non_batched_ropes) | ||
# synchronize before start timing | ||
torch.cuda.synchronize() | ||
with nvtx.annotate("non-batched", color="yellow"): | ||
for q, k, r in packed_qkr: | ||
r.forward(positions, q, k) | ||
torch.cuda.synchronize() | ||
with nvtx.annotate("batched", color="green"): | ||
batched_rope.forward(positions, query, key, flatten_offsets) | ||
torch.cuda.synchronize() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description="Benchmark the rotary embedding kernels.") | ||
parser.add_argument("--is-neox-style", type=bool, default=True) | ||
parser.add_argument("--batch-size", type=int, default=16) | ||
parser.add_argument("--seq-len", type=int, default=512) | ||
parser.add_argument("--num-heads", type=int, default=8) | ||
parser.add_argument("--head-size", | ||
type=int, | ||
choices=[64, 80, 96, 112, 128, 256], | ||
default=128) | ||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) | ||
parser.add_argument("--dtype", | ||
type=str, | ||
choices=["bfloat16", "float"], | ||
default="float") | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--device", | ||
type=str, | ||
choices=["cuda:0", "cuda:1"], | ||
default="cuda:0") | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
benchmark_rope_kernels_multi_lora( | ||
is_neox_style=args.is_neox_style, | ||
batch_size=args.batch_size, | ||
seq_len=args.seq_len, | ||
num_heads=args.num_heads, | ||
head_size=args.head_size, | ||
rotary_dim=args.rotary_dim, | ||
dtype=getattr(torch, args.dtype), | ||
seed=args.seed, | ||
device=args.device, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This kernel is almost exactly the same as
rotary_embedding_kernel
and you can make them the same by adding theconst int64_t* __restrict__ cos_sin_cache_offsets
(will be a null ptr if it is not set) argument there and then down below, doingThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cos_sin_cache_offset is passed as a pointer, we don't have a good way to determine if it's empty without auxiliary flag, also we try to avoid runtime branching in kernel code for performance. agreed that these two kernels are pretty much the same so I refactored it to avoid too much code duplication.