From 4c7cae0cc39adc38e3343bb8447dedf21a29d77b Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Tue, 24 Dec 2024 16:59:16 +0800 Subject: [PATCH] fix UT --- python/llm/src/ipex_llm/transformers/models/common.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/common.py b/python/llm/src/ipex_llm/transformers/models/common.py index 4303dbd3a18..fece88384b5 100644 --- a/python/llm/src/ipex_llm/transformers/models/common.py +++ b/python/llm/src/ipex_llm/transformers/models/common.py @@ -281,8 +281,13 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, key = repeat_kv(key, n_heads // n_kv_heads) value = repeat_kv(value, n_heads // n_kv_heads) - attn_output = torch.nn.functional.scaled_dot_product_attention( - query, key, value, mask, is_causal=is_causal, scale=scale - ) + if is_causal and mask is None: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, is_causal=is_causal, scale=scale + ) + else: + attn_output = torch.nn.functional.scaled_dot_product_attention( + query, key, value, mask, scale=scale + ) attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug return attn_output