-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reenable SDPA's FA2 During Training with torch.compile #30442
Conversation
Tagging @ArthurZucker and @younesbelkada for review. |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hate having this if else.... but I guess it's for the best here.
I wish this was natively supported.
Quick question I guess this adds a guard?
Otherwise, LGTM and slow test will be triggered one merged.
fyi @fxmarty when you come back |
Not sure why the CI errored out after these formatting changes. Locally I still have |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! It looks okay to me, just suggested a style change.
@warner-benjamin can you make sure the CI is green? Failing ones seem unrelated, maybe merging main will do |
…0442) * Reenable SDPA's FA2 during training with torch.compile * fix Olmo's SDPA FA2 dispatching too * update formatting * improved SDPA comment * formatting and explanatory comment * is_causal if statement to one-liner
* Reenable SDPA's FA2 during training with torch.compile * fix Olmo's SDPA FA2 dispatching too * update formatting * improved SDPA comment * formatting and explanatory comment * is_causal if statement to one-liner
This PR causes a dynamo graph break at |
@tombousso Yes. Why is it an issue for you? Do you see perf degradation? AFAIK there is no obvious way around it, maybe using newer APIs from pytorch/pytorch#114823 & maybe other PRs |
https://pytorch.org/docs/main/cond.html may be the way? |
Yes, I was seeing perf degradation. I was hoping to get a graph with no breaks to make it easier to see what's going on, and to give the compiler the best opportunity to make optimizations. |
@tombousso Could you open an issue for that? |
This PR resolves #30010 and completes #30070 by reenabling the SDPA Flash Attention 2 kernel for
torch.compile
when the model is training. During eval, SDPA dispatches to the efficient kernel with the same logic as in #30070.This PR will prevent SDPA Attention models from using a low amount of memory during training in eager mode but using a large amount or OOM'ing when compiling due to using the wrong SDPA kernel. It shouldn't affect exporting or generation when the model is in eval mode.
Moving the
is_causal
dispatch logic from inline to an if statement is required to support bothfullgraph=True
anddynamic=True
. The current code errors out withdynamic=True
due toq_len > 1
not being the correct bool type. But wrapping it in a boolbool(q_len>1)
to fix dynamic breaksfullgraph=True
.The Llama tests that I could run either all pass or fail in the same state as on main (
LlamaIntegrationTest::test_conversion
&LlamaIntegrationTest::test_compile_static_cache
). I couldn't run Gemma tests due to a model gating error despite having access to Gemma.