Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A fused apply_rotary_pos_emb implementation for Megatron-Core #1746

Merged
merged 4 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions apex/transformer/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from apex.transformer.functional.fused_rope import (
fused_apply_rotary_pos_emb,
fused_apply_rotary_pos_emb_cached,
)
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax

__all__ = [
"FusedScaleMaskSoftmax",
"fused_apply_rotary_pos_emb",
"fused_apply_rotary_pos_emb_cached",
]
73 changes: 73 additions & 0 deletions apex/transformer/functional/fused_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# coding=utf-8
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
import torch


class FusedRoPEFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx, t: torch.Tensor, cos_: torch.Tensor, sin_: torch.Tensor
) -> torch.Tensor:
import fused_rotary_positional_embedding

output = fused_rotary_positional_embedding.forward(t, cos_, sin_)
ctx.save_for_backward(cos_, sin_)

return output

@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
import fused_rotary_positional_embedding

cos_, sin_ = ctx.saved_tensors
grad_q = fused_rotary_positional_embedding.backward(grad_output, cos_, sin_)

return grad_q, None, None


def fused_apply_rotary_pos_emb(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.

Args:
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]

Returns:
Tensor: The input tensor after applying RoPE
"""
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
return FusedRoPEFunc.apply(t, cos_, sin_)


def fused_apply_rotary_pos_emb_cached(
t: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> torch.Tensor:
"""Apply rotary positional embedding to input tensor T.

Args:
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
cos (Tensor): Cached cosine of the rotary positional embedding tensor is of shape [seq_length, ..., dim]
sin (Tensor): Cached sine of the rotary positional embedding tensor is of shape [seq_length, ..., dim]

