-
Notifications
You must be signed in to change notification settings - Fork 1.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A fused apply_rotary_pos_emb
implementation for Megatron-Core
#1746
Conversation
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
apply_rotary_pos_emb
implementationapply_rotary_pos_emb
implementation for Megatron-Core
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about using .cuh
instead of .h
for clarity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just followed the naming convention under csrc/megatron
directory, e.g., generic_scaled_masked_softmax.h
, scaled_masked_softmax.h
, etc.
If you feel .cuh
is better, I'm OK to change it.
class FusedRoPEFunc(torch.autograd.Function): | ||
@staticmethod | ||
def forward( | ||
ctx, t: torch.Tensor, cos_: torch.Tensor, sin_: torch.Tensor | ||
) -> torch.Tensor: | ||
import fused_rotary_positional_embedding | ||
|
||
output = fused_rotary_positional_embedding.forward(t, cos_, sin_) | ||
ctx.save_for_backward(cos_, sin_) | ||
|
||
return output | ||
|
||
@staticmethod | ||
def backward( | ||
ctx, grad_output: torch.Tensor | ||
) -> Tuple[Union[torch.Tensor, None], ...]: | ||
import fused_rotary_positional_embedding | ||
|
||
cos_, sin_ = ctx.saved_tensors | ||
grad_q = fused_rotary_positional_embedding.backward(grad_output, cos_, sin_) | ||
|
||
return grad_q, None, None | ||
|
||
|
||
def apply_rotary_pos_emb_fused(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: | ||
cos_ = torch.cos(freqs).to(t.dtype) | ||
sin_ = torch.sin(freqs).to(t.dtype) | ||
return FusedRoPEFunc.apply(t, cos_, sin_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity, wouldn't it be useful to have this in apex.transformer.functiona
namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your review. I have added two functions to the apex.transformer.functional
namespace.
fused_apply_rotary_pos_emb
, which is a drop-in replacement for the currentapply_rotary_pos_emb
in Megatron Core.fused_apply_rotary_pos_emb_cached
, which would be beneficial when MCore implements caching for the rotary positional embedding.
Signed-off-by: Xin Yao <[email protected]>
This is a fused
apply_rotary_pos_emb
implementation for Megatron-Core.In my preliminary benchmark, it gives 2x - 4x speedup over the unfused version.
batch_size=2
andhead_num=64
are fixed.