diff --git a/python/llm/src/ipex_llm/transformers/models/aquila.py b/python/llm/src/ipex_llm/transformers/models/aquila.py index b889c1c1e11..880846667b0 100644 --- a/python/llm/src/ipex_llm/transformers/models/aquila.py +++ b/python/llm/src/ipex_llm/transformers/models/aquila.py @@ -109,7 +109,7 @@ def aquila_attention_forward( ) # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 13ad662e19e..8b0ba92fba2 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -79,8 +79,8 @@ def mlp_gelu_forward(self, x: torch.Tensor): return fuse_mlp_base(self, GELU, x) -def attention_softmax(attn_weights: torch.Tensor, training: bool): - if attn_weights.is_contiguous() and attn_weights.device.type == "xpu" and not training: +def attention_softmax(attn_weights: torch.Tensor): + if attn_weights.is_contiguous() and attn_weights.device.type == "xpu": import xe_addons xe_addons.attn_softmax_inplaced(attn_weights) else: diff --git a/python/llm/src/ipex_llm/transformers/models/gemma.py b/python/llm/src/ipex_llm/transformers/models/gemma.py index 1731266b27d..0490b4a19c4 100644 --- a/python/llm/src/ipex_llm/transformers/models/gemma.py +++ b/python/llm/src/ipex_llm/transformers/models/gemma.py @@ -220,7 +220,7 @@ def gemma_attention_forward( attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 1851d383207..68e47df6a47 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -125,7 +125,7 @@ def internlm_attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2) diff --git a/python/llm/src/ipex_llm/transformers/models/llama32.py b/python/llm/src/ipex_llm/transformers/models/llama32.py index 9bb1d97e266..f105669fa6b 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama32.py +++ b/python/llm/src/ipex_llm/transformers/models/llama32.py @@ -237,7 +237,7 @@ def llama_attention_forward( attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/minicpmv.py b/python/llm/src/ipex_llm/transformers/models/minicpmv.py index 89aca6d0126..9cad8fc8444 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpmv.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpmv.py @@ -56,7 +56,7 @@ def siglip_attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) @@ -161,7 +161,7 @@ def vision_transformer_attention_forward(self, x: torch.Tensor) -> torch.Tensor: query_states, key_states, value_states = qkv.chunk(3, dim=1) attn_weights = torch.matmul(query_states * self.scale, key_states.transpose(2, 3)) - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_weights = self.attn_drop(attn_weights) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/mllama.py b/python/llm/src/ipex_llm/transformers/models/mllama.py index 2f1142b7dd7..4a05346e3ca 100644 --- a/python/llm/src/ipex_llm/transformers/models/mllama.py +++ b/python/llm/src/ipex_llm/transformers/models/mllama.py @@ -85,7 +85,7 @@ def mllama_vision_attention_forward( # upcast attention to fp32 from ipex_llm.transformers.models.common import attention_softmax - attn_weights = attention_softmax(attn_weights, False) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, value) @@ -311,7 +311,7 @@ def mllama_cross_attention_forward( attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/phi.py b/python/llm/src/ipex_llm/transformers/models/phi.py index ca68700afc4..7401f3efac3 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi.py +++ b/python/llm/src/ipex_llm/transformers/models/phi.py @@ -114,7 +114,7 @@ def attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training).to(hidden_states.dtype) + attn_weights = attention_softmax(attn_weights).to(hidden_states.dtype) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index fa6c43d6d47..a7d1a1d5ec8 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -185,7 +185,7 @@ def attention_forward( attn_weights.div_(math.sqrt(self.head_dim)) if attention_mask is not None: attn_weights.add_(attention_mask) - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index dd0e0de3e82..9696723f127 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -219,7 +219,7 @@ def qwen2_vision_attention_forward( attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = attention_softmax(attn_weights, False) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(0, 1) attn_output = attn_output.reshape(seq_length, -1) @@ -298,7 +298,7 @@ def qwen2_vl_attention_forward( attn_weights = attn_weights + causal_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() diff --git a/python/llm/src/ipex_llm/transformers/models/sd15.py b/python/llm/src/ipex_llm/transformers/models/sd15.py index ab999d40974..60d657ee77f 100644 --- a/python/llm/src/ipex_llm/transformers/models/sd15.py +++ b/python/llm/src/ipex_llm/transformers/models/sd15.py @@ -114,7 +114,7 @@ def __call__( attn_weights = torch.matmul(query * scale, key.transpose(-1, -2)) if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = attention_softmax(attn_weights, False) + attn_weights = attention_softmax(attn_weights) hidden_states = torch.matmul(attn_weights, value) # IPEX-LLM changes end diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 37639ff92d3..af6c5dee530 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -175,7 +175,7 @@ def stablelm_attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_weights = self.attention_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/starcoder2.py b/python/llm/src/ipex_llm/transformers/models/starcoder2.py index 9ebb0c5ffd3..1bffc1ee67f 100644 --- a/python/llm/src/ipex_llm/transformers/models/starcoder2.py +++ b/python/llm/src/ipex_llm/transformers/models/starcoder2.py @@ -134,7 +134,7 @@ def attention_forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_weights = torch.nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index 339e958b206..800f0273c06 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -240,7 +240,7 @@ def yuan_attention_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = attention_softmax(attn_weights, self.training) + attn_weights = attention_softmax(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2)