-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Llama RMSNorm for Gemma #2974
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Before we merge, let's make sure it doesn't change the outputs (maybe we could add a test like we have for other models, using transformers as a reference). |
For a note, using the custom op brings a slight numerical difference in handling the residual connection. While the original implementation uses the current dtype (f16 of bf16) in
|
Gemma's RMSNorm is only slightly different from Llama's RMSNorm. Thus, we can use the existing custom op for it. This optimization leads to ~10% latency reduction.