Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reenable SDPA's FA2 During Training with torch.compile #30442

Merged
merged 7 commits into from
Apr 29, 2024

Conversation

warner-benjamin
Copy link
Contributor

This PR resolves #30010 and completes #30070 by reenabling the SDPA Flash Attention 2 kernel for torch.compile when the model is training. During eval, SDPA dispatches to the efficient kernel with the same logic as in #30070.

This PR will prevent SDPA Attention models from using a low amount of memory during training in eager mode but using a large amount or OOM'ing when compiling due to using the wrong SDPA kernel. It shouldn't affect exporting or generation when the model is in eval mode.

Moving the is_causal dispatch logic from inline to an if statement is required to support both fullgraph=True and dynamic=True. The current code errors out with dynamic=True due to q_len > 1 not being the correct bool type. But wrapping it in a bool bool(q_len>1) to fix dynamic breaks fullgraph=True.

The Llama tests that I could run either all pass or fail in the same state as on main (LlamaIntegrationTest::test_conversion & LlamaIntegrationTest::test_compile_static_cache). I couldn't run Gemma tests due to a model gating error despite having access to Gemma.

@warner-benjamin
Copy link
Contributor Author

Tagging @ArthurZucker and @younesbelkada for review.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate having this if else.... but I guess it's for the best here.
I wish this was natively supported.

Quick question I guess this adds a guard?
Otherwise, LGTM and slow test will be triggered one merged.

src/transformers/models/cohere/modeling_cohere.py Outdated Show resolved Hide resolved
@ArthurZucker
Copy link
Collaborator

fyi @fxmarty when you come back

@warner-benjamin
Copy link
Contributor Author

warner-benjamin commented Apr 24, 2024

Not sure why the CI errored out after these formatting changes. Locally I still have LlamaIntegrationTest::test_conversion & LlamaIntegrationTest::test_compile_static_cache failing, both which also fail on main. Every other run slow test passes for Llama.

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! It looks okay to me, just suggested a style change.

src/transformers/models/cohere/modeling_cohere.py Outdated Show resolved Hide resolved
@fxmarty
Copy link
Contributor

fxmarty commented Apr 29, 2024

@warner-benjamin can you make sure the CI is green? Failing ones seem unrelated, maybe merging main will do

@fxmarty fxmarty merged commit 9df8b30 into huggingface:main Apr 29, 2024
21 checks passed
eigen2017 pushed a commit to eigen2017/transformers that referenced this pull request Apr 30, 2024
…0442)

* Reenable SDPA's FA2 during training with torch.compile

* fix Olmo's SDPA FA2 dispatching too

* update formatting

* improved SDPA comment

* formatting and explanatory comment

* is_causal if statement to one-liner
itazap pushed a commit that referenced this pull request May 14, 2024
* Reenable SDPA's FA2 during training with torch.compile

* fix Olmo's SDPA FA2 dispatching too

* update formatting

* improved SDPA comment

* formatting and explanatory comment

* is_causal if statement to one-liner
@tombousso
Copy link

This PR causes a dynamo graph break at torch.all(attention_mask == 1) when running with torch.compile in training mode, due to the dynamic control flow. Is there a way to get around this?

@fxmarty
Copy link
Contributor

fxmarty commented Jun 4, 2024

@tombousso Yes. Why is it an issue for you? Do you see perf degradation?

AFAIK there is no obvious way around it, maybe using newer APIs from pytorch/pytorch#114823 & maybe other PRs

@fxmarty
Copy link
Contributor

fxmarty commented Jun 4, 2024

https://pytorch.org/docs/main/cond.html may be the way?

@tombousso
Copy link

tombousso commented Jun 4, 2024

Yes, I was seeing perf degradation. I was hoping to get a graph with no breaks to make it easier to see what's going on, and to give the compiler the best opportunity to make optimizations.

@fxmarty
Copy link
Contributor

fxmarty commented Jun 5, 2024

@tombousso Could you open an issue for that?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llama uses significantly more memory in 4.38 & 4.39 than 4.37 with identical code
5 participants