Skip to content

Commit

Permalink
refactor attention_softmax (#12295)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Oct 30, 2024
1 parent 2b2cb9c commit 540eaeb
Show file tree
Hide file tree
Showing 14 changed files with 18 additions and 18 deletions.
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def aquila_attention_forward(
)

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def mlp_gelu_forward(self, x: torch.Tensor):
return fuse_mlp_base(self, GELU, x)


def attention_softmax(attn_weights: torch.Tensor, training: bool):
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training:
def attention_softmax(attn_weights: torch.Tensor):
if attn_weights.is_contiguous() and attn_weights.device.type == "xpu":
import xe_addons
xe_addons.attn_softmax_inplaced(attn_weights)
else:
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def gemma_attention_forward(
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def internlm_attention_forward(
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def llama_attention_forward(
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def siglip_attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask

attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down Expand Up @@ -161,7 +161,7 @@ def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor:
query_states, key_states, value_states = qkv.chunk(3, dim=1)

attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3))
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_weights = self.attn_drop(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def mllama_vision_attention_forward(

# upcast attention to fp32
from ipex_llm.transformers.models.common import attention_softmax
attn_weights = attention_softmax(attn_weights, False)
attn_weights = attention_softmax(attn_weights)

attn_output = torch.matmul(attn_weights, value)

Expand Down Expand Up @@ -311,7 +311,7 @@ def mllama_cross_attention_forward(
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype)
attn_weights = attention_softmax(attn_weights).to(hidden_states.dtype)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)

Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def attention_forward(
attn_weights.div_(math.sqrt(self.head_dim))
if attention_mask is not None:
attn_weights.add_(attention_mask)
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)

attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
Expand Down
4 changes: 2 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def qwen2_vision_attention_forward(
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights, False)
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
Expand Down Expand Up @@ -298,7 +298,7 @@ def qwen2_vl_attention_forward(
attn_weights = attn_weights + causal_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2).contiguous()
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/sd15.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __call__(
attn_weights = torch.matmul(query * scale, key.transpose(-1, -2))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = attention_softmax(attn_weights, False)
attn_weights = attention_softmax(attn_weights)
hidden_states = torch.matmul(attn_weights, value)
# IPEX-LLM changes end

Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def stablelm_attention_forward(
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def attention_forward(
attn_weights = attn_weights + attention_mask

# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout,
training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/yuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def yuan_attention_forward(
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = attention_softmax(attn_weights, self.training)
attn_weights = attention_softmax(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)

attn_output = attn_output.transpose(1, 2)
Expand Down

0 comments on commit 540eaeb

Please sign in to comment.