diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index bfa380c2f51..0b7f873e734 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -277,7 +277,7 @@ def model_forward( head_dim = self.config.hidden_size // self.config.num_attention_heads past_key_values = DynamicNormalCache.from_reserved( n_layer, inputs.size(0), n_head, inputs.size(1), head_dim, - inputs.dtype, inputs.device + self.dtype, inputs.device ) return origin_model_forward( self=self,