-
-
Notifications
You must be signed in to change notification settings - Fork 5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
417 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
from typing import Optional | ||
|
||
import argparse | ||
import torch | ||
import nvtx | ||
from itertools import accumulate | ||
from vllm.model_executor.layers.rotary_embedding import get_rope | ||
|
||
|
||
def benchmark_rope_kernels_multi_lora( | ||
is_neox_style: bool, | ||
batch_size: int, | ||
seq_len: int, | ||
num_heads: int, | ||
head_size: int, | ||
rotary_dim: Optional[int], | ||
dtype: torch.dtype, | ||
seed: int, | ||
device: str, | ||
max_position: int = 8192, | ||
base: int = 10000, | ||
) -> None: | ||
torch.random.manual_seed(seed) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed(seed) | ||
torch.set_default_device(device) | ||
if rotary_dim is None: | ||
rotary_dim = head_size | ||
# silulating serving 4 LoRAs | ||
scaling_factors = [1, 2, 4, 8] | ||
# batched RoPE can take multiple scaling factors | ||
batched_rope = get_rope(head_size, rotary_dim, max_position, base, | ||
is_neox_style, { | ||
"type": "linear", | ||
"factor": tuple(scaling_factors) | ||
}) | ||
# non-batched RoPE takes only one scaling factor, we create multiple | ||
# instances to simulate the same behavior | ||
non_batched_ropes = [] | ||
for scaling_factor in scaling_factors: | ||
non_batched_ropes.append( | ||
get_rope(head_size, rotary_dim, max_position, base, is_neox_style, | ||
{ | ||
"type": "linear", | ||
"factor": (scaling_factor, ) | ||
})) | ||
|
||
positions = torch.randint(0, max_position, (batch_size, seq_len)) | ||
query = torch.randn(batch_size, | ||
seq_len, | ||
num_heads * head_size, | ||
dtype=dtype) | ||
key = torch.randn_like(query) | ||
|
||
# create query offsets for batched RoPE, we concat multiple kv cache | ||
# together and each query needs to find the right kv cache of its type | ||
offset_map = torch.tensor( | ||
list( | ||
accumulate([0] + [ | ||
max_position * scaling_factor * 2 | ||
for scaling_factor in scaling_factors[:-1] | ||
]))) | ||
query_types = torch.randint(0, | ||
len(scaling_factors), (batch_size, seq_len), | ||
device=device) | ||
# map query types to offsets | ||
query_offsets = offset_map[query_types] | ||
# the kernel takes flattened offsets | ||
flatten_offsets = query_offsets.flatten() | ||
|
||
# batched queries of the same type together for non-batched RoPE | ||
queries = [query[query_types == i] for i in range(len(scaling_factors))] | ||
keys = [key[query_types == i] for i in range(len(scaling_factors))] | ||
packed_qkr = zip(queries, keys, non_batched_ropes) | ||
# synchronize before start timing | ||
torch.cuda.synchronize() | ||
with nvtx.annotate("non-batched", color="yellow"): | ||
for q, k, r in packed_qkr: | ||
r.forward(positions, q, k) | ||
torch.cuda.synchronize() | ||
with nvtx.annotate("batched", color="green"): | ||
batched_rope.forward(positions, query, key, flatten_offsets) | ||
torch.cuda.synchronize() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description="Benchmark the rotary embedding kernels.") | ||
parser.add_argument("--is-neox-style", type=bool, default=True) | ||
parser.add_argument("--batch-size", type=int, default=16) | ||
parser.add_argument("--seq-len", type=int, default=512) | ||
parser.add_argument("--num-heads", type=int, default=8) | ||
parser.add_argument("--head-size", | ||
type=int, | ||
choices=[64, 80, 96, 112, 128, 256], | ||
default=128) | ||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) | ||
parser.add_argument("--dtype", | ||
type=str, | ||
choices=["bfloat16", "float"], | ||
default="float") | ||
parser.add_argument("--seed", type=int, default=0) | ||
parser.add_argument("--device", | ||
type=str, | ||
choices=["cuda:0", "cuda:1"], | ||
default="cuda:0") | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
benchmark_rope_kernels_multi_lora( | ||
is_neox_style=args.is_neox_style, | ||
batch_size=args.batch_size, | ||
seq_len=args.seq_len, | ||
num_heads=args.num_heads, | ||
head_size=args.head_size, | ||
rotary_dim=args.rotary_dim, | ||
dtype=getattr(torch, args.dtype), | ||
seed=args.seed, | ||
device=args.device, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.