From c9d1d8d53c33ae0e6d0d03b7c2022ff6d6657a87 Mon Sep 17 00:00:00 2001 From: sgwhat Date: Tue, 18 Jun 2024 00:22:02 +0800 Subject: [PATCH] add phi-3 model support --- python/llm/src/ipex_llm/transformers/models/phi3.py | 3 ++- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index c398ae6107e..9247ea947d2 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -234,7 +234,8 @@ def model_forward( ): # IPEX-LLM OPT: kv cache and quantize kv cache and sdp use_cache = use_cache if use_cache is not None else self.config.use_cache - use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids) + input = input_ids if input_ids is not None else inputs_embeds + use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 42e0c6316eb..a031d001fe4 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -48,6 +48,7 @@ def __init__(self, *args): # to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/ # python/llm/src/ipex_llm/transformers/models/llama.py#L119 self.up_proj = DummyLayer() + self.down_proj = DummyLayer() def forward(self, x): return x