diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index d8ca9c8eb87..7a214e571ff 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -87,7 +87,7 @@ def chatglm2_model_forward( dtype=inputs_embeds.dtype, device=inputs_embeds.device) if use_cache: - use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[-1]) + use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1]) use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, input_ids) if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index 4874a8957b2..0103b2437c5 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -51,7 +51,7 @@ def chatglm4_model_forward( if use_cache: inputs = input_ids if input_ids is not None else inputs_embeds - use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1]) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.dense_h_to_4h, inputs) if use_compress_kv and not use_quantize_kv and not isinstance(past_key_values, diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index ccc4afc37ff..6c2680bfa19 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -128,7 +128,7 @@ def llama_model_forward_4_36( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input, input.shape[-1]): + elif should_use_compresskv(input, input.shape[1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( @@ -168,7 +168,7 @@ def llama_model_forward_4_38( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input, input.shape[-1]): + elif should_use_compresskv(input, input.shape[1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( @@ -209,7 +209,7 @@ def llama_model_forward_4_41( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input, input.shape[-1]): + elif should_use_compresskv(input, input.shape[1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index fc31c4b73af..9c7f74dcfc8 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -628,7 +628,7 @@ def minicpm_model_forward( self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input, input.shape[-1]): + elif should_use_compresskv(input, input.shape[1]): if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index f077474fb65..a4ad0256092 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -213,7 +213,7 @@ def mistral_model_forward_4_36( self.config.num_attention_heads//self.config.num_key_value_heads): if not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - elif should_use_compresskv(input_ids, input_ids.shape[-1]): + elif should_use_compresskv(input_ids, input_ids.shape[1]): # if use quantize kv, compress kv will be ignored now if not isinstance(past_key_values, DynamicCompressCache): past_key_values = DynamicCompressCache.from_legacy_cache( diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 443a9921ee7..54c462a5207 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -258,7 +258,7 @@ def model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input) - use_compress_kv = should_use_compresskv(input, input.shape[-1]) + use_compress_kv = should_use_compresskv(input, input.shape[1]) if use_cache: if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index b2ec61a3222..b80f4cff86f 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -118,7 +118,7 @@ def qwen2_model_forward( and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, self.config.num_attention_heads//self.config.num_key_value_heads) ) - use_compress_kv = should_use_compresskv(inputs, inputs.shape[-1]) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):