From 3fe34fc484b273c3ce53afdf3106f4d74775f075 Mon Sep 17 00:00:00 2001 From: Yishuo Wang Date: Thu, 9 Jan 2025 14:26:28 +0800 Subject: [PATCH] update quantize kv cache condition --- .../src/ipex_llm/transformers/models/baichuan.py | 11 +++++++---- .../src/ipex_llm/transformers/models/chatglm2.py | 11 ++++++++--- .../src/ipex_llm/transformers/models/chatglm4.py | 13 +++++++++---- .../src/ipex_llm/transformers/models/chatglm4v.py | 2 +- python/llm/src/ipex_llm/transformers/models/glm.py | 2 +- .../src/ipex_llm/transformers/models/internlm.py | 9 ++++++--- .../llm/src/ipex_llm/transformers/models/llama.py | 2 +- .../llm/src/ipex_llm/transformers/models/minicpm.py | 2 +- .../src/ipex_llm/transformers/models/minicpm3.py | 4 +++- .../llm/src/ipex_llm/transformers/models/mistral.py | 2 +- .../llm/src/ipex_llm/transformers/models/mllama.py | 2 +- python/llm/src/ipex_llm/transformers/models/phi3.py | 8 ++++++-- python/llm/src/ipex_llm/transformers/models/qwen.py | 6 ++++-- .../llm/src/ipex_llm/transformers/models/qwen2.py | 7 ++++--- .../src/ipex_llm/transformers/models/qwen2_moe.py | 6 ++++-- .../src/ipex_llm/transformers/models/qwen2_vl.py | 4 +++- .../src/ipex_llm/transformers/models/stablelm.py | 4 +++- .../src/ipex_llm/transformers/models/starcoder2.py | 4 +++- .../llm/src/ipex_llm/transformers/models/utils.py | 10 +++++++--- python/llm/src/ipex_llm/transformers/models/yuan.py | 3 ++- 20 files changed, 75 insertions(+), 37 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/baichuan.py b/python/llm/src/ipex_llm/transformers/models/baichuan.py index a78e5f8e131..7e2aad9ea92 100644 --- a/python/llm/src/ipex_llm/transformers/models/baichuan.py +++ b/python/llm/src/ipex_llm/transformers/models/baichuan.py @@ -105,7 +105,9 @@ def baichuan_model_7b_forward( if use_cache: inputs = input_ids if input_ids is not None else inputs_embeds use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs) + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, + self.config.num_attention_heads, + self.config.num_attention_heads) if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): if use_quantize_kv: @@ -278,8 +280,6 @@ def baichuan_attention_forward_7b( key_states = key_states.to(hidden_states.dtype) # IPEX-LLM OPT: kv cache and quantize kv - use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) - # [CompressKV] if use_compresskv: enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, @@ -290,6 +290,8 @@ def baichuan_attention_forward_7b( query_states, attention_mask, 1, self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) else: + use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states, + self.num_heads, self.num_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, device @@ -340,7 +342,8 @@ def baichuan_attention_forward_13b( kv_seq_len += past_key_value[0].shape[2] # IPEX-LLM OPT: kv cache and quantize kv - use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states) + use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states, + self.num_heads, self.num_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, device diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index beb3653a6b5..9ad3a62e880 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -91,8 +91,13 @@ def chatglm2_model_forward( if use_cache: use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1]) + n_heads = self.config.num_attention_heads + if self.config.multi_query_attention: + n_kv_heads = self.config.multi_query_group_num + else: + n_kv_heads = n_heads use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, - input_ids) + input_ids, n_heads, n_kv_heads) if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): if use_quantize_kv: @@ -285,8 +290,6 @@ def chatglm2_attention_forward( key_states[..., :rot_dim] = k_rot[...] # IPEX-LLM OPT: kv cache and quantize kv - use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states) - # [CompressKV] if use_compresskv: from transformers.configuration_utils import PretrainedConfig @@ -300,6 +303,8 @@ def chatglm2_attention_forward( self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH ) else: + use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, + n_head, n_kv_head) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4.py b/python/llm/src/ipex_llm/transformers/models/chatglm4.py index c3adc3720ee..5c92feb905f 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4.py @@ -55,8 +55,13 @@ def chatglm4_model_forward( if use_cache: inputs = input_ids if input_ids is not None else inputs_embeds use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) - use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, - inputs) + n_heads = self.config.num_attention_heads + if self.config.multi_query_attention: + n_kv_heads = self.config.multi_query_group_num + else: + n_kv_heads = n_heads + use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs, + n_heads, n_kv_heads) if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): if use_quantize_kv: @@ -211,8 +216,6 @@ def chatglm4_attention_forward( key_states[..., :rot_dim] = k_rot[...] # IPEX-LLM OPT: kv cache and quantize kv - use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states) - # [CompressKV] if use_compresskv: from transformers.configuration_utils import PretrainedConfig @@ -226,6 +229,8 @@ def chatglm4_attention_forward( self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH ) else: + use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, + n_head, n_kv_head) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py index 10028bca176..e5c5e1bac4f 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm4v.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm4v.py @@ -230,7 +230,7 @@ def chatglm4v_attention_forward( key_states[..., :rot_dim] = k_rot[...] # IPEX-LLM OPT: kv cache and quantize kv - use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states) + use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device diff --git a/python/llm/src/ipex_llm/transformers/models/glm.py b/python/llm/src/ipex_llm/transformers/models/glm.py index f0a2d17a541..39326567210 100644 --- a/python/llm/src/ipex_llm/transformers/models/glm.py +++ b/python/llm/src/ipex_llm/transformers/models/glm.py @@ -147,7 +147,7 @@ def glm_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache or inputs.device.type == 'xpu' use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs, - self.config.num_attention_heads // + self.config.num_attention_heads, self.config.num_key_value_heads) if use_cache: diff --git a/python/llm/src/ipex_llm/transformers/models/internlm.py b/python/llm/src/ipex_llm/transformers/models/internlm.py index 9f71fb38cfa..7a67a653477 100644 --- a/python/llm/src/ipex_llm/transformers/models/internlm.py +++ b/python/llm/src/ipex_llm/transformers/models/internlm.py @@ -87,7 +87,8 @@ def internlm_attention_forward( ) # IPEX-LLM OPT: kv cache and quantzie kv cache - use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states) + use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states, + self.num_heads, self.num_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device @@ -171,7 +172,8 @@ def internlm2_attention_forward( ) # IPEX-LLM OPT: kv cache and quantzie kv cache - use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states) + use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states, + self.num_heads, self.num_key_value_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, hidden_states.device @@ -346,7 +348,8 @@ def internlm_xcomposser2_attention_forward( query_states, key_states, cos, sin, position_ids, "internlm") # IPEX-LLM OPT: kv cache and quantzie kv cache - use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states) + use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states, + self.num_heads, self.num_key_value_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, device diff --git a/python/llm/src/ipex_llm/transformers/models/llama.py b/python/llm/src/ipex_llm/transformers/models/llama.py index 610f1ac05a3..56a73290e70 100644 --- a/python/llm/src/ipex_llm/transformers/models/llama.py +++ b/python/llm/src/ipex_llm/transformers/models/llama.py @@ -72,7 +72,7 @@ def llama_model_forward( use_cache = True if inputs.device.type == "xpu" else use_cache use_quantize_kv = use_quantize_kv_cache( self.layers[0].mlp.down_proj, inputs, - self.config.num_attention_heads // self.config.num_key_value_heads + self.config.num_attention_heads, self.config.num_key_value_heads ) use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \ isinstance(past_key_values, DynamicCompressCache) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm.py b/python/llm/src/ipex_llm/transformers/models/minicpm.py index 532e992d211..49c230fe6cf 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm.py @@ -159,7 +159,7 @@ def minicpm_model_forward( # IPEX-LLM OPT: kv cache and quantize kv cache inputs = input_ids if input_ids is not None else inputs_embeds use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, - self.config.num_attention_heads // + self.config.num_attention_heads, self.config.num_key_value_heads) use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \ isinstance(past_key_values, DynamicCompressCache) diff --git a/python/llm/src/ipex_llm/transformers/models/minicpm3.py b/python/llm/src/ipex_llm/transformers/models/minicpm3.py index 8cef25f0989..03e45912a58 100644 --- a/python/llm/src/ipex_llm/transformers/models/minicpm3.py +++ b/python/llm/src/ipex_llm/transformers/models/minicpm3.py @@ -66,7 +66,9 @@ def minicpm3_model_forward( inputs = input_ids if input_ids is not None else inputs_embeds use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = True if inputs.device.type == "xpu" else use_cache - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs, + num_heads, num_kv_heads) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/mistral.py b/python/llm/src/ipex_llm/transformers/models/mistral.py index 4534f735aaa..413d12f380e 100644 --- a/python/llm/src/ipex_llm/transformers/models/mistral.py +++ b/python/llm/src/ipex_llm/transformers/models/mistral.py @@ -71,7 +71,7 @@ def mistral_model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache or inputs.device.type == 'xpu' use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs, - self.config.num_attention_heads // + self.config.num_attention_heads, self.config.num_key_value_heads) use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \ isinstance(past_key_values, DynamicCompressCache) diff --git a/python/llm/src/ipex_llm/transformers/models/mllama.py b/python/llm/src/ipex_llm/transformers/models/mllama.py index 2dc6896962f..29598f8f5a6 100644 --- a/python/llm/src/ipex_llm/transformers/models/mllama.py +++ b/python/llm/src/ipex_llm/transformers/models/mllama.py @@ -113,7 +113,7 @@ def mllama_text_model_forward( use_cache = True if inputs.device.type == "xpu" else use_cache use_quantize_kv = use_quantize_kv_cache( self.layers[0].mlp.down_proj, inputs, - self.config.num_attention_heads // self.config.num_key_value_heads + self.config.num_attention_heads, self.config.num_key_value_heads ) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 85c41c5c9f3..07f264bda80 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -249,7 +249,9 @@ def model_forward( # IPEX-LLM OPT: kv cache and quantize kv cache and sdp use_cache = use_cache if use_cache is not None else self.config.use_cache inputs = input_ids if input_ids is not None else inputs_embeds - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs, + num_heads, num_kv_heads) use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \ isinstance(past_key_values, DynamicCompressCache) if use_cache: @@ -305,7 +307,9 @@ def model_forward( ): # IPEX-LLM OPT: kv cache and quantize kv cache and sdp use_cache = use_cache if use_cache is not None else self.config.use_cache - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids) + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids, + num_heads, num_kv_heads) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen.py b/python/llm/src/ipex_llm/transformers/models/qwen.py index 590867c85ff..739f638ab8f 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen.py @@ -107,7 +107,8 @@ def qwen_attention_forward( query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states) # IPEX-LLM OPT: kv cache and quantzie kv cache - use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states) + use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states, + self.num_heads, self.num_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, device @@ -205,7 +206,8 @@ def qwen_attention_forward_registered( query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states) # IPEX-LLM OPT: kv cache and quantzie kv cache - use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states) + use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states, + self.num_heads, self.num_heads) key_states, value_states = update_past_key_value( past_key_value, key_states, value_states, kv_seq_len, use_quantize_kv, device diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2.py b/python/llm/src/ipex_llm/transformers/models/qwen2.py index 62f48e2d012..bbf5a6bdea7 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2.py @@ -113,10 +113,10 @@ def qwen2_model_forward( # ipex-llm changes start # IPEX-LLM OPT: kv cache and quantize kv cache inputs = input_ids if input_ids is not None else inputs_embeds + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads use_quantize_kv = ( self.config.hidden_size != 3584 # disable quantize kv in specific model - and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, - self.config.num_attention_heads//self.config.num_key_value_heads) + and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, num_heads, num_kv_heads) ) use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \ isinstance(past_key_values, DynamicCompressCache) @@ -305,10 +305,11 @@ def qwen2_model_forward_4_42( # ipex-llm changes start # IPEX-LLM OPT: kv cache and quantize kv cache + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads use_quantize_kv = ( self.config.hidden_size != 3584 # disable quantize kv in specific model and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds, - self.config.num_attention_heads//self.config.num_key_value_heads) + num_heads, num_kv_heads) ) use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \ isinstance(past_key_values, DynamicCompressCache) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py index d2e2f0262de..2b20b874d4e 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_moe.py @@ -73,8 +73,10 @@ def qwen2moe_model_forward( return_dict: Optional[bool] = None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - input = input_ids if input_ids is not None else inputs_embeds - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input) + inputs = input_ids if input_ids is not None else inputs_embeds + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, inputs, + num_heads, num_kv_heads) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py index 71a63366835..d885e23abfa 100644 --- a/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py +++ b/python/llm/src/ipex_llm/transformers/models/qwen2_vl.py @@ -88,7 +88,9 @@ def qwen2_vl_model_forward( # IPEX-LLM OPT start: kv cache and quantize kv cache inputs = input_ids if input_ids is not None else inputs_embeds use_cache = True if inputs.device.type == "xpu" else use_cache - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs, + num_heads, num_kv_heads) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 9965a25e7d1..b6144cccd73 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -69,8 +69,10 @@ def stablelm_model_forward( ): # IPEX-LLM OPT: kv cache and quantize kv cache use_cache = use_cache if use_cache is not None else self.config.use_cache + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128] - and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids)) + and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids, + num_heads, num_kv_heads)) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/starcoder2.py b/python/llm/src/ipex_llm/transformers/models/starcoder2.py index 7a23c80f3ae..e882ff57b13 100644 --- a/python/llm/src/ipex_llm/transformers/models/starcoder2.py +++ b/python/llm/src/ipex_llm/transformers/models/starcoder2.py @@ -132,7 +132,9 @@ def model_forward( return_dict: Optional[bool] = None, ): use_cache = use_cache if use_cache is not None else self.config.use_cache - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids) + num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids, + num_heads, num_kv_heads) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 93836afb611..e6bdc4b7c06 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -74,7 +74,8 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states): return new_cache_k, new_cache_v -def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool: +def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, + num_heads: int, num_kv_heads: int) -> bool: if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None: warnings.warn( "`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. " @@ -90,8 +91,11 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in else: device_name = get_xpu_device_name(x.device) return ( - device_name in ["mtl", "lnl", "arl"] and kv_group == 1 - or device_name in ["arc", "bmg"] and x.size(0) > 1 + num_kv_heads >= 4 + and ( + device_name in ["mtl", "lnl", "arl"] and num_heads // num_kv_heads <= 4 + or device_name in ["arc", "bmg"] and x.size(0) > 1 + ) ) diff --git a/python/llm/src/ipex_llm/transformers/models/yuan.py b/python/llm/src/ipex_llm/transformers/models/yuan.py index e6d3ddbe671..ccc5ff3aad1 100644 --- a/python/llm/src/ipex_llm/transformers/models/yuan.py +++ b/python/llm/src/ipex_llm/transformers/models/yuan.py @@ -158,7 +158,8 @@ def yuan_attention_forward( "yuan") # IPEX-LLM OPT: kv cache and quantzie kv cache - use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states) + use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states, + self.num_heads, self.num_heads) key_states, value_states = update_past_key_value( None if past_key_value is None else (past_key_value[0], past_key_value[1]), key_states, value_states,