diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py new file mode 100644 index 0000000000000..f9564dd9588f0 --- /dev/null +++ b/benchmarks/kernels/benchmark_rope.py @@ -0,0 +1,120 @@ +from typing import Optional + +import argparse +import torch +import nvtx +from itertools import accumulate +from vllm.model_executor.layers.rotary_embedding import get_rope + + +def benchmark_rope_kernels_multi_lora( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + # silulating serving 4 LoRAs + scaling_factors = [1, 2, 4, 8] + # batched RoPE can take multiple scaling factors + batched_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { + "type": "linear", + "factor": tuple(scaling_factors) + }) + # non-batched RoPE takes only one scaling factor, we create multiple + # instances to simulate the same behavior + non_batched_ropes = [] + for scaling_factor in scaling_factors: + non_batched_ropes.append( + get_rope(head_size, rotary_dim, max_position, base, is_neox_style, + { + "type": "linear", + "factor": (scaling_factor, ) + })) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + # create query offsets for batched RoPE, we concat multiple kv cache + # together and each query needs to find the right kv cache of its type + offset_map = torch.tensor( + list( + accumulate([0] + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ]))) + query_types = torch.randint(0, + len(scaling_factors), (batch_size, seq_len), + device=device) + # map query types to offsets + query_offsets = offset_map[query_types] + # the kernel takes flattened offsets + flatten_offsets = query_offsets.flatten() + + # batched queries of the same type together for non-batched RoPE + queries = [query[query_types == i] for i in range(len(scaling_factors))] + keys = [key[query_types == i] for i in range(len(scaling_factors))] + packed_qkr = zip(queries, keys, non_batched_ropes) + # synchronize before start timing + torch.cuda.synchronize() + with nvtx.annotate("non-batched", color="yellow"): + for q, k, r in packed_qkr: + r.forward(positions, q, k) + torch.cuda.synchronize() + with nvtx.annotate("batched", color="green"): + batched_rope.forward(positions, query, key, flatten_offsets) + torch.cuda.synchronize() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Benchmark the rotary embedding kernels.") + parser.add_argument("--is-neox-style", type=bool, default=True) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--num-heads", type=int, default=8) + parser.add_argument("--head-size", + type=int, + choices=[64, 80, 96, 112, 128, 256], + default=128) + parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) + parser.add_argument("--dtype", + type=str, + choices=["bfloat16", "float"], + default="float") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", + type=str, + choices=["cuda:0", "cuda:1"], + default="cuda:0") + args = parser.parse_args() + print(args) + + benchmark_rope_kernels_multi_lora( + is_neox_style=args.is_neox_style, + batch_size=args.batch_size, + seq_len=args.seq_len, + num_heads=args.num_heads, + head_size=args.head_size, + rotary_dim=args.rotary_dim, + dtype=getattr(torch, args.dtype), + seed=args.seed, + device=args.device, + ) diff --git a/csrc/ops.h b/csrc/ops.h index 53222972abb70..d5d6e240da7c4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -53,6 +53,16 @@ void rotary_embedding( torch::Tensor& cos_sin_cache, bool is_neox); +void batched_rotary_embedding( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache, + bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets); + void silu_and_mul( torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 5f522795619e1..d80cb6973fad6 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -8,7 +8,7 @@ namespace vllm { template -inline __device__ void apply_rotary_embedding( +inline __device__ void apply_token_rotary_embedding( scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, @@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding( } template -__global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] +inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, + const scalar_t* cache_ptr, + const int head_size, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride) +{ const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_rotary_embedding(query + token_head, cos_ptr, + apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } @@ -72,11 +68,53 @@ __global__ void rotary_embedding_kernel( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_rotary_embedding(key + token_head, cos_ptr, + apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } } +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + +template +__global__ void batched_rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + } // namespace vllm void rotary_embedding( @@ -128,3 +166,61 @@ void rotary_embedding( } }); } + +/* +Batched version of rotary embedding, pack multiple LoRAs together +and process in batched manner. +*/ +void batched_rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] +) { + int64_t num_tokens = cos_sin_cache_offsets.size(0); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), + "rotary_embedding", + [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::batched_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } + }); +} diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 39384f08d928c..a5c6439fd6909 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -56,6 +56,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def( + "batched_rotary_embedding", + &batched_rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); + // Quantization ops #ifndef USE_ROCM ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 0d27bbaff9fc5..ffdcc1e8c80fd 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import List, Optional import pytest import torch from allclose_default import get_default_atol, get_default_rtol +from itertools import accumulate from vllm.model_executor.layers.rotary_embedding import get_rope IS_NEOX_STYLE = [True, False] @@ -72,3 +73,135 @@ def test_rotary_embedding( ref_key, atol=get_default_atol(out_key), rtol=get_default_rtol(out_key)) + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_batched_rotary_embedding( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { + "type": "linear", + "factor": (1, ) + }) + rope = rope.to(dtype=dtype) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_query, ref_key = rope._forward(positions, query, key) + out_query, out_key = rope.forward(positions, + query, + key, + offsets=torch.zeros(batch_size * seq_len, + dtype=int, + device=device)) + # Compare the results. + assert torch.allclose(out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query)) + assert torch.allclose(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_batched_rotary_embedding_multi_lora( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + scaling_factors: List[int] = [1, 2, 4] + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { + "type": "linear", + "factor": tuple(scaling_factors) + }) + rope = rope.to(dtype=dtype) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + offset_map = torch.tensor( + list( + accumulate([0] + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ]))) + query_types = torch.randint(0, + len(scaling_factors), (batch_size, seq_len), + device=device) + query_offsets = offset_map[query_types] + + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_query, ref_key = rope._forward(positions, query, key, query_offsets) + out_query, out_key = rope.forward(positions, query, key, + query_offsets.flatten()) + # Compare the results. + assert torch.allclose(out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query)) + assert torch.allclose(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 13749570f28a2..db5c7080b50b0 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -22,7 +22,7 @@ # limitations under the License. """Rotary Positional Embeddings.""" import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -96,6 +96,7 @@ def _forward( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" query = query.view(*query.shape[:-1], -1, self.head_size) @@ -107,7 +108,9 @@ def _forward( query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - cos_sin = self.cos_sin_cache[positions] + self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the @@ -137,11 +140,19 @@ def forward( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # ops.rotary_embedding() is an in-place operation that - # updates the query and key tensors. - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) + # ops.rotary_embedding()/batched_rotary_embedding() are in-place operations that + # update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key @@ -158,27 +169,32 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, - scaling_factor: float, + scaling_factors: Union[List[float], float], ) -> None: - self.scaling_factor = scaling_factor + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors = scaling_factors super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) - # NOTE(woosuk): self.max_position_embeddings is the original - # maximum length before applying the rope scaling. - # Thus, the maximum length after applying the rope scaling is - # self.max_position_embeddings * self.scaling_factor. - max_len = self.max_position_embeddings * self.scaling_factor - t = torch.arange(max_len, dtype=torch.float) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache + cache_list = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + cache_list.append(cache) + return torch.cat(cache_list, dim=0) class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):