diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index 100da837a9e..8b20f546893 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -121,6 +121,21 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] + @classmethod + def from_reserved(cls, layers: int, + bsz: int, n_head: int, length: int, head_dim: int, + dtype: torch.dtype, device: torch.device): + past_key_values = cls() + for _i in range(layers): + k_cache, v_cache = init_kv_cache( + bsz, n_head, head_dim, + 0, length + cls.KV_ALLOC_BLOCK_LENGTH, + dtype, device + ) + past_key_values.key_cache.append(k_cache) + past_key_values.value_cache.append(v_cache) + return past_key_values + # Copied from transformers.models.llama.modeling_llama.repeat_kv def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 5c630681cc9..823fb10391a 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -254,9 +254,9 @@ def model_forward( ): # IPEX-LLM OPT: kv cache and quantize kv cache and sdp use_cache = use_cache if use_cache is not None else self.config.use_cache - input = input_ids if input_ids is not None else inputs_embeds - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input) - use_compress_kv = should_use_compresskv(input, input.shape[1]) + inputs = input_ids if input_ids is not None else inputs_embeds + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs) + use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) if use_cache: if use_compress_kv and not isinstance(past_key_values, DynamicCompressCache): @@ -272,6 +272,14 @@ def model_forward( DynamicCompressCache )): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) + if past_key_values.get_seq_length() == 0: + n_layer = self.config.num_hidden_layers + n_head = self.config.num_attention_heads + head_dim = self.config.hidden_size // self.config.num_attention_heads + past_key_values = DynamicNormalCache.from_reserved( + n_layer, inputs.size(0), n_head, inputs.size(1), head_dim, + inputs.dtype, inputs.device + ) return origin_model_forward( self=self, input_ids=input_ids,