Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

add option for using fused kernel #227

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft

add option for using fused kernel #227

wants to merge 11 commits into from

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Feb 27, 2024

Summary

The fused_sat_cast kernel can be found here: https://github.com/drisspg/driss_torch/blob/6d1be6ec21c5a56cf8ddfeb12a57cfce316e40bb/src/saturated_cast.cu#L243

The other two kernels can be found here: https://github.com/pytorch-labs/float8_experimental/pull/227/files#diff-a1a29e99a81b48419f66c77a301ca1f09c51bf754baf82e0962c9bc243d89310

Eager numbers from this script: which is a stripped down version of the full benchmark script to compare fused/vs unfused casting
https://github.com/pytorch-labs/float8_experimental/pull/227/files#diff-729d5216ec3b30dea879056f9eb4a9bac9127501b8c8e6d516b640abb2f106ae

Table

Key/Structure:

  • ref_dtype : the dtype of the equivalent linear forw+backward in that precision
  • fuse_cast: use the fused kernels from this PR in eager for some of the casting
  • pt_fp8_speedup: ref_time_sec/pt_fp8_time_sec
      shape          ref_dtype       fuse_cast  ref_time_sec  pt_fp8_time_sec  pt_fp8_speedup
(16384, 8192, 1280)  torch.bfloat16       True      0.002140         0.002688        0.796081
(16384, 8192, 1280)  torch.bfloat16      False      0.002142         0.004025        0.532102
(16384, 1024, 8192)  torch.bfloat16       True      0.001883         0.002398        0.785198
(16384, 1024, 8192)  torch.bfloat16      False      0.001885         0.003384        0.556938
(16384, 8192, 7168)  torch.bfloat16       True      0.010392         0.007928        1.310714
(16384, 8192, 7168)  torch.bfloat16      False      0.010418         0.011007        0.946480
(16384, 3584, 8192)  torch.bfloat16       True      0.005375         0.004720        1.138892
(16384, 3584, 8192)  torch.bfloat16      False      0.005423         0.006589        0.823073

Traces:

Repro script: https://gist.github.com/drisspg/693b53527859433fc9d8987a1b7e464b

Things left to do still for more perf

During the backward pass we also want the option to transpose so that we can we have prepare inputs for the TN format. I still need to add support for this in the kernel but besides that this is a less contained change since we need to sinnal this to the to_fp8_no_autograd constructor instead of relying on the compiler to generate this.

The same could be done for scale inverse calls as well but again we need to cache these on the fp8 tensor

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 27, 2024
@drisspg drisspg force-pushed the use_fused_kernel branch 3 times, most recently from 7a8c3f5 to 4afcd0b Compare February 28, 2024 20:22
@drisspg
Copy link
Contributor Author

drisspg commented Mar 1, 2024

@y-sq had a good suggestion to just use amax directly and not the scale to avoid launch of extra kernel: drisspg/driss_torch#5

That will require moving the to_float8 constructor to take in an amax and not a scale, I will work on that tomorrow..

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants