diff --git a/megatron/fused_kernels/fused_rotary_positional_embedding.cpp b/megatron/fused_kernels/fused_rotary_positional_embedding.cpp index cc22a10a2..ad6b26da0 100644 --- a/megatron/fused_kernels/fused_rotary_positional_embedding.cpp +++ b/megatron/fused_kernels/fused_rotary_positional_embedding.cpp @@ -14,77 +14,97 @@ * limitations under the License. */ -#include +#include + +#include "fused_rotary_positional_embedding.h" +#include "type_shim.h" namespace fused_rope { torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos, - const torch::Tensor &sin); - -torch::Tensor bwd_cuda(const torch::Tensor &output_grads, - const torch::Tensor &cos, const torch::Tensor &sin); + const torch::Tensor &sin, const bool transpose_output) { + // input sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = input.size(0); + const int b = input.size(1); + const int h = input.size(2); + const int d = input.size(3); + // input strides + const int stride_s = input.stride(0); + const int stride_b = input.stride(1); + const int stride_h = input.stride(2); + const int stride_d = input.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); -torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_, - const at::Tensor &sin_) { - auto input = input_.contiguous(); - auto cos = cos_.contiguous(); - auto sin = sin_.contiguous(); - TORCH_CHECK(input.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(input.size(0) == cos.size(0), - "expected input and cos tensor have the same sequence length"); - TORCH_CHECK(input.size(0) == sin.size(0), - "expected input and sin tensor have the same sequence length"); - TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, - "expected the second and third dims of the cos tensor equal 1"); - TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, - "expected the second and third dims of the sin tensor equal 1"); - TORCH_CHECK(input.size(3) >= cos.size(3), - "expected the last dim of the input tensor is greater than the " - "cos tensor"); - TORCH_CHECK(input.size(3) >= sin.size(3), - "expected the last dim of the input tensor is greater than the " - "sin tensor"); + // output + auto act_options = input.options().requires_grad(false); + torch::Tensor output; + if (transpose_output) { + output = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + output = torch::empty({s, b, h, d}, act_options); + } + // output strides + const int o_stride_s = output.stride(0); + const int o_stride_b = output.stride(1); + const int o_stride_h = output.stride(2); + const int o_stride_d = output.stride(3); - return fwd_cuda(input, cos, sin); + DISPATCH_FLOAT_HALF_AND_BFLOAT( + input.scalar_type(), 0, "dispatch_fused_rope_forward", + dispatch_fused_rope_forward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, input.data_ptr(), + cos.data_ptr(), sin.data_ptr(), + output.data_ptr());); + return output; } -torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_, - const at::Tensor &sin_) { - auto output_grads = output_grads_.contiguous(); - auto cos = cos_.contiguous(); - auto sin = sin_.contiguous(); - TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(cos.dim() == 4, "expected 4D tensor"); - TORCH_CHECK(sin.dim() == 4, "expected 4D tensor"); - TORCH_CHECK( - output_grads.size(0) == cos.size(0), - "expected output_grads and cos tensor have the same sequence length"); - TORCH_CHECK( - output_grads.size(0) == sin.size(0), - "expected output_grads and sin tensor have the same sequence length"); - TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1, - "expected the second and third dims of the cos tensor equal 1"); - TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1, - "expected the second and third dims of the sin tensor equal 1"); - TORCH_CHECK( - output_grads.size(3) >= cos.size(3), - "expected the last dim of the output_grads tensor is greater than the " - "cos tensor"); - TORCH_CHECK( - output_grads.size(3) >= sin.size(3), - "expected the last dim of the output_grads tensor is greater than the " - "sin tensor"); - - return bwd_cuda(output_grads, cos, sin); -} +torch::Tensor bwd_cuda(const torch::Tensor &output_grads, + const torch::Tensor &cos, const torch::Tensor &sin, + const bool transpose_output) { + // output_grads sizes: (s, b, h, d) + // s: sequence length + // b: batch size + // h: head num + // d: dim of each head + const int s = output_grads.size(0); + const int b = output_grads.size(1); + const int h = output_grads.size(2); + const int d = output_grads.size(3); + // output_grads strides + const int stride_s = output_grads.stride(0); + const int stride_b = output_grads.stride(1); + const int stride_h = output_grads.stride(2); + const int stride_d = output_grads.stride(3); + // cos/sin's shape is always (s, 1, 1, d2), so the strides are same under + // different memory formats + const int d2 = cos.size(3); -} // end namespace fused_rope + auto act_options = output_grads.options().requires_grad(false); + torch::Tensor input_grads; + if (transpose_output) { + input_grads = torch::empty({b, s, h, d}, act_options).transpose(0, 1); + } else { + input_grads = torch::empty({s, b, h, d}, act_options); + } + const int o_stride_s = input_grads.stride(0); + const int o_stride_b = input_grads.stride(1); + const int o_stride_h = input_grads.stride(2); + const int o_stride_d = input_grads.stride(3); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &fused_rope::fwd, - "Fused Rotary Positional Embedding -- Forward."); - m.def("backward", &fused_rope::bwd, - "Fused Rotary Positional Embedding -- Backward."); + DISPATCH_FLOAT_HALF_AND_BFLOAT( + output_grads.scalar_type(), 0, "dispatch_fused_rope_backward", + dispatch_fused_rope_backward( + s, b, h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, + o_stride_b, o_stride_h, o_stride_d, + output_grads.data_ptr(), cos.data_ptr(), + sin.data_ptr(), input_grads.data_ptr());) + return input_grads; } +} // end namespace fused_rope diff --git a/megatron/fused_kernels/fused_rotary_positional_embedding.h b/megatron/fused_kernels/fused_rotary_positional_embedding.h index 28dca70a5..3b1b2fe8b 100644 --- a/megatron/fused_kernels/fused_rotary_positional_embedding.h +++ b/megatron/fused_kernels/fused_rotary_positional_embedding.h @@ -25,70 +25,83 @@ namespace { template -__global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2, +__global__ void fused_rope_forward(int h, int d, int d2, int stride_s, + int stride_b, int stride_h, int stride_d, + int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, const scalar_t* src, const scalar_t* cos, const scalar_t* sin, scalar_t* dst) { - int sq_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = sq_id * b * np * hn + b_id * np * hn; + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; #pragma unroll - for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) { - scalar_t v_cos = cos[sq_id * hn2 + hn_id]; - scalar_t v_sin = sin[sq_id * hn2 + hn_id]; + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + scalar_t v_cos = cos[s_id * d2 + d_id]; + scalar_t v_sin = sin[s_id * d2 + d_id]; #pragma unroll - for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { - int offset_src_dst = offset_block + head_id * hn + hn_id; - scalar_t v_src = src[offset_src_dst]; - scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2) - ? -src[offset_src_dst + hn2 / 2] - : src[offset_src_dst + hn2 / 2 - hn2]; - dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t v_src = src[offset_src]; + scalar_t v_src_rotate = (d_id + d2 / 2 < d2) + ? -src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } // copy the rest - if (hn > hn2) { + if (d > d2) { #pragma unroll - for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { - int offset_head = offset_block + head_id * hn; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; #pragma unroll - for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - dst[offset_head + hn_id] = src[offset_head + hn_id]; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = + src[offset_head + d_id * stride_d]; } } } } template -__global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, +__global__ void fused_rope_backward(int h, int d, int d2, int stride_s, + int stride_b, int stride_h, int stride_d, + int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, const scalar_t* src, const scalar_t* cos, const scalar_t* sin, scalar_t* dst) { - int sq_id = blockIdx.x, b_id = blockIdx.y; - int offset_block = sq_id * b * np * hn + b_id * np * hn; + int s_id = blockIdx.x, b_id = blockIdx.y; + int offset_block = s_id * stride_s + b_id * stride_b; + int offset_block_dst = s_id * o_stride_s + b_id * o_stride_b; #pragma unroll - for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) { - scalar_t v_cos = cos[sq_id * hn2 + hn_id]; - scalar_t v_sin = (hn_id + hn2 / 2 < hn2) - ? sin[sq_id * hn2 + hn_id + hn2 / 2] - : -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2]; + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + scalar_t v_cos = cos[s_id * d2 + d_id]; + scalar_t v_sin = (d_id + d2 / 2 < d2) + ? sin[s_id * d2 + d_id + d2 / 2] + : -sin[s_id * d2 + d_id + d2 / 2 - d2]; #pragma unroll - for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { - int offset_src_dst = offset_block + head_id * hn + hn_id; - scalar_t v_src = src[offset_src_dst]; - scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2) - ? src[offset_src_dst + hn2 / 2] - : src[offset_src_dst + hn2 / 2 - hn2]; - dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + scalar_t v_src = src[offset_src]; + scalar_t v_src_rotate = (d_id + d2 / 2 < d2) + ? src[offset_src + (d2 / 2) * stride_d] + : src[offset_src + (d2 / 2 - d2) * stride_d]; + dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } // handle the tail - if (hn > hn2) { + if (d > d2) { #pragma unroll - for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) { - int offset_head = offset_block + head_id * hn; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_head = offset_block + h_id * stride_h; + int offset_head_dst = offset_block_dst + h_id * o_stride_h; #pragma unroll - for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) { - dst[offset_head + hn_id] = src[offset_head + hn_id]; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; } } } @@ -97,32 +110,40 @@ __global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2, } // end of anonymous namespace template -void dispatch_fused_rope_forward(int sq, int b, int np, int hn, int hn2, +void dispatch_fused_rope_forward(int s, int b, int h, int d, int d2, + int stride_s, int stride_b, int stride_h, + int stride_d, int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, const scalar_t* input, const scalar_t* cos, const scalar_t* sin, scalar_t* output) { auto stream = at::cuda::getCurrentCUDAStream(); - int warps_per_block = np < 16 ? 4 : 8; - dim3 blocks(sq, b); + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); dim3 threads(C10_WARP_SIZE, warps_per_block); - fused_rope_forward<<>>(sq, b, np, hn, hn2, input, - cos, sin, output); + fused_rope_forward<<>>( + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, input, cos, sin, output); C10_CUDA_KERNEL_LAUNCH_CHECK(); } template -void dispatch_fused_rope_backward(int sq, int b, int np, int hn, int hn2, +void dispatch_fused_rope_backward(int s, int b, int h, int d, int d2, + int stride_s, int stride_b, int stride_h, + int stride_d, int o_stride_s, int o_stride_b, + int o_stride_h, int o_stride_d, const scalar_t* output_grads, const scalar_t* cos, const scalar_t* sin, scalar_t* input_grads) { auto stream = at::cuda::getCurrentCUDAStream(); - int warps_per_block = np < 16 ? 4 : 8; - dim3 blocks(sq, b); + int warps_per_block = h < 16 ? 4 : 8; + dim3 blocks(s, b); dim3 threads(C10_WARP_SIZE, warps_per_block); fused_rope_backward<<>>( - sq, b, np, hn, hn2, output_grads, cos, sin, input_grads); + h, d, d2, stride_s, stride_b, stride_h, stride_d, o_stride_s, o_stride_b, + o_stride_h, o_stride_d, output_grads, cos, sin, input_grads); C10_CUDA_KERNEL_LAUNCH_CHECK(); }