Skip to content

Commit

Permalink
fix UT
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Dec 24, 2024
1 parent a8e33e4 commit 4c7cae0
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,13 @@ def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor,
key = repeat_kv(key, n_heads // n_kv_heads)
value = repeat_kv(value, n_heads // n_kv_heads)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)
if is_causal and mask is None:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, is_causal=is_causal, scale=scale
)
else:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, scale=scale
)
attn_output = attn_output.to(dtype) # workaround ipex 2.1's bug
return attn_output

0 comments on commit 4c7cae0

Please sign in to comment.