diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index ec03ba1eb5fd83..65e97aa076682a 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -228,6 +228,7 @@ def fa_peft_integration_check( deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" +@torch.compiler.disable(recursive=True) def _flash_attention_forward( query_states: torch.Tensor, key_states: torch.Tensor,