diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 4a5f2bc0b2d4..4a4481874e27 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -76,9 +76,14 @@ def chatglm4_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or\ (past_key_values and seq_length != 1): - full_attention_mask = self.get_masks(inputs_embeds, - past_key_values, - padding_mask=attention_mask) + if self.config.hidden_size == 4096: + full_attention_mask = self.get_masks(input_ids, + past_key_values, + padding_mask=attention_mask) + else: + full_attention_mask = self.get_masks(inputs_embeds, + past_key_values, + padding_mask=attention_mask) # ipex-llm changes begin # 1. replace `rotary_pos_emb` with `inv_freq` and `position_ids` diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index a09c9df48baf..dfc1c8f918a3 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -510,7 +510,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa for k, v in past_key_values ] elif self.config.model_type == "chatglm": - if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'): + if self.config.num_layers in [28, 40] and hasattr(self.config, 'rope_ratio'): past_key_values = [ (k[:, :, :-(new_cache_size), :], v[:, :, :-(new_cache_size), :]) @@ -768,7 +768,7 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_ forward_args["attention_mask"] = cur_attention_mask if self.config.model_type == "chatglm": - if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'): + if self.config.num_layers in [28, 40] and hasattr(self.config, 'rope_ratio'): past_key_value_len = past_key_values[0][0].shape[2] else: past_key_value_len = past_key_values[0][0].shape[0]