-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable fused causal attention #14732
Conversation
I'm not sure it is a good idea to disable fMHA by defaut. |
Good suggestion. We will improve the test with real sentences, and re-evaluate this later. |
There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch model) drops about 1%. The cause is the fused causal attention uses fp16 accumulation. Disable it by default and add an environment variable ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1 to turn on it manually. It also updated the GPT-2 parity test script to generate left side padding to reflect the actual usage. To test: ``` python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2 --output gpt2.onnx -o -p fp16 --use_gpu ``` The top1-match-rate in the output is on-par with ORT 1.13.1.
There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch model) drops about 1%. The cause is the fused causal attention uses fp16 accumulation. Disable it by default and add an environment variable ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1 to turn on it manually. It also updated the GPT-2 parity test script to generate left side padding to reflect the actual usage. To test: ``` python -m onnxruntime.transformers.models.gpt2.convert_to_onnx -m gpt2 --output gpt2.onnx -o -p fp16 --use_gpu ``` The top1-match-rate in the output is on-par with ORT 1.13.1.
Description
There is accuracy regression in GPT-2 model. Top1 match rate (vs PyTorch model) drops about 1%. The cause is the fused causal attention uses fp16 accumulation. Disable it by default and user could use an environment variable
ORT_ENABLE_FUSED_CAUSAL_ATTENTION=1
to turn on it manually.It also updated the GPT-2 parity test script to generate left side padding to reflect the actual usage.
To test:
The top1-match-rate in the output is on-par with ORT 1.13.1.
Motivation and Context