From 6077f37e8ffaec1195734d7665289f89e215c297 Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 15 Nov 2024 18:49:54 +0800 Subject: [PATCH] fix glm4-9b error --- .../llm/src/ipex_llm/transformers/models/chatglm4.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 4a5f2bc0b2d4..4a4481874e27 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -76,9 +76,14 @@ def chatglm4_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or\ (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(inputs_embeds, - past_key_values, - padding_mask=attention_mask) + if self.config.hidden_size == 4096: + full_attention_mask = self.get_masks(input_ids, + past_key_values, + padding_mask=attention_mask) + else: + full_attention_mask = self.get_masks(inputs_embeds, + past_key_values, + padding_mask=attention_mask) # ipex-llm changes begin # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids`