From 535faf16565452547fb28e1c0b29fe5661abf108 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 12 Oct 2023 01:41:18 -0700 Subject: [PATCH 1/4] fused rope Signed-off-by: Xin Yao --- .../fused_rotary_positional_embedding.cpp | 66 +++++++++ .../fused_rotary_positional_embedding.h | 129 ++++++++++++++++++ .../fused_rotary_positional_embedding_cuda.cu | 64 +++++++++ setup.py | 18 +++ 4 files changed, 277 insertions(+) create mode 100644 csrc/megatron/fused_rotary_positional_embedding.cpp create mode 100644 csrc/megatron/fused_rotary_positional_embedding.h create mode 100644 csrc/megatron/fused_rotary_positional_embedding_cuda.cu diff --git a/csrc/megatron/fused_rotary_positional_embedding.cpp b/csrc/megatron/fused_rotary_positional_embedding.cpp new file mode 100644 index 000000000..952709b96 --- /dev/null +++ b/csrc/megatron/fused_rotary_positional_embedding.cpp @@ -0,0 +1,66 @@ +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace fused_rope { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& cos, + torch::Tensor const& sin); + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads, + torch::Tensor const& cos, + torch::Tensor const& sin); + +torch::Tensor fwd( + torch::Tensor & input, + torch::Tensor & cos, + torch::Tensor & sin) { + if (!input.is_contiguous()) + input = input.contiguous(); + if (!cos.is_contiguous()) + cos = cos.contiguous(); + if (!sin.is_contiguous()) + sin = sin.contiguous(); + + return fwd_cuda(input, cos, sin); +} + +torch::Tensor bwd( + torch::Tensor & output_grads, + torch::Tensor & cos, + torch::Tensor & sin) { + if (!output_grads.is_contiguous()) + output_grads = output_grads.contiguous(); + if (!cos.is_contiguous()) + cos = cos.contiguous(); + if (!sin.is_contiguous()) + sin = sin.contiguous(); + + return bwd_cuda(output_grads, cos, sin); +} + +} // end namespace fused_rope + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &fused_rope::fwd, + "Fused Rotary Positional Embedding -- Forward."); + m.def("backward", &fused_rope::bwd, + "Fused Rotary Positional Embedding -- Backward."); +} diff --git a/csrc/megatron/fused_rotary_positional_embedding.h b/csrc/megatron/fused_rotary_positional_embedding.h new file mode 100644 index 000000000..c3d50fb06 --- /dev/null +++ b/csrc/megatron/fused_rotary_positional_embedding.h @@ -0,0 +1,129 @@ +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace { + +template +__global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2, + const scalar_t* src, const scalar_t* cos, + const scalar_t* sin, scalar_t* dst) { + int sq_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = sq_id * b * np * hn + b_id * np * hn; +#pragma unroll + for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) { + scalar_t v_cos = cos[sq_id * hn2 + hn_id]; + scalar_t v_sin = sin[sq_id * hn2 + hn_id]; +#pragma unroll + for (int head_id = 0; head_id < np; head_id += 1) { + int offset_src_dst = offset_block + head_id * hn + hn_id; + scalar_t v_src = src[offset_src_dst]; + scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2) + ? -src[offset_src_dst + hn2 / 2] + : src[offset_src_dst + hn2 / 2 - hn2]; + dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // copy the rest + if (hn > hn2) { +#pragma unroll + for (int head_id = 0; head_id < np; head_id += 1) { + int offset_head = offset_block + head_id * hn; +#pragma unroll + for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { + int offset_src_dst = offset_head + hn_id; + dst[offset_src_dst] = src[offset_src_dst]; + } + } + } +} + +template +__global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, + const scalar_t* src, const scalar_t* cos, + const scalar_t* sin, scalar_t* dst) { + int sq_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = sq_id * b * np * hn + b_id * np * hn; +#pragma unroll + for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) { + scalar_t v_cos = cos[sq_id * hn2 + hn_id]; + scalar_t v_sin = (hn_id + hn2 / 2 < hn2) + ? sin[sq_id * hn2 + hn_id + hn2 / 2] + : -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2]; +#pragma unroll + for (int head_id = 0; head_id < np; head_id += 1) { + int offset_src_dst = offset_block + head_id * hn + hn_id; + scalar_t v_src = src[offset_src_dst]; + scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2) + ? src[offset_src_dst + hn2 / 2] + : src[offset_src_dst + hn2 / 2 - hn2]; + dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin; + } + } + + // handle the tail + if (hn > hn2) { +#pragma unroll + for (int head_id = 0; head_id < np; head_id += 1) { + int offset_head = offset_block + head_id * hn; +#pragma unroll + for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { + dst[offset_head + hn_id] = 1.0; + } + } + } +} + +} // end of anonymous namespace + +template +void dispatch_fused_rope_forward(int sq, int b, int np, int hn, int hn2, + const scalar_t* input, const scalar_t* cos, + const scalar_t* sin, scalar_t* output) { + auto stream = at::cuda::getCurrentCUDAStream(); + + constexpr int threads_per_block = 256; + dim3 blocks(sq, b); + dim3 threads(threads_per_block); + + fused_rope_forward<<>>(sq, b, np, hn, hn2, input, + cos, sin, output); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void dispatch_fused_rope_backward(int sq, int b, int np, int hn, int hn2, + const scalar_t* output_grads, + const scalar_t* cos, const scalar_t* sin, + scalar_t* input_grads) { + auto stream = at::cuda::getCurrentCUDAStream(); + + constexpr int threads_per_block = 256; + dim3 blocks(sq, b); + dim3 threads(threads_per_block); + + fused_rope_backward<<>>( + sq, b, np, hn, hn2, output_grads, cos, sin, input_grads); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} diff --git a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu new file mode 100644 index 000000000..965c468c9 --- /dev/null +++ b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu @@ -0,0 +1,64 @@ +/* coding=utf-8 + * Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "fused_rotary_positional_embedding.h" +#include "type_shim.h" + +namespace fused_rope { + +torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& cos, + torch::Tensor const& sin) { + const int sq = input.size(0); + const int b = input.size(1); + const int np = input.size(2); + const int hn = input.size(3); + const int hn2 = cos.size(3); + + // output + auto act_options = input.options().requires_grad(false); + torch::Tensor output = torch::empty({sq, b, np, hn}, act_options); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), 0, "dispatch_fused_rope_forward", + dispatch_fused_rope_forward( + sq, b, np, hn, hn2, input.data_ptr(), + cos.data_ptr(), sin.data_ptr(), + output.data_ptr());); + return output; +} + +torch::Tensor bwd_cuda(torch::Tensor const& output_grads, + torch::Tensor const& cos, torch::Tensor const& sin) { + const int sq = output_grads.size(0); + const int b = output_grads.size(1); + const int np = output_grads.size(2); + const int hn = output_grads.size(3); + const int hn2 = cos.size(3); + + auto act_options = output_grads.options().requires_grad(false); + torch::Tensor input_grads = torch::empty({sq, b, np, hn}, act_options); + + DISPATCH_FLOAT_HALF_AND_BFLOAT( + output_grads.scalar_type(), 0, "dispatch_fused_rope_backward", + dispatch_fused_rope_backward( + sq, b, np, hn, hn2, output_grads.data_ptr(), + cos.data_ptr(), sin.data_ptr(), + input_grads.data_ptr());) + return input_grads; +} +} // end namespace fused_rope diff --git a/setup.py b/setup.py index 329f85646..bbfbba738 100644 --- a/setup.py +++ b/setup.py @@ -329,6 +329,24 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ) ) + ext_modules.append( + CUDAExtension( + name="fused_rotary_positional_embedding", + sources=["csrc/megatron/fused_rotary_positional_embedding.cpp", "csrc/megatron/fused_rotary_positional_embedding_cuda.cu"], + include_dirs=[os.path.join(this_dir, "csrc")], + extra_compile_args={ + "cxx": ["-O3"] + version_dependent_macros, + "nvcc": [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + version_dependent_macros, + }, + ) + ) + if bare_metal_version >= Version("11.0"): cc_flag = [] From f8815626c15084b1858cc9cdc223a42aedf16146 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 8 Nov 2023 22:39:39 -0800 Subject: [PATCH 2/4] add checks and a unit test Signed-off-by: Xin Yao --- .../fused_rotary_positional_embedding.cpp | 80 ++++++---- .../fused_rotary_positional_embedding_cuda.cu | 8 +- tests/L0/run_transformer/test_fused_rope.py | 138 ++++++++++++++++++ 3 files changed, 194 insertions(+), 32 deletions(-) create mode 100644 tests/L0/run_transformer/test_fused_rope.py diff --git a/csrc/megatron/fused_rotary_positional_embedding.cpp b/csrc/megatron/fused_rotary_positional_embedding.cpp index 952709b96..cc22a10a2 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.cpp +++ b/csrc/megatron/fused_rotary_positional_embedding.cpp @@ -18,40 +18,64 @@ namespace fused_rope { -torch::Tensor fwd_cuda( - torch::Tensor const& input, - torch::Tensor const& cos, - torch::Tensor const& sin); +torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos, + const torch::Tensor &sin); -torch::Tensor bwd_cuda( - torch::Tensor const& output_grads, - torch::Tensor const& cos, - torch::Tensor const& sin); +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, const torch::Tensor &sin); -torch::Tensor fwd( - torch::Tensor & input, - torch::Tensor & cos, - torch::Tensor & sin) { - if (!input.is_contiguous()) - input = input.contiguous(); - if (!cos.is_contiguous()) - cos = cos.contiguous(); - if (!sin.is_contiguous()) - sin = sin.contiguous(); +torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_, + const at::Tensor &sin_) { + auto input = input_.contiguous(); + auto cos = cos_.contiguous(); + auto sin = sin_.contiguous(); + TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(input.size(0) == cos.size(0), + "expected input and cos tensor have the same sequence length"); + TORCH_CHECK(input.size(0) == sin.size(0), + "expected input and sin tensor have the same sequence length"); + TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, + "expected the second and third dims of the cos tensor equal 1"); + TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, + "expected the second and third dims of the sin tensor equal 1"); + TORCH_CHECK(input.size(3) >= cos.size(3), + "expected the last dim of the input tensor is greater than the " + "cos tensor"); + TORCH_CHECK(input.size(3) >= sin.size(3), + "expected the last dim of the input tensor is greater than the " + "sin tensor"); return fwd_cuda(input, cos, sin); } -torch::Tensor bwd( - torch::Tensor & output_grads, - torch::Tensor & cos, - torch::Tensor & sin) { - if (!output_grads.is_contiguous()) - output_grads = output_grads.contiguous(); - if (!cos.is_contiguous()) - cos = cos.contiguous(); - if (!sin.is_contiguous()) - sin = sin.contiguous(); +torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_, + const at::Tensor &sin_) { + auto output_grads = output_grads_.contiguous(); + auto cos = cos_.contiguous(); + auto sin = sin_.contiguous(); + TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); + TORCH_CHECK( + output_grads.size(0) == cos.size(0), + "expected output_grads and cos tensor have the same sequence length"); + TORCH_CHECK( + output_grads.size(0) == sin.size(0), + "expected output_grads and sin tensor have the same sequence length"); + TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, + "expected the second and third dims of the cos tensor equal 1"); + TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, + "expected the second and third dims of the sin tensor equal 1"); + TORCH_CHECK( + output_grads.size(3) >= cos.size(3), + "expected the last dim of the output_grads tensor is greater than the " + "cos tensor"); + TORCH_CHECK( + output_grads.size(3) >= sin.size(3), + "expected the last dim of the output_grads tensor is greater than the " + "sin tensor"); return bwd_cuda(output_grads, cos, sin); } diff --git a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu index 965c468c9..7c09871cc 100644 --- a/csrc/megatron/fused_rotary_positional_embedding_cuda.cu +++ b/csrc/megatron/fused_rotary_positional_embedding_cuda.cu @@ -21,8 +21,8 @@ namespace fused_rope { -torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& cos, - torch::Tensor const& sin) { +torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos, + const torch::Tensor &sin) { const int sq = input.size(0); const int b = input.size(1); const int np = input.size(2); @@ -42,8 +42,8 @@ torch::Tensor fwd_cuda(torch::Tensor const& input, torch::Tensor const& cos, return output; } -torch::Tensor bwd_cuda(torch::Tensor const& output_grads, - torch::Tensor const& cos, torch::Tensor const& sin) { +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, const torch::Tensor &sin) { const int sq = output_grads.size(0); const int b = output_grads.size(1); const int np = output_grads.size(2); diff --git a/tests/L0/run_transformer/test_fused_rope.py b/tests/L0/run_transformer/test_fused_rope.py new file mode 100644 index 000000000..477842ab7 --- /dev/null +++ b/tests/L0/run_transformer/test_fused_rope.py @@ -0,0 +1,138 @@ +"""Test for fused RoPE functions. + +Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py +""" # NOQA +import itertools +from typing import Tuple, Union + +import torch +from torch.testing._internal import common_utils + + +class FusedRoPEFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, t: torch.Tensor, cos_: torch.Tensor, sin_: torch.Tensor + ) -> torch.Tensor: + import fused_rotary_positional_embedding + + output = fused_rotary_positional_embedding.forward(t, cos_, sin_) + ctx.save_for_backward(cos_, sin_) + + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + import fused_rotary_positional_embedding + + cos_, sin_ = ctx.saved_tensors + grad_q = fused_rotary_positional_embedding.backward(grad_output, cos_, sin_) + + return grad_q, None, None + + +def apply_rotary_pos_emb_fused(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + return FusedRoPEFunc.apply(t, cos_, sin_) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Change sign so the last dimension becomes [-odd, +even] + + Args: + x (Tensor): Input tensor + + Returns: + Tensor: Tensor rotated half + """ + + x1, x2 = torch.chunk(x, 2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T. + + check https://kexue.fm/archives/8265 for detailed formulas + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE + """ + rot_dim = freqs.shape[-1] + + # ideally t_pass is empty so rotary pos embedding is applied to all tensor t + t, t_pass = t[..., :rot_dim], t[..., rot_dim:] + + # first part is cosine component + # second part is sine component, need to change signs with _rotate_half method + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + + t = (t * cos_) + (_rotate_half(t) * sin_) + return torch.cat((t, t_pass), dim=-1) + + +class TestFusedRoPE(common_utils.TestCase): + def setUp(self): + super().setUp() + self.batch_size = 2 + self.head_num = 64 + self.seq_length = [2048, 4096] + self.hidden_size = [128, 256] + self.rotary_percent = [0.5, 1.0] + self.dtype = [torch.float32, torch.bfloat16, torch.float16] + self.device = torch.cuda.current_device() + + def tearDown(self) -> None: + torch.cuda.empty_cache() + super().tearDown() + + def test_forward_backward(self): + for dtype, seq_length, hidden_size, rotary_percent in itertools.product( + self.dtype, self.seq_length, self.hidden_size, self.rotary_percent + ): + t = torch.rand( + (seq_length, self.batch_size, self.head_num, hidden_size), + dtype=dtype, + device=self.device, + requires_grad=True, + ) + + emb = torch.rand( + (seq_length, 1, 1, int(hidden_size * rotary_percent)), + dtype=torch.float32, + device=self.device, + ) + + # unfused + output_unfused = apply_rotary_pos_emb(t, emb) + output_unfused.sum().backward() + grad_unfused = t.grad.detach().clone() + t.grad = None + + # fused + output_fused = apply_rotary_pos_emb_fused(t, emb) + output_fused.sum().backward() + grad_fused = t.grad.detach().clone() + + self.assertEqual( + output_unfused, + output_fused, + msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}", + ) + self.assertEqual( + grad_unfused, + grad_fused, + msg=f"{dtype=}, {seq_length=}, {hidden_size=}, {rotary_percent=}", + ) + + +if __name__ == "__main__": + common_utils.run_tests() From 82a3231d9e42e0ac36b783e801abd911114a890c Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 9 Nov 2023 01:38:09 -0800 Subject: [PATCH 3/4] use better block size Signed-off-by: Xin Yao --- .../megatron/fused_rotary_positional_embedding.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/csrc/megatron/fused_rotary_positional_embedding.h b/csrc/megatron/fused_rotary_positional_embedding.h index c3d50fb06..7ac13932d 100644 --- a/csrc/megatron/fused_rotary_positional_embedding.h +++ b/csrc/megatron/fused_rotary_positional_embedding.h @@ -35,7 +35,7 @@ __global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2, scalar_t v_cos = cos[sq_id * hn2 + hn_id]; scalar_t v_sin = sin[sq_id * hn2 + hn_id]; #pragma unroll - for (int head_id = 0; head_id < np; head_id += 1) { + for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { int offset_src_dst = offset_block + head_id * hn + hn_id; scalar_t v_src = src[offset_src_dst]; scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2) @@ -48,7 +48,7 @@ __global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2, // copy the rest if (hn > hn2) { #pragma unroll - for (int head_id = 0; head_id < np; head_id += 1) { + for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { @@ -72,7 +72,7 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, ? sin[sq_id * hn2 + hn_id + hn2 / 2] : -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2]; #pragma unroll - for (int head_id = 0; head_id < np; head_id += 1) { + for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { int offset_src_dst = offset_block + head_id * hn + hn_id; scalar_t v_src = src[offset_src_dst]; scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2) @@ -85,7 +85,7 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, // handle the tail if (hn > hn2) { #pragma unroll - for (int head_id = 0; head_id < np; head_id += 1) { + for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { int offset_head = offset_block + head_id * hn; #pragma unroll for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { @@ -103,9 +103,9 @@ void dispatch_fused_rope_forward(int sq, int b, int np, int hn, int hn2, const scalar_t* sin, scalar_t* output) { auto stream = at::cuda::getCurrentCUDAStream(); - constexpr int threads_per_block = 256; + int warps_per_block = np < 16 ? 4 : 8; dim3 blocks(sq, b); - dim3 threads(threads_per_block); + dim3 threads(C10_WARP_SIZE, warps_per_block); fused_rope_forward<<>>(sq, b, np, hn, hn2, input, cos, sin, output); @@ -119,9 +119,9 @@ void dispatch_fused_rope_backward(int sq, int b, int np, int hn, int hn2, scalar_t* input_grads) { auto stream = at::cuda::getCurrentCUDAStream(); - constexpr int threads_per_block = 256; + int warps_per_block = np < 16 ? 4 : 8; dim3 blocks(sq, b); - dim3 threads(threads_per_block); + dim3 threads(C10_WARP_SIZE, warps_per_block); fused_rope_backward<<>>( sq, b, np, hn, hn2, output_grads, cos, sin, input_grads); From 89a410f9446040f613b909e5f14865dbc0782786 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 9 Nov 2023 18:43:15 -0800 Subject: [PATCH 4/4] add fused_rope to functional Signed-off-by: Xin Yao --- apex/transformer/functional/__init__.py | 6 ++ apex/transformer/functional/fused_rope.py | 73 +++++++++++++++++++++ setup.py | 5 +- tests/L0/run_transformer/test_fused_rope.py | 37 ++--------- 4 files changed, 87 insertions(+), 34 deletions(-) create mode 100644 apex/transformer/functional/fused_rope.py diff --git a/apex/transformer/functional/__init__.py b/apex/transformer/functional/__init__.py index d770c8859..563078c1c 100644 --- a/apex/transformer/functional/__init__.py +++ b/apex/transformer/functional/__init__.py @@ -1,5 +1,11 @@ +from apex.transformer.functional.fused_rope import ( + fused_apply_rotary_pos_emb, + fused_apply_rotary_pos_emb_cached, +) from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax __all__ = [ "FusedScaleMaskSoftmax", + "fused_apply_rotary_pos_emb", + "fused_apply_rotary_pos_emb_cached", ] diff --git a/apex/transformer/functional/fused_rope.py b/apex/transformer/functional/fused_rope.py new file mode 100644 index 000000000..107665535 --- /dev/null +++ b/apex/transformer/functional/fused_rope.py @@ -0,0 +1,73 @@ +# coding=utf-8 +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple, Union +import torch + + +class FusedRoPEFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, t: torch.Tensor, cos_: torch.Tensor, sin_: torch.Tensor + ) -> torch.Tensor: + import fused_rotary_positional_embedding + + output = fused_rotary_positional_embedding.forward(t, cos_, sin_) + ctx.save_for_backward(cos_, sin_) + + return output + + @staticmethod + def backward( + ctx, grad_output: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + import fused_rotary_positional_embedding + + cos_, sin_ = ctx.saved_tensors + grad_q = fused_rotary_positional_embedding.backward(grad_output, cos_, sin_) + + return grad_q, None, None + + +def fused_apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T. + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE + """ + cos_ = torch.cos(freqs).to(t.dtype) + sin_ = torch.sin(freqs).to(t.dtype) + return FusedRoPEFunc.apply(t, cos_, sin_) + + +def fused_apply_rotary_pos_emb_cached( + t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> torch.Tensor: + """Apply rotary positional embedding to input tensor T. + + Args: + t (Tensor): Input tensor T is of shape [seq_length, ... , dim] + cos (Tensor): Cached cosine of the rotary positional embedding tensor is of shape [seq_length, ..., dim] + sin (Tensor): Cached sine of the rotary positional embedding tensor is of shape [seq_length, ..., dim] + + Returns: + Tensor: The input tensor after applying RoPE + """ + cos_ = cos.to(t.dtype) + sin_ = sin.to(t.dtype) + return FusedRoPEFunc.apply(t, cos_, sin_) diff --git a/setup.py b/setup.py index bbfbba738..c02d8339d 100644 --- a/setup.py +++ b/setup.py @@ -332,7 +332,10 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ext_modules.append( CUDAExtension( name="fused_rotary_positional_embedding", - sources=["csrc/megatron/fused_rotary_positional_embedding.cpp", "csrc/megatron/fused_rotary_positional_embedding_cuda.cu"], + sources=[ + "csrc/megatron/fused_rotary_positional_embedding.cpp", + "csrc/megatron/fused_rotary_positional_embedding_cuda.cu", + ], include_dirs=[os.path.join(this_dir, "csrc")], extra_compile_args={ "cxx": ["-O3"] + version_dependent_macros, diff --git a/tests/L0/run_transformer/test_fused_rope.py b/tests/L0/run_transformer/test_fused_rope.py index 477842ab7..be557054e 100644 --- a/tests/L0/run_transformer/test_fused_rope.py +++ b/tests/L0/run_transformer/test_fused_rope.py @@ -3,40 +3,10 @@ Ref: https://github.com/NVIDIA/Megatron-LM/blob/40becfc96c4144985458ac0e0fae45dbb111fbd2/megatron/fused_kernels/tests/test_fused_kernels.py """ # NOQA import itertools -from typing import Tuple, Union import torch from torch.testing._internal import common_utils - - -class FusedRoPEFunc(torch.autograd.Function): - @staticmethod - def forward( - ctx, t: torch.Tensor, cos_: torch.Tensor, sin_: torch.Tensor - ) -> torch.Tensor: - import fused_rotary_positional_embedding - - output = fused_rotary_positional_embedding.forward(t, cos_, sin_) - ctx.save_for_backward(cos_, sin_) - - return output - - @staticmethod - def backward( - ctx, grad_output: torch.Tensor - ) -> Tuple[Union[torch.Tensor, None], ...]: - import fused_rotary_positional_embedding - - cos_, sin_ = ctx.saved_tensors - grad_q = fused_rotary_positional_embedding.backward(grad_output, cos_, sin_) - - return grad_q, None, None - - -def apply_rotary_pos_emb_fused(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) - return FusedRoPEFunc.apply(t, cos_, sin_) +from apex.transformer.functional import fused_apply_rotary_pos_emb def _rotate_half(x: torch.Tensor) -> torch.Tensor: @@ -52,7 +22,8 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: x1, x2 = torch.chunk(x, 2, dim=-1) return torch.cat((-x2, x1), dim=-1) - +# Copied from Megatron-Core for testing. +# https://github.com/NVIDIA/Megatron-LM/blob/5f2877d85cb26e47ce6dcdae4b80adf376abf4e8/megatron/core/models/common/embeddings/rotary_pos_embedding.py#L139 def apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: """Apply rotary positional embedding to input tensor T. @@ -118,7 +89,7 @@ def test_forward_backward(self): t.grad = None # fused - output_fused = apply_rotary_pos_emb_fused(t, emb) + output_fused = fused_apply_rotary_pos_emb(t, emb) output_fused.sum().backward() grad_fused = t.grad.detach().clone()