diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 4a5f2bc0b2d4..4a4481874e27 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -76,9 +76,14 @@ def chatglm4_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or\ (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(inputs_embeds, - past_key_values, - padding_mask=attention_mask) + if self.config.hidden_size == 4096: + full_attention_mask = self.get_masks(input_ids, + past_key_values, + padding_mask=attention_mask) + else: + full_attention_mask = self.get_masks(inputs_embeds, + past_key_values, + padding_mask=attention_mask) # ipex-llm changes begin # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`