From 92df994aec90fa96d0cfc6e381767016cfa3c6df Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 9 Dec 2024 18:35:44 -0500 Subject: [PATCH] flash attention forward doesn't play well with torch.compile --- src/transformers/modeling_flash_attention_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index ec03ba1eb5fd83..65e97aa076682a 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -228,6 +228,7 @@ def fa_peft_integration_check( deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" +@torch.compiler.disable(recursive=True) def _flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor,