Skip to content

Commit

Permalink
Gemma capping is a must for big models (#31698)
Browse files Browse the repository at this point in the history
* softcapping

* soft cap before the mask

* style

* ...

* super nit
  • Loading branch information
ArthurZucker authored Jun 28, 2024
1 parent cb29897 commit bbf1e61
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/transformers/models/gemma2/configuration_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class Gemma2Config(PretrainedConfig):
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
final_logit_softcapping (`float`, *optional*, defaults to 30.0): scaling factor when applying tanh softcapping on the logits.
attn_logit_softcapping (`float`, *optional*, defaults to 50.0): scaling factor when applying tanh softcapping on the attention scores.
query_pre_attn_scalar (`float`, *optional*, defaults to 224): scaling factor used on the attention scores
sliding_window (`int`, *optional*, defaults to 4096): in Gemma2, every other layer uses sliding window attention. This is the
size of the sliding window.
Expand Down Expand Up @@ -116,6 +117,7 @@ def __init__(
attention_bias=False,
attention_dropout=0.0,
final_logit_softcapping=30.0,
attn_logit_softcapping=50.0,
query_pre_attn_scalar=224,
sliding_window=4096,
**kwargs,
Expand All @@ -135,6 +137,7 @@ def __init__(
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.attn_logit_softcapping = attn_logit_softcapping

super().__init__(
pad_token_id=pad_token_id,
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,11 @@ def forward(

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

if self.config.attn_logit_softcapping is not None:
attn_weights = attn_weights / self.config.attn_logit_softcapping
attn_weights = torch.tanh(attn_weights)
attn_weights = attn_weights * self.config.attn_logit_softcapping

if attention_mask is not None: # no matter the length, we just slice it
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
Expand Down

0 comments on commit bbf1e61

Please sign in to comment.