Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A fused apply_rotary_pos_emb implementation for Megatron-Core #1746

Merged
merged 4 commits into from
Nov 14, 2023

Conversation

yaox12
Copy link
Contributor

@yaox12 yaox12 commented Nov 9, 2023

This is a fused apply_rotary_pos_emb implementation for Megatron-Core.

In my preliminary benchmark, it gives 2x - 4x speedup over the unfused version. batch_size=2 and head_num=64 are fixed.

dtype=torch.float32, seq_length=2048, hidden_size=128, rotary_percent=0.5
unfused rope: 0.45 ms
fused rope: 0.14 ms

dtype=torch.float32, seq_length=2048, hidden_size=128, rotary_percent=1.0
unfused rope: 0.67 ms
fused rope: 0.15 ms

dtype=torch.float32, seq_length=2048, hidden_size=256, rotary_percent=0.5
unfused rope: 0.84 ms
fused rope: 0.27 ms

dtype=torch.float32, seq_length=2048, hidden_size=256, rotary_percent=1.0
unfused rope: 1.3 ms
fused rope: 0.3 ms

dtype=torch.float32, seq_length=4096, hidden_size=128, rotary_percent=0.5
unfused rope: 0.85 ms
fused rope: 0.23 ms

dtype=torch.float32, seq_length=4096, hidden_size=128, rotary_percent=1.0
unfused rope: 1.3 ms
fused rope: 0.3 ms

dtype=torch.float32, seq_length=4096, hidden_size=256, rotary_percent=0.5
unfused rope: 1.6 ms
fused rope: 0.75 ms

dtype=torch.float32, seq_length=4096, hidden_size=256, rotary_percent=1.0
unfused rope: 2.6 ms
fused rope: 0.58 ms

Signed-off-by: Xin Yao <[email protected]>
Signed-off-by: Xin Yao <[email protected]>
@yaox12 yaox12 changed the title A fused apply_rotary_pos_emb implementation A fused apply_rotary_pos_emb implementation for Megatron-Core Nov 9, 2023
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about using .cuh instead of .h for clarity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just followed the naming convention under csrc/megatron directory, e.g., generic_scaled_masked_softmax.h, scaled_masked_softmax.h, etc.
If you feel .cuh is better, I'm OK to change it.

Comment on lines 12 to 39
class FusedRoPEFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx, t: torch.Tensor, cos_: torch.Tensor, sin_: torch.Tensor
) -> torch.Tensor:
import fused_rotary_positional_embedding

output = fused_rotary_positional_embedding.forward(t, cos_, sin_)
ctx.save_for_backward(cos_, sin_)

return output

@staticmethod
def backward(
ctx, grad_output: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
import fused_rotary_positional_embedding

cos_, sin_ = ctx.saved_tensors
grad_q = fused_rotary_positional_embedding.backward(grad_output, cos_, sin_)

return grad_q, None, None


def apply_rotary_pos_emb_fused(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
cos_ = torch.cos(freqs).to(t.dtype)
sin_ = torch.sin(freqs).to(t.dtype)
return FusedRoPEFunc.apply(t, cos_, sin_)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity, wouldn't it be useful to have this in apex.transformer.functiona namespace?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review. I have added two functions to the apex.transformer.functional namespace.

  1. fused_apply_rotary_pos_emb, which is a drop-in replacement for the current apply_rotary_pos_emb in Megatron Core.
  2. fused_apply_rotary_pos_emb_cached, which would be beneficial when MCore implements caching for the rotary positional embedding.

@crcrpar crcrpar merged commit 08f7402 into NVIDIA:master Nov 14, 2023
@crcrpar crcrpar added this to the 23.12 milestone Nov 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants