From a2ca7b26196e77c580244aa2608a70884a919394 Mon Sep 17 00:00:00 2001 From: cyita Date: Fri, 15 Nov 2024 18:09:35 +0800 Subject: [PATCH] fix glm4 position id --- python/llm/src/ipex_llm/transformers/speculative.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 9800cfea70d2..a09c9df48baf 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -768,7 +768,10 @@ def _non_cpu_ipex_verify(self, verify_input_ids, past_key_values, cur_attention_ forward_args["attention_mask"] = cur_attention_mask if self.config.model_type == "chatglm": - past_key_value_len = past_key_values[0][0].shape[0] + if self.config.num_layers == 40 and hasattr(self.config, 'rope_ratio'): + past_key_value_len = past_key_values[0][0].shape[2] + else: + past_key_value_len = past_key_values[0][0].shape[0] position_ids = torch.arange(verify_input_ids.shape[1], dtype=torch.long, device=verify_input_ids.device) position_ids = position_ids.unsqueeze(0).repeat(1, 1) + past_key_value_len