diff --git a/benchmarks/bench_append_paged_kv_cache.py b/benchmarks/bench_append_paged_kv_cache.py index 4fd67af1..0cc7d0e1 100644 --- a/benchmarks/bench_append_paged_kv_cache.py +++ b/benchmarks/bench_append_paged_kv_cache.py @@ -2,10 +2,11 @@ import dataclasses from typing import cast -import flashinfer import torch from triton.testing import do_bench +import flashinfer + @dataclasses.dataclass(kw_only=True) class ModelConfig: diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index c52795f8..829065a4 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include "layout.cuh" @@ -156,6 +157,55 @@ __device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin_i return vec; } +template +__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_cache, + float* __restrict__ sin_cache, IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, size_t q_stride_n, size_t q_stride_h, + size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, + size_t k_rope_stride_n, size_t k_rope_stride_h) { + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bdy = blockDim.y; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + + cos.load(cos_cache + pos * head_dim + tx * vec_size); + sin.load(sin_cache + pos * head_dim + tx * vec_size); + +#pragma unroll 1 + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } + +#pragma unroll 1 + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin); + } + k_vec.cast_store(k_rope_ptr + tx * vec_size); + } + } +} + template __global__ void BatchQKApplyRotaryPosIdsKernel( @@ -309,6 +359,48 @@ __global__ void BatchQKApplyRotaryKernel( __VA_ARGS__ \ } +template +cudaError_t BatchQKApplyRotaryPosIdsCosSinCache( + DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_cache, float* sin_cache, + IdType* pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + bool interleave, cudaStream_t stream = nullptr) { + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks((nnz + bdy - 1) / bdy); + dim3 nthrs(bdx, bdy); + auto kernel = BatchQKApplyRotaryPosIdsCosSinCacheKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&cos_cache, + (void*)&sin_cache, + (void*)&pos_ids, + (void*)&nnz, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h}; + FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + }); + }); + + return cudaSuccess; +} + template cudaError_t BatchQKApplyRotaryPosIds(DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz, diff --git a/python/csrc/flashinfer_rope_ops.cu b/python/csrc/flashinfer_rope_ops.cu index c6259968..c07be244 100644 --- a/python/csrc/flashinfer_rope_ops.cu +++ b/python/csrc/flashinfer_rope_ops.cu @@ -35,10 +35,17 @@ void apply_llama31_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor float rope_scale, float rope_theta, float low_freq_factor, float high_freq_factor, float old_context_length); +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor cos_cache, + torch::Tensor sin_cache, torch::Tensor pos_ids, + bool interleave); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("apply_rope", &apply_rope, "Apply RoPE"); m.def("apply_llama31_rope", &apply_llama31_rope, "Apply Llama 3.1 style RoPE"); m.def("apply_rope_pos_ids", &apply_rope_pos_ids, "Apply RoPE with positional ids"); m.def("apply_llama31_rope_pos_ids", &apply_llama31_rope_pos_ids, "Apply Llama 3.1 style RoPE with positional ids"); + m.def("apply_rope_pos_ids_cos_sin_cache", &apply_rope_pos_ids_cos_sin_cache, + "Apply RoPE with positional ids and cosine/sine cache"); } diff --git a/python/csrc/rope.cu b/python/csrc/rope.cu index 8f661da0..9f773284 100644 --- a/python/csrc/rope.cu +++ b/python/csrc/rope.cu @@ -22,8 +22,8 @@ using namespace flashinfer; void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta) { - CHECK_CUDA(q); // not necessarily contiguous - CHECK_CUDA(k); // not necessarily contiguous + CHECK_LAST_DIM_CONTIGUOUS(q); + CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(indptr); CHECK_INPUT(offsets); @@ -69,8 +69,8 @@ void apply_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::T void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor pos_ids, bool interleave, float rope_scale, float rope_theta) { - CHECK_CUDA(q); // not necessarily contiguous - CHECK_CUDA(k); // not necessarily contiguous + CHECK_LAST_DIM_CONTIGUOUS(q); + CHECK_LAST_DIM_CONTIGUOUS(k); CHECK_INPUT(pos_ids); auto device = q.device(); @@ -107,6 +107,60 @@ void apply_rope_pos_ids(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, }); } +void apply_rope_pos_ids_cos_sin_cache(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, + torch::Tensor k_rope, torch::Tensor cos_cache, + torch::Tensor sin_cache, torch::Tensor pos_ids, + bool interleave) { + CHECK_LAST_DIM_CONTIGUOUS(q); + CHECK_LAST_DIM_CONTIGUOUS(k); + CHECK_INPUT(cos_cache); + CHECK_INPUT(sin_cache); + CHECK_INPUT(pos_ids); + auto device = q.device(); + CHECK_EQ(k.device(), device); + CHECK_EQ(cos_cache.device(), device); + CHECK_EQ(sin_cache.device(), device); + CHECK_EQ(pos_ids.device(), device); + CHECK_DIM(3, q); // q: (nnz, H_Q, D) + CHECK_DIM(3, k); // k: (nnz, H_K, D) + CHECK_DIM(2, cos_cache); // cos_cache: (max_seq_len, D) + CHECK_DIM(2, sin_cache); // sin_cache: (max_seq_len, D) + CHECK_EQ(q.size(0), k.size(0)); + CHECK_EQ(q.size(2), k.size(2)); + CHECK_EQ(cos_cache.size(1), q.size(2)); + CHECK_EQ(sin_cache.size(1), q.size(2)); + CHECK_EQ(cos_cache.dtype(), torch::kFloat32); + CHECK_EQ(sin_cache.dtype(), torch::kFloat32); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int nnz = q.size(0); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + pos_ids = pos_ids.to(torch::kInt32); + + cudaStream_t torch_current_stream = c10::cuda::getCurrentCUDAStream(device.index()); + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(q.scalar_type(), c_type, [&] { + cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(cos_cache.data_ptr()), static_cast(sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, head_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, interleave, torch_current_stream); + TORCH_CHECK(status == cudaSuccess, + "BatchQKApplyRotaryPosIdsCosSinCache failed with error code " + + std::string(cudaGetErrorString(status))); + return true; + }); +} + void apply_llama31_rope(torch::Tensor q, torch::Tensor k, torch::Tensor q_rope, torch::Tensor k_rope, torch::Tensor indptr, torch::Tensor offsets, bool interleave, float rope_scale, float rope_theta, float low_freq_factor, diff --git a/python/flashinfer/__init__.py b/python/flashinfer/__init__.py index 9baf6849..5284f268 100644 --- a/python/flashinfer/__init__.py +++ b/python/flashinfer/__init__.py @@ -60,10 +60,18 @@ from .quantization import segment_packbits as segment_packbits from .rope import apply_llama31_rope as apply_llama31_rope from .rope import apply_llama31_rope_inplace as apply_llama31_rope_inplace +from .rope import apply_llama31_rope_pos_ids as apply_llama31_rope_pos_ids +from .rope import ( + apply_llama31_rope_pos_ids_inplace as apply_llama31_rope_pos_ids_inplace, +) from .rope import apply_rope as apply_rope from .rope import apply_rope_inplace as apply_rope_inplace from .rope import apply_rope_pos_ids as apply_rope_pos_ids from .rope import apply_rope_pos_ids_inplace as apply_rope_pos_ids_inplace +from .rope import apply_rope_with_cos_sin_cache as apply_rope_with_cos_sin_cache +from .rope import ( + apply_rope_with_cos_sin_cache_inplace as apply_rope_with_cos_sin_cache_inplace, +) from .sampling import chain_speculative_sampling as chain_speculative_sampling from .sampling import min_p_sampling_from_probs as min_p_sampling_from_probs from .sampling import sampling_from_probs as sampling_from_probs diff --git a/python/flashinfer/rope.py b/python/flashinfer/rope.py index 60ca2c6e..e28a69d4 100644 --- a/python/flashinfer/rope.py +++ b/python/flashinfer/rope.py @@ -153,6 +153,45 @@ def _fake_apply_rope_pos_ids( pass +@register_custom_op( + "flashinfer::apply_rope_pos_ids_cos_sin_cache", mutates_args=("q_rope", "k_rope") +) +def _apply_rope_pos_ids_cos_sin_cache( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, +) -> None: + get_rope_module().apply_rope_pos_ids_cos_sin_cache( + q, + k, + q_rope, + k_rope, + cos_cache, + sin_cache, + pos_ids, + interleave, + ) + + +@register_fake_op("flashinfer::apply_rope_pos_ids_cos_sin_cache") +def _fake_apply_rope_pos_ids_cos_sin_cache( + q: torch.Tensor, + k: torch.Tensor, + q_rope: torch.Tensor, + k_rope: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool, +) -> None: + pass + + @register_custom_op( "flashinfer::apply_llama31_rope_pos_ids", mutates_args=("q_rope", "k_rope") ) @@ -211,6 +250,7 @@ def apply_rope_inplace( rope_theta: float = 1e4, ) -> None: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -288,6 +328,7 @@ def apply_rope_pos_ids_inplace( rope_theta: float = 1e4, ) -> None: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor) inplace. + cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -333,7 +374,7 @@ def apply_llama31_rope_inplace( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, - interleave: bool = True, + interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, @@ -341,7 +382,7 @@ def apply_llama31_rope_inplace( old_context_len: int = 8192, ) -> None: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as - RaggedTensor) inplace. + RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -433,7 +474,7 @@ def apply_llama31_rope_pos_ids_inplace( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, - interleave: bool = True, + interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, @@ -441,7 +482,7 @@ def apply_llama31_rope_pos_ids_inplace( old_context_len: int = 8192, ) -> None: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as - RaggedTensor) inplace. + RaggedTensor) inplace. cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -510,6 +551,7 @@ def apply_rope( rope_theta: float = 1e4, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -603,6 +645,7 @@ def apply_rope_pos_ids( rope_theta: float = 1e4, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -660,7 +703,7 @@ def apply_llama31_rope( k: torch.Tensor, indptr: torch.Tensor, offsets: torch.Tensor, - interleave: bool = True, + interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, @@ -668,7 +711,7 @@ def apply_llama31_rope( old_context_len: int = 8192, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as - RaggedTensor). + RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -774,7 +817,7 @@ def apply_llama31_rope_pos_ids( q: torch.Tensor, k: torch.Tensor, pos_ids: torch.Tensor, - interleave: bool = True, + interleave: bool = False, rope_scale: float = 8, rope_theta: float = 5e5, low_freq_factor: float = 1, @@ -782,7 +825,7 @@ def apply_llama31_rope_pos_ids( old_context_len: int = 8192, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Apply Llama 3.1 style rotary embedding to a batch of queries/keys (stored as - RaggedTensor). + RaggedTensor). cos/sin values are computed on the fly inside the kernel. We use :attr:`indptr` to denote the start pointer of each segment in the batch, the i-th segment the query of the i-th segment is ``q[indptr[i]:indptr[i+1]]`` and the key of the @@ -848,3 +891,94 @@ def apply_llama31_rope_pos_ids( float(old_context_len), ) return q_rope, k_rope + + +def apply_rope_with_cos_sin_cache( + q: torch.Tensor, + k: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Apply rotary embedding to keys and queries with precomputed cos/sin values. + + Parameters + ---------- + q : torch.Tensor + Query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k : torch.Tensor + Key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + cos_cache : torch.Tensor + Cosine cache tensor, shape: ``(max_seq_len, head_dim)``. + sin_cache : torch.Tensor + Sine cache tensor, shape: ``(max_seq_len, head_dim)``. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + + Returns + ------- + q_rope : torch.Tensor + The rotated query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k_rope : torch.Tensor + The rotated key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + """ + if cos_cache.dtype != torch.float32 or sin_cache.dtype != torch.float32: + raise ValueError("cos_cache and sin_cache should be float32") + q_rope = torch.empty_like(q) + k_rope = torch.empty_like(k) + _apply_rope_pos_ids_cos_sin_cache( + q, k, q_rope, k_rope, cos_cache, sin_cache, pos_ids, interleave + ) + return q_rope, k_rope + + +def apply_rope_with_cos_sin_cache_inplace( + q: torch.Tensor, + k: torch.Tensor, + cos_cache: torch.Tensor, + sin_cache: torch.Tensor, + pos_ids: torch.Tensor, + interleave: bool = False, +) -> None: + r"""Apply rotary embedding to keys and queries with precomputed cos/sin values. + The result is stored in the input tensors inplace. + + Parameters + ---------- + q : torch.Tensor + Query tensor, shape: ``(nnz, num_q_heads, head_dim)``. + k : torch.Tensor + Key tensor, shape: ``(nnz, num_k_heads, head_dim)``. + cos_cache : torch.Tensor + Cosine cache tensor, shape: ``(max_seq_len, head_dim)``. + Expect float32 data type. + sin_cache : torch.Tensor + Sine cache tensor, shape: ``(max_seq_len, head_dim)``. + Expect float32 data type. + pos_ids : torch.Tensor + Position indices, shape: ``(nnz)``. + interleave : bool + Whether to use interleaved layout in the last dimension, default: ``False``. + + * If ``True``, the last dimension of the query/key tensor is interleaved, i.e., + we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``. + + * If ``False``, the last dimension of the query/key tensor is not interleaved, i.e., + we rorate the first half dimensions ``([..., :head_dim//2])`` and the second half + dimensions ``([..., head_dim//2:])``. + """ + if cos_cache.dtype != torch.float32 or sin_cache.dtype != torch.float32: + raise ValueError("cos_cache and sin_cache should be float32") + _apply_rope_pos_ids_cos_sin_cache( + q, k, q, k, cos_cache, sin_cache, pos_ids, interleave + ) diff --git a/tests/rope_reference.py b/tests/rope_reference.py index 494d9e91..7a176692 100644 --- a/tests/rope_reference.py +++ b/tests/rope_reference.py @@ -69,3 +69,31 @@ def apply_rotary_emb( xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) + + +def apply_rotary_pos_emb(q, k, cos, sin): + cos = cos.unsqueeze(1) + sin = sin.unsqueeze(1) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def generate_cos_sin_f32_cache( + max_seq_len, head_dim, theta=1e4, use_scaled: bool = False +): + position = torch.arange(max_seq_len).float().unsqueeze(1) + freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) + freqs = torch.cat([freqs, freqs], dim=-1).contiguous() + if use_scaled: + freqs = apply_scaling(freqs) + args = position * freqs + sin_cache = torch.sin(args) + cos_cache = torch.cos(args) + return cos_cache, sin_cache diff --git a/tests/test_rope.py b/tests/test_rope.py index 513cdfa8..6af2a549 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -27,13 +27,17 @@ @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama_rope_inplace( +@pytest.mark.parametrize("llama_version", ["llama", "llama31"]) +@pytest.mark.parametrize("inplace", [False, True]) +def test_rope( batch_size, qkv_len, num_qo_heads, num_kv_heads, offset, head_dim, + llama_version, + inplace, ): nnz = batch_size * qkv_len qkv_packed = torch.randn( @@ -52,61 +56,14 @@ def test_llama_rope_inplace( offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") # reference implementation - freqs_cis = precompute_freqs_cis( - head_dim, qkv_len + offset, 10000.0, use_scaled=False - ).to("cuda:0") - q_rope_ref, k_rope_ref = apply_rotary_emb( - q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), - k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), - freqs_cis[offset : offset + qkv_len], - ) - q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) - k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) - - # flashinfer implementation - flashinfer.apply_rope_inplace( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 - ) - - # compare - torch.testing.assert_close(q_rope_ref, q, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(k_rope_ref, k, rtol=1e-3, atol=1e-3) - - -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) -@pytest.mark.parametrize("num_qo_heads", [8, 16]) -@pytest.mark.parametrize("num_kv_heads", [8]) -@pytest.mark.parametrize("offset", [0, 15, 99]) -@pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama_rope( - batch_size, - qkv_len, - num_qo_heads, - num_kv_heads, - offset, - head_dim, -): - nnz = batch_size * qkv_len - qkv_packed = torch.randn( - nnz, - (num_qo_heads + 2 * num_kv_heads) * head_dim, - dtype=torch.float16, - device="cuda:0", - ) - q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) - k = qkv_packed[ - :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim - ].reshape(nnz, num_kv_heads, head_dim) - indptr = torch.tensor( - [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" - ) - offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") - - # reference implementation - freqs_cis = precompute_freqs_cis( - head_dim, qkv_len + offset, 10000.0, use_scaled=False - ).to("cuda:0") + if llama_version == "llama": + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 10000.0, use_scaled=False + ).to("cuda:0") + else: + freqs_cis = precompute_freqs_cis( + head_dim, qkv_len + offset, 5e5, use_scaled=True + ).to("cuda:0") q_rope_ref, k_rope_ref = apply_rotary_emb( q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), @@ -116,9 +73,26 @@ def test_llama_rope( k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) # flashinfer implementation - q_rope, k_rope = flashinfer.apply_rope( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 - ) + if llama_version == "llama": + if inplace: + flashinfer.apply_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + q_rope, k_rope = q, k + else: + q_rope, k_rope = flashinfer.apply_rope( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + else: + if inplace: + flashinfer.apply_llama31_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=5e5 + ) + q_rope, k_rope = q, k + else: + q_rope, k_rope = flashinfer.apply_llama31_rope( + q, k, indptr, offsets, interleave=True, rope_theta=5e5 + ) # compare torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) @@ -131,13 +105,17 @@ def test_llama_rope( @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama_rope_pos_ids( +@pytest.mark.parametrize("llama_version", ["llama", "llama31"]) +@pytest.mark.parametrize("inplace", [False, True]) +def test_rope_pos_ids( batch_size, qkv_len, num_qo_heads, num_kv_heads, offset, head_dim, + llama_version, + inplace, ): nnz = batch_size * qkv_len qkv_packed = torch.randn( @@ -162,13 +140,44 @@ def test_llama_rope_pos_ids( ] ).to("cuda:0") - q_rope, k_rope = flashinfer.apply_rope( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 - ) - - q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids( - q, k, pos_ids, interleave=True, rope_theta=1e4 - ) + if llama_version == "llama": + if inplace: + q_clone, k_clone = q.clone(), k.clone() + flashinfer.apply_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + q_rope, k_rope = q, k + flashinfer.apply_rope_pos_ids_inplace( + q_clone, k_clone, pos_ids, interleave=True, rope_theta=1e4 + ) + q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone + else: + q_rope, k_rope = flashinfer.apply_rope( + q, k, indptr, offsets, interleave=True, rope_theta=1e4 + ) + + q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_rope_pos_ids( + q, k, pos_ids, interleave=True, rope_theta=1e4 + ) + else: + if inplace: + q_clone, k_clone = q.clone(), k.clone() + flashinfer.apply_llama31_rope_inplace( + q, k, indptr, offsets, interleave=True, rope_theta=5e5 + ) + q_rope, k_rope = q, k + flashinfer.apply_llama31_rope_pos_ids_inplace( + q_clone, k_clone, pos_ids, interleave=True, rope_theta=5e5 + ) + q_rope_pos_ids, k_rope_pos_ids = q_clone, k_clone + else: + q_rope, k_rope = flashinfer.apply_llama31_rope( + q, k, indptr, offsets, interleave=True, rope_theta=5e5 + ) + + q_rope_pos_ids, k_rope_pos_ids = flashinfer.apply_llama31_rope_pos_ids( + q, k, pos_ids, interleave=True, rope_theta=5e5 + ) # compare torch.testing.assert_close(q_rope_pos_ids, q_rope, rtol=1e-3, atol=1e-3) @@ -181,13 +190,17 @@ def test_llama_rope_pos_ids( @pytest.mark.parametrize("num_kv_heads", [8]) @pytest.mark.parametrize("offset", [0, 15, 99]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama_rope_pos_ids_inplace( +@pytest.mark.parametrize("llama_version", ["llama", "llama31"]) +@pytest.mark.parametrize("inplace", [False, True]) +def test_rope_cos_sin_cache( batch_size, qkv_len, num_qo_heads, num_kv_heads, offset, head_dim, + llama_version, + inplace, ): nnz = batch_size * qkv_len qkv_packed = torch.randn( @@ -200,11 +213,6 @@ def test_llama_rope_pos_ids_inplace( k = qkv_packed[ :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim ].reshape(nnz, num_kv_heads, head_dim) - indptr = torch.tensor( - [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" - ) - offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") - pos_ids = torch.cat( [ torch.arange(offset, qkv_len + offset, dtype=torch.int32) @@ -212,130 +220,55 @@ def test_llama_rope_pos_ids_inplace( ] ).to("cuda:0") - q_clone = q.clone() - k_clone = k.clone() - - flashinfer.apply_rope_inplace( - q, k, indptr, offsets, interleave=True, rope_theta=1e4 - ) - - flashinfer.apply_rope_pos_ids_inplace( - q_clone, k_clone, pos_ids, interleave=True, rope_theta=1e4 - ) + if llama_version == "llama": + cos_cache, sin_cache = generate_cos_sin_f32_cache(4096, head_dim, theta=1e4) + else: + cos_cache, sin_cache = generate_cos_sin_f32_cache( + 4096, head_dim, theta=5e5, use_scaled=True + ) + cos_cache = cos_cache.to("cuda:0") + sin_cache = sin_cache.to("cuda:0") + + if inplace: + q_clone, k_clone = q.clone(), k.clone() + + if llama_version == "llama": + if inplace: + flashinfer.apply_rope_pos_ids_inplace(q, k, pos_ids, interleave=False) + q_rope, k_rope = q, k + else: + q_rope, k_rope = flashinfer.apply_rope_pos_ids( + q, k, pos_ids, interleave=False + ) + else: + if inplace: + flashinfer.apply_llama31_rope_pos_ids_inplace( + q, k, pos_ids, interleave=False + ) + q_rope, k_rope = q, k + else: + q_rope, k_rope = flashinfer.apply_llama31_rope_pos_ids( + q, k, pos_ids, interleave=False + ) + + if inplace: + flashinfer.apply_rope_with_cos_sin_cache_inplace( + q_clone, k_clone, cos_cache, sin_cache, pos_ids, interleave=False + ) + q_rope_cos_sin_cache, k_rope_cos_sin_cache = q_clone, k_clone + else: + q_rope_cos_sin_cache, k_rope_cos_sin_cache = ( + flashinfer.apply_rope_with_cos_sin_cache( + q, k, cos_cache, sin_cache, pos_ids, interleave=False + ) + ) # compare - torch.testing.assert_close(q_clone, q, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(k_clone, k, rtol=1e-3, atol=1e-3) - - -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) -@pytest.mark.parametrize("num_qo_heads", [8, 16]) -@pytest.mark.parametrize("num_kv_heads", [8]) -@pytest.mark.parametrize("offset", [0, 15, 99]) -@pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama31_rope_inplace( - batch_size, - qkv_len, - num_qo_heads, - num_kv_heads, - offset, - head_dim, -): - nnz = batch_size * qkv_len - qkv_packed = torch.randn( - nnz, - (num_qo_heads + 2 * num_kv_heads) * head_dim, - dtype=torch.float16, - device="cuda:0", - ) - q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) - k = qkv_packed[ - :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim - ].reshape(nnz, num_kv_heads, head_dim) - indptr = torch.tensor( - [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" - ) - offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") - - # reference implementation - freqs_cis = precompute_freqs_cis( - head_dim, qkv_len + offset, 5e5, use_scaled=True - ).to("cuda:0") - q_rope_ref, k_rope_ref = apply_rotary_emb( - q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), - k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), - freqs_cis[offset : offset + qkv_len], - ) - q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) - k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) - - # flashinfer implementation - flashinfer.apply_llama31_rope_inplace( - q, k, indptr, offsets, interleave=True, rope_theta=5e5 - ) - - # compare - torch.testing.assert_close(q_rope_ref, q, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(k_rope_ref, k, rtol=1e-3, atol=1e-3) - - -@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) -@pytest.mark.parametrize("qkv_len", [1, 4, 19, 204]) -@pytest.mark.parametrize("num_qo_heads", [8, 16]) -@pytest.mark.parametrize("num_kv_heads", [8]) -@pytest.mark.parametrize("offset", [0, 15, 99]) -@pytest.mark.parametrize("head_dim", [64, 128, 256]) -def test_llama31_rope( - batch_size, - qkv_len, - num_qo_heads, - num_kv_heads, - offset, - head_dim, -): - nnz = batch_size * qkv_len - qkv_packed = torch.randn( - nnz, - (num_qo_heads + 2 * num_kv_heads) * head_dim, - dtype=torch.float16, - device="cuda:0", - ) - q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim) - k = qkv_packed[ - :, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim - ].reshape(nnz, num_kv_heads, head_dim) - indptr = torch.tensor( - [i * qkv_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0" - ) - offsets = torch.full((batch_size,), offset, dtype=torch.int32, device="cuda:0") - - # reference implementation - freqs_cis = precompute_freqs_cis( - head_dim, qkv_len + offset, 5e5, use_scaled=True - ).to("cuda:0") - q_rope_ref, k_rope_ref = apply_rotary_emb( - q.reshape(batch_size, qkv_len, num_qo_heads, head_dim), - k.reshape(batch_size, qkv_len, num_kv_heads, head_dim), - freqs_cis[offset : offset + qkv_len], - ) - q_rope_ref = q_rope_ref.reshape(nnz, num_qo_heads, head_dim) - k_rope_ref = k_rope_ref.reshape(nnz, num_kv_heads, head_dim) - - # flashinfer implementation - q_rope, k_rope = flashinfer.apply_llama31_rope( - q, k, indptr, offsets, interleave=True, rope_theta=5e5 - ) - - # compare - torch.testing.assert_close(q_rope_ref, q_rope, rtol=1e-3, atol=1e-3) - torch.testing.assert_close(k_rope_ref, k_rope, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(q_rope, q_rope_cos_sin_cache, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(k_rope, k_rope_cos_sin_cache, rtol=1e-3, atol=1e-3) if __name__ == "__main__": - test_llama_rope_inplace(2, 1, 8, 8, 1, 128) - test_llama31_rope_inplace(1, 1, 8, 8, 0, 128) - test_llama_rope(2, 1, 8, 8, 1, 128) - test_llama31_rope(1, 1, 8, 8, 0, 128) - test_llama_rope_pos_ids(2, 1, 8, 8, 1, 128) - test_llama_rope_pos_ids_inplace(2, 1, 8, 8, 1, 128) + test_rope(2, 1, 8, 8, 1, 128, "llama31", False) + test_rope_pos_ids(2, 1, 8, 8, 1, 128, "llama31", False) + test_rope_cos_sin_cache(99, 19, 16, 8, 99, 256, "llama31", False)