diff --git a/src/transformers/models/gemma2/configuration_gemma2.py b/src/transformers/models/gemma2/configuration_gemma2.py index 47207d7ca12436..7da541207bfe76 100644 --- a/src/transformers/models/gemma2/configuration_gemma2.py +++ b/src/transformers/models/gemma2/configuration_gemma2.py @@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig): attention_dropout (`float`, *optional*, defaults to 0.0): The dropout ratio for the attention probabilities. final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits. + attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores. query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the size of the sliding window. @@ -116,6 +117,7 @@ def __init__( attention_bias=False, attention_dropout=0.0, final_logit_softcapping=30.0, + attn_logit_softcapping=50.0, query_pre_attn_scalar=224, sliding_window=4096, **kwargs, @@ -135,6 +137,7 @@ def __init__( self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout + self.attn_logit_softcapping = attn_logit_softcapping super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 24e60eddba221b..6b2b47b5159e28 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -256,6 +256,11 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling + if self.config.attn_logit_softcapping is not None: + attn_weights = attn_weights / self.config.attn_logit_softcapping + attn_weights = torch.tanh(attn_weights) + attn_weights = attn_weights * self.config.attn_logit_softcapping + if attention_mask is not None: # no matter the length, we just slice it causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights + causal_mask