Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Another take at speeding up FusedLinearLayer #16

Closed
wants to merge 3 commits into from

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Oct 20, 2021

What does this PR do?

Rewrite most of the fused linear kernel, splitting the work in a more logical way (grad over bias goes into the second kernel, which walks over M and is what it needed). General code cleanup in that area, but... not faster than before. Discussing with @ptillet, I must be missing something

Latest numbers, somehow reintroducing the batch dimension speeds things up
--- Type: torch.float16 ---

Units: TFlops/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096
pytorch - squared_relu - no bias - fw+bw 4.1 8.4 8.4 11.3 13.0
triton - squared_relu - no bias - fw+bw 6.1 9.2 9.3 10.9 12.0
pytorch - squared_relu - bias - fw+bw 3.9 7.7 7.7 10.7 12.6
triton - squared_relu - bias - fw+bw 5.5 8.8 8.8 10.4 11.8
pytorch - gelu - no bias - fw+bw 5.7 10.4 10.4 12.6 13.8
triton - gelu - no bias - fw+bw 2.8 4.3 4.3 4.7 5.0
pytorch - gelu - bias - fw+bw 5.0 9.4 9.3 11.8 13.3
triton - gelu - bias - fw+bw 2.4 4.1 4.1 4.4 4.7
pytorch - leaky_relu - no bias - fw+bw 6.0 10.7 10.8 13.1 14.2
triton - leaky_relu - no bias - fw+bw 6.7 9.6 9.7 10.8 12.0
pytorch - leaky_relu - bias - fw+bw 5.0 9.7 9.7 12.2 13.4
triton - leaky_relu - bias - fw+bw 6.0 9.2 9.2 10.6 11.8
pytorch - relu - no bias - fw+bw 5.9 10.8 10.8 13.1 14.1
triton - relu - no bias - fw+bw 6.6 9.9 9.9 11.2 12.2
pytorch - relu - bias - fw+bw 5.0 9.7 9.8 12.2 13.6
triton - relu - bias - fw+bw 6.0 9.5 9.3 10.9 12.0
pytorch - None - no bias - fw+bw 7.1 12.3 12.3 13.9 14.4
triton - None - no bias - fw+bw 6.7 9.9 9.7 10.9 12.2
pytorch - None - bias - fw+bw 5.9 11.0 10.9 13.1 14.0
triton - None - bias - fw+bw 6.0 9.3 9.3 10.3 11.9

--- Type: torch.float16 ---

Units: TFlops/s B=8, M=256, K=512 B=8, M=512, K=1024 B=4, M=1024, K=1024 B=2, M=2048, K=2048 B=2, M=4096, K=4096
pytorch - squared_relu - no bias - fw 8.7 14.7 14.7 18.7 20.1
triton - squared_relu - no bias - fw 11.4 16.5 16.5 20.2 20.8
pytorch - squared_relu - bias - fw 7.0 12.5 12.4 16.8 18.7
triton - squared_relu - bias - fw 11.3 16.1 16.0 19.2 20.1
pytorch - gelu - no bias - fw 10.4 16.8 16.8 20.1 20.4
triton - gelu - no bias - fw 6.7 7.9 7.9 10.1 10.6
pytorch - gelu - bias - fw 8.1 13.9 14.0 17.9 18.8
triton - gelu - bias - fw 4.2 7.1 7.1 9.0 9.2
pytorch - leaky_relu - no bias - fw 11.2 17.6 17.6 20.6 20.7
triton - leaky_relu - no bias - fw 13.4 17.7 17.5 20.8 19.8
pytorch - leaky_relu - bias - fw 8.5 14.4 14.5 18.3 19.3
triton - leaky_relu - bias - fw 13.3 17.2 17.1 20.0 19.6
pytorch - relu - no bias - fw 11.2 17.6 17.5 20.7 20.6
triton - relu - no bias - fw 13.4 19.1 19.1 22.3 21.7
pytorch - relu - bias - fw 8.5 14.5 14.5 18.6 19.4
triton - relu - bias - fw 13.3 18.7 18.6 22.0 21.5
pytorch - None - no bias - fw 15.2 21.8 21.8 23.5 21.7
triton - None - no bias - fw 13.4 17.4 17.5 20.2 19.7
pytorch - None - bias - fw 10.7 17.2 17.2 20.6 20.3
triton - None - bias - fw 13.4 17.0 17.0 20.4 19.6

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 20, 2021
@blefaudeux blefaudeux marked this pull request as draft October 20, 2021 05:46
…n to flatten a bunch of code, to begin with

testing the partial grad reduction
moving the bias computation to the second kernel, much better fit
minor tweaks, perfs still not there on V100
@blefaudeux blefaudeux force-pushed the fused_linear_improve branch from 030229e to ee2f627 Compare October 20, 2021 19:02
@blefaudeux blefaudeux force-pushed the fused_linear_improve branch from ee2f627 to 618697b Compare October 20, 2021 22:51
@blefaudeux
Copy link
Contributor Author

no meaningful benefits for now, closing this, will reboot down the line

@blefaudeux blefaudeux closed this Oct 20, 2021
@blefaudeux blefaudeux deleted the fused_linear_improve branch October 26, 2021 23:17
qianfengz added a commit to qianfengz/xformers that referenced this pull request Feb 7, 2024
…added

ensure ck_decoder does not dispatch in test_attn_bias_padded
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants