Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added RMS normalization layer #2881

Merged
merged 1 commit into from
Feb 19, 2023
Merged

Conversation

chiamp
Copy link
Collaborator

@chiamp chiamp commented Feb 16, 2023

Resolves #2849.

Added an optional argument use_mean in the _compute_stats function in flax/linen/normalization.py, which will compute the mean and variance if set to True, and will set the mean to 0 and compute the variance without subtracting the mean if set to False. The latter mode is useful as square rooting this "variance" value (which is done in the _normalize function) will give you the RMS.

@chiamp chiamp self-assigned this Feb 16, 2023
@codecov-commenter
Copy link

codecov-commenter commented Feb 16, 2023

Codecov Report

Merging #2881 (9cff780) into main (5d4040a) will increase coverage by 0.02%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #2881      +/-   ##
==========================================
+ Coverage   81.45%   81.47%   +0.02%     
==========================================
  Files          55       55              
  Lines        5779     5798      +19     
==========================================
+ Hits         4707     4724      +17     
- Misses       1072     1074       +2     
Impacted Files Coverage Δ
flax/linen/__init__.py 100.00% <ø> (ø)
flax/linen/normalization.py 97.41% <100.00%> (+0.29%) ⬆️
flax/core/scope.py 89.91% <0.00%> (-0.22%) ⬇️
flax/linen/module.py 92.37% <0.00%> (-0.13%) ⬇️
flax/configurations.py 85.00% <0.00%> (+0.78%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@chiamp chiamp requested a review from levskaya February 16, 2023 04:21
@@ -335,6 +343,70 @@ def __call__(self, x):
self.bias_init, self.scale_init)


class RMSNorm(Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add an example of how to use the layer? I think we should starting doing this for every layer, like @cgarciae does in his RNN PR: https://github.com/google/flax/pull/2604/files#r1107264719.

@chiamp chiamp force-pushed the rmslayernorm branch 2 times, most recently from 4a05ad4 to 42ba933 Compare February 17, 2023 03:55
Copy link
Collaborator

@levskaya levskaya left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks!! -- I second Marc's ask to add a small usage example in the docstring.

@chiamp
Copy link
Collaborator Author

chiamp commented Feb 17, 2023

@marcvanzee @levskaya I added a docstring, let me know if this works!

@levskaya
Copy link
Collaborator

@chiamp - I added a exception for the deprecation warning, your tests all seem to pass now!

@chiamp
Copy link
Collaborator Author

chiamp commented Feb 19, 2023

@chiamp - I added a exception for the deprecation warning, your tests all seem to pass now!

Thanks @levskaya!

epsilon: float = 1e-6
dtype: Optional[Dtype] = None
param_dtype: Dtype = jnp.float32
use_bias: bool = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I just noticed this - we probably don't want use_bias and bias_init here since we're never adjusting the offset?

@copybara-service copybara-service bot merged commit 5f0ac50 into google:main Feb 19, 2023
@chiamp chiamp deleted the rmslayernorm branch February 23, 2023 01:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Incorporate RMSNorm
4 participants