From 06bbe747545728b00ece54f16896c65a857c5c33 Mon Sep 17 00:00:00 2001 From: Benjamin Warner Date: Mon, 29 Apr 2024 11:45:43 -0500 Subject: [PATCH] Reenable SDPA's FA2 During Training with torch.compile (#30442) * Reenable SDPA's FA2 during training with torch.compile * fix Olmo's SDPA FA2 dispatching too * update formatting * improved SDPA comment * formatting and explanatory comment * is_causal if statement to one-liner --- src/transformers/modeling_attn_mask_utils.py | 7 ++++--- src/transformers/models/cohere/modeling_cohere.py | 13 +++++++++---- src/transformers/models/gemma/modeling_gemma.py | 13 +++++++++---- src/transformers/models/llama/modeling_llama.py | 13 +++++++++---- src/transformers/models/olmo/modeling_olmo.py | 11 +++++++++-- 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 44ea1795669f58..8dcf40268d0324 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -240,6 +240,7 @@ def _ignore_causal_mask_sdpa( inputs_embeds: torch.Tensor, past_key_values_length: int, sliding_window: Optional[int] = None, + is_training: bool = False, ) -> bool: """ Detects whether the optional user-specified attention_mask & the automatically created causal mask can be ignored in case PyTorch's SDPA is used, rather relying on SDPA's `is_causal` argument. @@ -263,11 +264,11 @@ def _ignore_causal_mask_sdpa( if attention_mask is None: # TODO: When tracing with TorchDynamo with fullgraph=True, the model is recompiled depending on the input shape, thus SDPA's `is_causal` argument is rightfully updated (see https://gist.github.com/fxmarty/1313f39037fc1c112508989628c57363). However, when using `torch.export` or # or `torch.onnx.dynamo_export`, we must pass an example input, and `is_causal` behavior is hard-coded. If a user exports a model with q_len > 1, the exported model will hard-code `is_causal=True` which is in general wrong (see https://github.com/pytorch/pytorch/issues/108108). - # Thus, we currently can NOT set `ignore_causal_mask = True` here. We would need a `torch._dynamo.is_exporting()` flag. + # Thus, we only set `ignore_causal_mask = True` if the model is set to training. # # Besides, jit.trace can not handle the `q_len > 1` condition for `is_causal` (`TypeError: scaled_dot_product_attention(): argument 'is_causal' must be bool, not Tensor`). if ( - not is_tracing + (is_training or not is_tracing) and (query_length == 1 or key_value_length == query_length) and (sliding_window is None or key_value_length < sliding_window) ): @@ -279,7 +280,7 @@ def _ignore_causal_mask_sdpa( raise ValueError( f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}." ) - elif not is_tracing and torch.all(attention_mask == 1): + elif (is_training or not is_tracing) and torch.all(attention_mask == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. ignore_causal_mask = True diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c0516928c..3d529fd1ec4fe0 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -590,15 +590,17 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -996,7 +998,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a53..97e4e5d49f8e0f 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -570,15 +570,17 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -982,7 +984,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a63d..9a2566f2fdd2eb 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -666,15 +666,17 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() - # In case we are not compiling, we may set `causal_mask` to None, which is required to dispatch to SDPA's Flash Attention 2 backend, rather - # relying on the `is_causal` argument. + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -1074,7 +1076,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index e3b0e05127c52d..87db966e2d8f67 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -647,13 +647,17 @@ def forward( key_states = key_states.contiguous() value_states = value_states.contiguous() + # We dispatch to SDPA's Flash Attention or Efficient kernels via this if statement instead of an + # inline conditional assignment to support both torch.compile's `dynamic=True` and `fullgraph=True` + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, - is_causal=causal_mask is None and q_len > 1, + is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -1057,7 +1061,10 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, # in order to dispatch on Flash Attention 2. if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, ): return None