diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 4303dbd3a18..fece88384b5 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -281,8 +281,13 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, key = repeat_kv(key, n_heads // n_kv_heads) value = repeat_kv(value, n_heads // n_kv_heads) - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, key, value, mask, is_causal=is_causal, scale=scale - ) + if is_causal and mask is None: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, is_causal=is_causal, scale=scale + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, mask, scale=scale + ) attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug return attn_output