Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable fused causal attention #14732

Merged
merged 3 commits into from
Feb 21, 2023
Merged

Disable fused causal attention #14732

merged 3 commits into from
Feb 21, 2023

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Feb 17, 2023

Description

There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch model) drops about 1%. The cause is the fused causal attention uses fp16 accumulation. Disable it by default and user could use an environment variable ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1 to turn on it manually.

It also updated the GPT-2 parity test script to generate left side padding to reflect the actual usage.

To test:

python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2 --output gpt2.onnx -o -p fp16 --use_gpu

The top1-match-rate in the output is on-par with ORT 1.13.1.

Motivation and Context

wangyems
wangyems previously approved these changes Feb 17, 2023
@yufenglee
Copy link
Member

I'm not sure it is a good idea to disable fMHA by defaut.
The accuracy check is based on dummy inputs which are meaningless. Intuitively, it tends to generate logits that are more neural to all tokens, i.e., no/less preference on next token. I would say it'd better to randomly select 1000 real(meaning) sentences as the test data set.
In addition, fMHA is only enabled for the context input, not for all iterations.

@tianleiwu
Copy link
Contributor Author

I'm not sure it is a good idea to disable fMHA by defaut. The accuracy check is based on dummy inputs which are meaningless. Intuitively, it tends to generate logits that are more neural to all tokens, i.e., no/less preference on next token. I would say it'd better to randomly select 1000 real(meaning) sentences as the test data set. In addition, fMHA is only enabled for the context input, not for all iterations.

Good suggestion. We will improve the test with real sentences, and re-evaluate this later.
Based on current test result, I think it is better to turn off it by default. Even though 1% drop is small, it is still regression on 1.13.

@tianleiwu tianleiwu merged commit c0d2472 into main Feb 21, 2023
@tianleiwu tianleiwu deleted the tlwu/disable_fused_causal_att branch February 21, 2023 17:53
PatriceVignola pushed a commit that referenced this pull request Feb 22, 2023
There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch
model) drops about 1%. The cause is the fused causal attention uses fp16
accumulation. Disable it by default and add an environment variable 
ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1 to turn on it manually.

It also updated the GPT-2 parity test script to generate left side
padding to reflect the actual usage.

To test:
```
python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2 --output gpt2.onnx -o -p fp16 --use_gpu
```
The top1-match-rate in the output is on-par with ORT 1.13.1.
PatriceVignola pushed a commit that referenced this pull request Feb 22, 2023
There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch
model) drops about 1%. The cause is the fused causal attention uses fp16
accumulation. Disable it by default and add an environment variable 
ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1 to turn on it manually.

It also updated the GPT-2 parity test script to generate left side
padding to reflect the actual usage.

To test:
```
python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2 --output gpt2.onnx -o -p fp16 --use_gpu
```
The top1-match-rate in the output is on-par with ORT 1.13.1.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants