From d0105aca34be5868ab0f51e2b868371cdaf6ca04 Mon Sep 17 00:00:00 2001 From: tterrysun Date: Wed, 28 Feb 2024 13:55:20 -0800 Subject: [PATCH 1/9] add batched rope kernel --- csrc/ops.h | 10 ++ csrc/pos_encoding_kernels.cu | 100 +++++++++++++ csrc/pybind.cpp | 5 + tests/kernels/test_pos_encoding.py | 135 +++++++++++++++++- .../model_executor/layers/rotary_embedding.py | 57 +++++--- 5 files changed, 285 insertions(+), 22 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 2bcd0c2efc5c6..6c5ace4d46c0e 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -53,6 +53,16 @@ void rotary_embedding( torch::Tensor& cos_sin_cache, bool is_neox); +void batched_rotary_embedding( + torch::Tensor& positions, + torch::Tensor& query, + torch::Tensor& key, + int head_size, + torch::Tensor& cos_sin_cache, + bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets); + void silu_and_mul( torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 5f522795619e1..282fdf23189b6 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -77,6 +77,48 @@ __global__ void rotary_embedding_kernel( } } +template +__global__ void batched_rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] or [num_tokens] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(query + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_rotary_embedding(key + token_head, cos_ptr, + sin_ptr, rot_offset, embed_dim); + } +} + } // namespace vllm void rotary_embedding( @@ -128,3 +170,61 @@ void rotary_embedding( } }); } + +/* +Batched version of rotary embedding, pack multiple LoRAs together +and process in batched manner. +*/ +void batched_rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size] + int head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, + int rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] +) { + int64_t num_tokens = cos_sin_cache_offsets.size(0); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + query.scalar_type(), + "rotary_embedding", + [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } else { + vllm::batched_rotary_embedding_kernel<<>>( + positions.data_ptr(), + query.data_ptr(), + key.data_ptr(), + cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), + rot_dim, + query_stride, + key_stride, + num_heads, + num_kv_heads, + head_size); + } + }); +} diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index b36d259697167..d7d5918d2266d 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -48,6 +48,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); + ops.def( + "batched_rotary_embedding", + &batched_rotary_embedding, + "Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"); + // Quantization ops #ifndef USE_ROCM ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index 0d27bbaff9fc5..ffdcc1e8c80fd 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -1,8 +1,9 @@ -from typing import Optional +from typing import List, Optional import pytest import torch from allclose_default import get_default_atol, get_default_rtol +from itertools import accumulate from vllm.model_executor.layers.rotary_embedding import get_rope IS_NEOX_STYLE = [True, False] @@ -72,3 +73,135 @@ def test_rotary_embedding( ref_key, atol=get_default_atol(out_key), rtol=get_default_rtol(out_key)) + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_batched_rotary_embedding( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { + "type": "linear", + "factor": (1, ) + }) + rope = rope.to(dtype=dtype) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_query, ref_key = rope._forward(positions, query, key) + out_query, out_key = rope.forward(positions, + query, + key, + offsets=torch.zeros(batch_size * seq_len, + dtype=int, + device=device)) + # Compare the results. + assert torch.allclose(out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query)) + assert torch.allclose(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + + +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_batched_rotary_embedding_multi_lora( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + scaling_factors: List[int] = [1, 2, 4] + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { + "type": "linear", + "factor": tuple(scaling_factors) + }) + rope = rope.to(dtype=dtype) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + offset_map = torch.tensor( + list( + accumulate([0] + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ]))) + query_types = torch.randint(0, + len(scaling_factors), (batch_size, seq_len), + device=device) + query_offsets = offset_map[query_types] + + # NOTE(woosuk): The reference implementation should be executed first + # because the custom kernel is in-place. + ref_query, ref_key = rope._forward(positions, query, key, query_offsets) + out_query, out_key = rope.forward(positions, query, key, + query_offsets.flatten()) + # Compare the results. + assert torch.allclose(out_query, + ref_query, + atol=get_default_atol(out_query), + rtol=get_default_rtol(out_query)) + assert torch.allclose(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 93ec5c12536fb..bc43422a963d4 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -22,7 +22,7 @@ # limitations under the License. """Rotary Positional Embeddings.""" import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -96,6 +96,7 @@ def _forward( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """PyTorch-native implementation equivalent to forward().""" query = query.view(*query.shape[:-1], -1, self.head_size) @@ -107,7 +108,10 @@ def _forward( query_pass = query[..., self.rotary_dim:] key_pass = key[..., self.rotary_dim:] - cos_sin = self.cos_sin_cache[positions] + self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + # breakpoint() cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the @@ -137,11 +141,19 @@ def forward( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # ops.rotary_embedding() is an in-place operation that - # updates the query and key tensors. - ops.rotary_embedding(positions, query, key, self.head_size, - self.cos_sin_cache, self.is_neox_style) + self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) + # ops.rotary_embedding()/batched_rotary_embedding() are in-place operations that + # update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) return query, key @@ -158,27 +170,30 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, - scaling_factor: float, + scaling_factors: List[float], ) -> None: - self.scaling_factor = scaling_factor + self.scaling_factors = scaling_factors super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style) def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.base) - # NOTE(woosuk): self.max_position_embeddings is the original - # maximum length before applying the rope scaling. - # Thus, the maximum length after applying the rope scaling is - # self.max_position_embeddings * self.scaling_factor. - max_len = self.max_position_embeddings * self.scaling_factor - t = torch.arange(max_len, dtype=torch.float) - t = t / self.scaling_factor - - freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() - sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) - return cache + cache_list = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + cache_list.append(cache) + return torch.cat(cache_list, dim=0) class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): From 09399b9914015641a66b0b7910392516c565168e Mon Sep 17 00:00:00 2001 From: tterrysun Date: Fri, 1 Mar 2024 15:33:10 -0800 Subject: [PATCH 2/9] refactor kernel --- csrc/pos_encoding_kernels.cu | 68 +++++++++++++++++------------------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 282fdf23189b6..d80cb6973fad6 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -8,7 +8,7 @@ namespace vllm { template -inline __device__ void apply_rotary_embedding( +inline __device__ void apply_token_rotary_embedding( scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, const scalar_t* __restrict__ sin_ptr, @@ -38,22 +38,18 @@ inline __device__ void apply_rotary_embedding( } template -__global__ void rotary_embedding_kernel( - const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] +inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] - const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] - const int rot_dim, - const int64_t query_stride, - const int64_t key_stride, + const scalar_t* cache_ptr, + const int head_size, const int num_heads, const int num_kv_heads, - const int head_size) { - // Each thread block is responsible for one token. - const int token_idx = blockIdx.x; - int64_t pos = positions[token_idx]; - const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; - + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride) +{ const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -63,7 +59,7 @@ __global__ void rotary_embedding_kernel( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_rotary_embedding(query + token_head, cos_ptr, + apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } @@ -72,11 +68,31 @@ __global__ void rotary_embedding_kernel( const int head_idx = i / embed_dim; const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; - apply_rotary_embedding(key + token_head, cos_ptr, + apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } } +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); +} + template __global__ void batched_rotary_embedding_kernel( const int64_t* __restrict__ positions, // [batch_size, seq_len] or [num_tokens] @@ -96,27 +112,7 @@ __global__ void batched_rotary_embedding_kernel( int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; const scalar_t* cache_ptr = cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; - const int embed_dim = rot_dim / 2; - const scalar_t* cos_ptr = cache_ptr; - const scalar_t* sin_ptr = cache_ptr + embed_dim; - - const int nq = num_heads * embed_dim; - for (int i = threadIdx.x; i < nq; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * query_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_rotary_embedding(query + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); - } - - const int nk = num_kv_heads * embed_dim; - for (int i = threadIdx.x; i < nk; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_rotary_embedding(key + token_head, cos_ptr, - sin_ptr, rot_offset, embed_dim); - } + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, token_idx, query_stride, key_stride); } } // namespace vllm From 98f0c7a0468c3a7bdc4d7bd79ed6c68b553a850c Mon Sep 17 00:00:00 2001 From: tterrysun Date: Tue, 5 Mar 2024 23:12:38 -0800 Subject: [PATCH 3/9] benchmarking script wip --- benchmarks/kernels/benchmark_rope.py | 102 +++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 benchmarks/kernels/benchmark_rope.py diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py new file mode 100644 index 0000000000000..52a254628013c --- /dev/null +++ b/benchmarks/kernels/benchmark_rope.py @@ -0,0 +1,102 @@ +from typing import Optional + +import argparse +import torch +import nvtx +from itertools import accumulate +from vllm.model_executor.layers.rotary_embedding import get_rope + + +def benchmark_rope_kernels_multi_lora( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + scaling_factors = [1, 2, 4, 8] + rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { + "type": "linear", + "factor": tuple(scaling_factors) + }) + rope = rope.to(dtype=dtype) + + positions = torch.randint(0, max_position, (batch_size, seq_len)) + query = torch.randn(batch_size, + seq_len, + num_heads * head_size, + dtype=dtype) + key = torch.randn_like(query) + + offset_map = torch.tensor( + list( + accumulate([0] + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ]))) + query_types = torch.randint(0, + len(scaling_factors), (batch_size, seq_len), + device=device) + query_offsets = offset_map[query_types].flatten() + + torch.cuda.synchronize() + with nvtx.annotate("batched", color="green"): + rope.forward(positions, query, key, query_offsets) + torch.cuda.synchronize() + + queries = [query[query_types == i] for i in range(len(scaling_factors))] + keys = [key[query_types == i] for i in range(len(scaling_factors))] + packed_qk = zip(queries, keys) + torch.cuda.synchronize() + with nvtx.annotate("non-batched", color="yellow"): + for query, key in packed_qk: + # the value here is actually wrong because we don't pass any offsets + # but we are only interested in the time it takes to execute the kernel + rope.forward(positions, query, key) + torch.cuda.synchronize() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description="Benchmark the rottery embedding kernels.") + parser.add_argument("--is-neox-style", type=bool, default=True) + parser.add_argument("--batch-size", type=int, default=16) + parser.add_argument("--seq-len", type=int, default=512) + parser.add_argument("--num-heads", type=int, default=8) + parser.add_argument("--head-size", + type=int, + choices=[64, 80, 96, 112, 128, 256], + default=128) + parser.add_argument("--rottery-dim", type=int, choices=[16, 32], default=32) + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0") + args = parser.parse_args() + print(args) + + benchmark_rope_kernels_multi_lora( + is_neox_style=args.is_neox_style, + batch_size=args.batch_size, + seq_len=args.seq_len, + num_heads=args.num_heads, + head_size=args.head_size, + rotary_dim=args.rottery_dim, + dtype=getattr(torch, args.dtype), + seed=args.seed, + device=args.device, + ) From d7f886916ce59b520b25e53633e0cf0d73c7e55c Mon Sep 17 00:00:00 2001 From: tterrysun Date: Wed, 6 Mar 2024 12:54:45 -0800 Subject: [PATCH 4/9] benchmarking script on --- benchmarks/kernels/benchmark_rope.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 52a254628013c..87adba19a3ebe 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -56,16 +56,14 @@ def benchmark_rope_kernels_multi_lora( rope.forward(positions, query, key, query_offsets) torch.cuda.synchronize() - queries = [query[query_types == i] for i in range(len(scaling_factors))] - keys = [key[query_types == i] for i in range(len(scaling_factors))] - packed_qk = zip(queries, keys) + packed_qk = zip(query, key) torch.cuda.synchronize() with nvtx.annotate("non-batched", color="yellow"): - for query, key in packed_qk: + for q, k in packed_qk: # the value here is actually wrong because we don't pass any offsets # but we are only interested in the time it takes to execute the kernel - rope.forward(positions, query, key) - torch.cuda.synchronize() + rope.forward(positions, q, k) + torch.cuda.synchronize() if __name__ == '__main__': From ccb3c74870a75e6804c2a0ce37496ee7e72d48eb Mon Sep 17 00:00:00 2001 From: tterrysun Date: Wed, 6 Mar 2024 13:05:29 -0800 Subject: [PATCH 5/9] formatting --- benchmarks/kernels/benchmark_rope.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 87adba19a3ebe..211ee1ef4e543 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -77,13 +77,19 @@ def benchmark_rope_kernels_multi_lora( type=int, choices=[64, 80, 96, 112, 128, 256], default=128) - parser.add_argument("--rottery-dim", type=int, choices=[16, 32], default=32) + parser.add_argument("--rottery-dim", + type=int, + choices=[16, 32], + default=32) parser.add_argument("--dtype", type=str, choices=["half", "bfloat16", "float"], default="half") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0") + parser.add_argument("--device", + type=str, + choices=["cuda:0", "cuda:1"], + default="cuda:0") args = parser.parse_args() print(args) From bfbe4db6e95eb274cf6cd828422cd2aaf7387b82 Mon Sep 17 00:00:00 2001 From: tterrysun Date: Thu, 7 Mar 2024 15:19:47 -0800 Subject: [PATCH 6/9] update benchmarking script --- benchmarks/kernels/benchmark_rope.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 211ee1ef4e543..c46ddbc8ef92f 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -49,21 +49,22 @@ def benchmark_rope_kernels_multi_lora( query_types = torch.randint(0, len(scaling_factors), (batch_size, seq_len), device=device) - query_offsets = offset_map[query_types].flatten() + query_offsets = offset_map[query_types] + flatten_offsets = query_offsets.flatten() - torch.cuda.synchronize() - with nvtx.annotate("batched", color="green"): - rope.forward(positions, query, key, query_offsets) - torch.cuda.synchronize() - - packed_qk = zip(query, key) + queries = [query[query_types == i] for i in range(len(scaling_factors))] + keys = [key[query_types == i] for i in range(len(scaling_factors))] + packed_qk = zip(queries, keys) torch.cuda.synchronize() with nvtx.annotate("non-batched", color="yellow"): for q, k in packed_qk: # the value here is actually wrong because we don't pass any offsets # but we are only interested in the time it takes to execute the kernel rope.forward(positions, q, k) - torch.cuda.synchronize() + torch.cuda.synchronize() + with nvtx.annotate("batched", color="green"): + rope.forward(positions, query, key, flatten_offsets) + torch.cuda.synchronize() if __name__ == '__main__': From 870fcf28d91312ab200de08a0dc5189d122981a0 Mon Sep 17 00:00:00 2001 From: tterrysun Date: Thu, 7 Mar 2024 15:23:02 -0800 Subject: [PATCH 7/9] remove breakpoint --- vllm/model_executor/layers/rotary_embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 9ed078db43e25..ea3ba4f9070e4 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -111,7 +111,6 @@ def _forward( self.cos_sin_cache = self.cos_sin_cache.to(positions.get_device()) cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] - # breakpoint() cos, sin = cos_sin.chunk(2, dim=-1) if self.is_neox_style: # NOTE(woosuk): Here we assume that the positions tensor has the From 77b0da541f240bf4325066b737172d25fb1aa256 Mon Sep 17 00:00:00 2001 From: tterrysun Date: Fri, 8 Mar 2024 11:46:54 -0800 Subject: [PATCH 8/9] align bm behavior --- benchmarks/kernels/benchmark_rope.py | 42 +++++++++++++++++++--------- 1 file changed, 29 insertions(+), 13 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index c46ddbc8ef92f..2cb67ccc20f75 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -26,12 +26,24 @@ def benchmark_rope_kernels_multi_lora( torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size + # silulating serving 4 LoRAs scaling_factors = [1, 2, 4, 8] - rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", - "factor": tuple(scaling_factors) - }) - rope = rope.to(dtype=dtype) + # batched RoPE can take multiple scaling factors + batched_rope = get_rope(head_size, rotary_dim, max_position, base, + is_neox_style, { + "type": "linear", + "factor": tuple(scaling_factors) + }) + # non-batched RoPE takes only one scaling factor, we create multiple + # instances to simulate the same behavior + non_batched_ropes = [] + for scaling_factor in scaling_factors: + non_batched_ropes.append( + get_rope(head_size, rotary_dim, max_position, base, is_neox_style, + { + "type": "linear", + "factor": (scaling_factor, ) + })) positions = torch.randint(0, max_position, (batch_size, seq_len)) query = torch.randn(batch_size, @@ -40,6 +52,8 @@ def benchmark_rope_kernels_multi_lora( dtype=dtype) key = torch.randn_like(query) + # create query offsets for batched RoPE, we concat multiple kv cache + # together and each query needs to find the right kv cache of its type offset_map = torch.tensor( list( accumulate([0] + [ @@ -49,21 +63,23 @@ def benchmark_rope_kernels_multi_lora( query_types = torch.randint(0, len(scaling_factors), (batch_size, seq_len), device=device) + # map query types to offsets query_offsets = offset_map[query_types] + # the kernel takes flattened offsets flatten_offsets = query_offsets.flatten() + # batched queries of the same type together for non-batched RoPE queries = [query[query_types == i] for i in range(len(scaling_factors))] keys = [key[query_types == i] for i in range(len(scaling_factors))] - packed_qk = zip(queries, keys) + packed_qkr = zip(queries, keys, non_batched_ropes) + # synchronize before start timing torch.cuda.synchronize() with nvtx.annotate("non-batched", color="yellow"): - for q, k in packed_qk: - # the value here is actually wrong because we don't pass any offsets - # but we are only interested in the time it takes to execute the kernel - rope.forward(positions, q, k) + for q, k, r in packed_qkr: + r.forward(positions, q, k) torch.cuda.synchronize() with nvtx.annotate("batched", color="green"): - rope.forward(positions, query, key, flatten_offsets) + batched_rope.forward(positions, query, key, flatten_offsets) torch.cuda.synchronize() @@ -84,8 +100,8 @@ def benchmark_rope_kernels_multi_lora( default=32) parser.add_argument("--dtype", type=str, - choices=["half", "bfloat16", "float"], - default="half") + choices=["bfloat16", "float"], + default="float") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--device", type=str, From d7f691ea9b435a69b7b1ef9eca9a0dadb74a49ef Mon Sep 17 00:00:00 2001 From: tterrysun Date: Fri, 8 Mar 2024 15:36:21 -0800 Subject: [PATCH 9/9] minor polishing --- benchmarks/kernels/benchmark_rope.py | 9 +++------ vllm/model_executor/layers/rotary_embedding.py | 4 +++- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 2cb67ccc20f75..f9564dd9588f0 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -85,7 +85,7 @@ def benchmark_rope_kernels_multi_lora( if __name__ == '__main__': parser = argparse.ArgumentParser( - description="Benchmark the rottery embedding kernels.") + description="Benchmark the rotary embedding kernels.") parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--seq-len", type=int, default=512) @@ -94,10 +94,7 @@ def benchmark_rope_kernels_multi_lora( type=int, choices=[64, 80, 96, 112, 128, 256], default=128) - parser.add_argument("--rottery-dim", - type=int, - choices=[16, 32], - default=32) + parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--dtype", type=str, choices=["bfloat16", "float"], @@ -116,7 +113,7 @@ def benchmark_rope_kernels_multi_lora( seq_len=args.seq_len, num_heads=args.num_heads, head_size=args.head_size, - rotary_dim=args.rottery_dim, + rotary_dim=args.rotary_dim, dtype=getattr(torch, args.dtype), seed=args.seed, device=args.device, diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index ea3ba4f9070e4..db5c7080b50b0 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -169,8 +169,10 @@ def __init__( max_position_embeddings: int, base: int, is_neox_style: bool, - scaling_factors: List[float], + scaling_factors: Union[List[float], float], ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] self.scaling_factors = scaling_factors super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style)