Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node-wise normalization mode in LayerNorm #4944

Merged
merged 14 commits into from
Jul 8, 2022
Merged

Conversation

lightaime
Copy link
Contributor

This pull request adds the node-wise normalization mode in LayerNorm. If "graph" mode is used (by default), each graph will be considered as an element to be normalized. If "node" mode is used, each node will be considered as an element to be normalized.

@codecov
Copy link

codecov bot commented Jul 7, 2022

Codecov Report

Merging #4944 (4af89d3) into master (db5e6d9) will decrease coverage by 1.94%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master    #4944      +/-   ##
==========================================
- Coverage   84.65%   82.71%   -1.95%     
==========================================
  Files         330      330              
  Lines       17885    17887       +2     
==========================================
- Hits        15141    14795     -346     
- Misses       2744     3092     +348     
Impacted Files Coverage Δ
torch_geometric/nn/norm/layer_norm.py 100.00% <100.00%> (ø)
torch_geometric/nn/models/dimenet_utils.py 0.00% <0.00%> (-75.52%) ⬇️
torch_geometric/nn/models/dimenet.py 14.51% <0.00%> (-53.00%) ⬇️
torch_geometric/nn/conv/utils/typing.py 81.25% <0.00%> (-17.50%) ⬇️
torch_geometric/loader/neighbor_loader.py 85.62% <0.00%> (-9.38%) ⬇️
torch_geometric/nn/inits.py 67.85% <0.00%> (-7.15%) ⬇️
torch_geometric/nn/resolver.py 88.00% <0.00%> (-6.00%) ⬇️
torch_geometric/testing/decorators.py 96.66% <0.00%> (-3.34%) ⬇️
torch_geometric/testing/feature_store.py 97.22% <0.00%> (-2.78%) ⬇️
torch_geometric/io/tu.py 93.90% <0.00%> (-2.44%) ⬇️
... and 7 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update db5e6d9...4af89d3. Read the comment docs.

@lightaime lightaime marked this pull request as ready for review July 7, 2022 16:32
@lightaime lightaime requested a review from rusty1s July 7, 2022 16:32
Copy link
Member

@rusty1s rusty1s left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super. Only left nitpicky comments :D

torch_geometric/nn/norm/layer_norm.py Outdated Show resolved Hide resolved
torch_geometric/nn/norm/layer_norm.py Outdated Show resolved Hide resolved
torch_geometric/nn/norm/layer_norm.py Outdated Show resolved Hide resolved
torch_geometric/nn/norm/layer_norm.py Outdated Show resolved Hide resolved
torch_geometric/nn/norm/layer_norm.py Show resolved Hide resolved
torch_geometric/nn/norm/layer_norm.py Outdated Show resolved Hide resolved
torch_geometric/nn/norm/layer_norm.py Outdated Show resolved Hide resolved
torch_geometric/nn/norm/layer_norm.py Outdated Show resolved Hide resolved
lightaime and others added 11 commits July 8, 2022 14:52
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
@lightaime
Copy link
Contributor Author

Super. Only left nitpicky comments :D

Thanks! Looks cleaner.

@lightaime lightaime merged commit d220afe into master Jul 8, 2022
@lightaime lightaime deleted the node_layer_norm branch July 8, 2022 12:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants