Skip to content

Commit

Permalink
use new fp32 softmax kernel (#11776)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Aug 13, 2024
1 parent 23d3acd commit aa861df
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 3 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def siglip_attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)

Expand Down
6 changes: 3 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,9 @@ def attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(value_states.dtype)
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down

0 comments on commit aa861df

Please sign in to comment.