Returns:
Tensor: The input tensor after applying RoPE
"""
cos_ = cos.to(t.dtype)
sin_ = sin.to(t.dtype)
return FusedRoPEFunc.apply(t, cos_, sin_)
90 changes: 90 additions & 0 deletions csrc/megatron/fused_rotary_positional_embedding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* coding=utf-8
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <torch/extension.h>

namespace fused_rope {

torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos,
const torch::Tensor &sin);

torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
const torch::Tensor &cos, const torch::Tensor &sin);

torch::Tensor fwd(const at::Tensor &input_, const at::Tensor &cos_,
const at::Tensor &sin_) {
auto input = input_.contiguous();
auto cos = cos_.contiguous();
auto sin = sin_.contiguous();
TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(input.size(0) == cos.size(0),
"expected input and cos tensor have the same sequence length");
TORCH_CHECK(input.size(0) == sin.size(0),
"expected input and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(input.size(3) >= cos.size(3),
"expected the last dim of the input tensor is greater than the "
"cos tensor");
TORCH_CHECK(input.size(3) >= sin.size(3),
"expected the last dim of the input tensor is greater than the "
"sin tensor");

return fwd_cuda(input, cos, sin);
}

torch::Tensor bwd(const torch::Tensor &output_grads_, const at::Tensor &cos_,
const at::Tensor &sin_) {
auto output_grads = output_grads_.contiguous();
auto cos = cos_.contiguous();
auto sin = sin_.contiguous();
TORCH_CHECK(output_grads.dim() == 4, "expected 4D tensor");
TORCH_CHECK(cos.dim() == 4, "expected 4D tensor");
TORCH_CHECK(sin.dim() == 4, "expected 4D tensor");
TORCH_CHECK(
output_grads.size(0) == cos.size(0),
"expected output_grads and cos tensor have the same sequence length");
TORCH_CHECK(
output_grads.size(0) == sin.size(0),
"expected output_grads and sin tensor have the same sequence length");
TORCH_CHECK(cos.size(1) == 1 && cos.size(2) == 1,
"expected the second and third dims of the cos tensor equal 1");
TORCH_CHECK(sin.size(1) == 1 && sin.size(2) == 1,
"expected the second and third dims of the sin tensor equal 1");
TORCH_CHECK(
output_grads.size(3) >= cos.size(3),
"expected the last dim of the output_grads tensor is greater than the "
"cos tensor");
TORCH_CHECK(
output_grads.size(3) >= sin.size(3),
"expected the last dim of the output_grads tensor is greater than the "
"sin tensor");

return bwd_cuda(output_grads, cos, sin);
}

} // end namespace fused_rope

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &fused_rope::fwd,
"Fused Rotary Positional Embedding -- Forward.");
m.def("backward", &fused_rope::bwd,
"Fused Rotary Positional Embedding -- Backward.");
}
129 changes: 129 additions & 0 deletions csrc/megatron/fused_rotary_positional_embedding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/* coding=utf-8
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/macros/Macros.h>
#include <cuda_runtime.h>
#include <torch/extension.h>

namespace {

template <typename scalar_t>
__global__ void fused_rope_forward(int sq, int b, int np, int hn, int hn2,
const scalar_t* src, const scalar_t* cos,
const scalar_t* sin, scalar_t* dst) {
int sq_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = sq_id * b * np * hn + b_id * np * hn;
#pragma unroll
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
scalar_t v_sin = sin[sq_id * hn2 + hn_id];
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_src_dst = offset_block + head_id * hn + hn_id;
scalar_t v_src = src[offset_src_dst];
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
? -src[offset_src_dst + hn2 / 2]
: src[offset_src_dst + hn2 / 2 - hn2];
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}

// copy the rest
if (hn > hn2) {
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_head = offset_block + head_id * hn;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
int offset_src_dst = offset_head + hn_id;
dst[offset_src_dst] = src[offset_src_dst];
}
}
}
}

template <typename scalar_t>
__global__ void fused_rope_backward(int sq, int b, int np, int hn, int hn2,
const scalar_t* src, const scalar_t* cos,
const scalar_t* sin, scalar_t* dst) {
int sq_id = blockIdx.x, b_id = blockIdx.y;
int offset_block = sq_id * b * np * hn + b_id * np * hn;
#pragma unroll
for (int hn_id = threadIdx.x; hn_id < hn2; hn_id += blockDim.x) {
scalar_t v_cos = cos[sq_id * hn2 + hn_id];
scalar_t v_sin = (hn_id + hn2 / 2 < hn2)
? sin[sq_id * hn2 + hn_id + hn2 / 2]
: -sin[sq_id * hn2 + hn_id + hn2 / 2 - hn2];
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_src_dst = offset_block + head_id * hn + hn_id;
scalar_t v_src = src[offset_src_dst];
scalar_t v_src_rotate = (hn_id + hn2 / 2 < hn2)
? src[offset_src_dst + hn2 / 2]
: src[offset_src_dst + hn2 / 2 - hn2];
dst[offset_src_dst] = v_src * v_cos + v_src_rotate * v_sin;
}
}

// handle the tail
if (hn > hn2) {
#pragma unroll
for (int head_id = threadIdx.y; head_id < np; head_id += blockDim.y) {
int offset_head = offset_block + head_id * hn;
#pragma unroll
for (int hn_id = hn2 + threadIdx.x; hn_id < hn; hn_id += blockDim.x) {
dst[offset_head + hn_id] = 1.0;
}
}
}
}

} // end of anonymous namespace

template <typename scalar_t>
void dispatch_fused_rope_forward(int sq, int b, int np, int hn, int hn2,
const scalar_t* input, const scalar_t* cos,
const scalar_t* sin, scalar_t* output) {
auto stream = at::cuda::getCurrentCUDAStream();

int warps_per_block = np < 16 ? 4 : 8;
dim3 blocks(sq, b);
dim3 threads(C10_WARP_SIZE, warps_per_block);

fused_rope_forward<<<blocks, threads, 0, stream>>>(sq, b, np, hn, hn2, input,
cos, sin, output);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <typename scalar_t>
void dispatch_fused_rope_backward(int sq, int b, int np, int hn, int hn2,
const scalar_t* output_grads,
const scalar_t* cos, const scalar_t* sin,
scalar_t* input_grads) {
auto stream = at::cuda::getCurrentCUDAStream();

int warps_per_block = np < 16 ? 4 : 8;
dim3 blocks(sq, b);
dim3 threads(C10_WARP_SIZE, warps_per_block);

fused_rope_backward<<<blocks, threads, 0, stream>>>(
sq, b, np, hn, hn2, output_grads, cos, sin, input_grads);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
64 changes: 64 additions & 0 deletions csrc/megatron/fused_rotary_positional_embedding_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/* coding=utf-8
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <ATen/ATen.h>

#include "fused_rotary_positional_embedding.h"
#include "type_shim.h"

namespace fused_rope {

torch::Tensor fwd_cuda(const torch::Tensor &input, const torch::Tensor &cos,
const torch::Tensor &sin) {
const int sq = input.size(0);
const int b = input.size(1);
const int np = input.size(2);
const int hn = input.size(3);
const int hn2 = cos.size(3);

// output
auto act_options = input.options().requires_grad(false);
torch::Tensor output = torch::empty({sq, b, np, hn}, act_options);

DISPATCH_FLOAT_HALF_AND_BFLOAT(
input.scalar_type(), 0, "dispatch_fused_rope_forward",
dispatch_fused_rope_forward(
sq, b, np, hn, hn2, input.data_ptr<scalar_t_0>(),
cos.data_ptr<scalar_t_0>(), sin.data_ptr<scalar_t_0>(),
output.data_ptr<scalar_t_0>()););
return output;
}

torch::Tensor bwd_cuda(const torch::Tensor &output_grads,
const torch::Tensor &cos, const torch::Tensor &sin) {
const int sq = output_grads.size(0);
const int b = output_grads.size(1);
const int np = output_grads.size(2);
const int hn = output_grads.size(3);
const int hn2 = cos.size(3);

auto act_options = output_grads.options().requires_grad(false);
torch::Tensor input_grads = torch::empty({sq, b, np, hn}, act_options);

DISPATCH_FLOAT_HALF_AND_BFLOAT(
output_grads.scalar_type(), 0, "dispatch_fused_rope_backward",
dispatch_fused_rope_backward(
sq, b, np, hn, hn2, output_grads.data_ptr<scalar_t_0>(),
cos.data_ptr<scalar_t_0>(), sin.data_ptr<scalar_t_0>(),
input_grads.data_ptr<scalar_t_0>());)
return input_grads;
}
} // end namespace fused_rope
21 changes: 21 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,27 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
)
)

ext_modules.append(
CUDAExtension(
name="fused_rotary_positional_embedding",
sources=[
"csrc/megatron/fused_rotary_positional_embedding.cpp",
"csrc/megatron/fused_rotary_positional_embedding_cuda.cu",
],
include_dirs=[os.path.join(this_dir, "csrc")],
extra_compile_args={
"cxx": ["-O3"] + version_dependent_macros,
"nvcc": [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
] + version_dependent_macros,
},
)
)

if bare_metal_version >= Version("11.0"):

cc_flag = []
Expand Down
Loading