diff --git a/python/llm/src/ipex_llm/transformers/models/stablelm.py b/python/llm/src/ipex_llm/transformers/models/stablelm.py index 9bef4c292cc..00845966b4f 100644 --- a/python/llm/src/ipex_llm/transformers/models/stablelm.py +++ b/python/llm/src/ipex_llm/transformers/models/stablelm.py @@ -93,7 +93,7 @@ def stablelm_model_forward( ): # IPEX-LLM OPT: kv cache and quantize kv cache use_cache = use_cache if use_cache is not None else self.config.use_cache - use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 96, 128] + use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128] and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids)) if use_cache: if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 48a6d1b8345..9d4b44cc245 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -329,7 +329,7 @@ def use_sdp(q_len, kv_len, head_dim, query_states): return ( query_states.device.type == "xpu" and query_states.dtype in [torch.float, torch.half] # fp32/fp16 - and head_dim in [64, 96, 128] + and head_dim in [64, 80, 96, 128] and q_len != kv_len # next token and q_len <= 32 # lookup ) @@ -347,7 +347,7 @@ def use_sdp_fp8(q_len, kv_len, query_states): def use_sdp_causal(q_len, kv_len, head_dim, query_states, training): return ( q_len == kv_len # first token - and head_dim in [64, 96, 128] # for now + and head_dim in [64, 80, 96, 128] # for now and query_states.device.type == "xpu" # GPU and query_states.dtype in [torch.float, torch.half] # fp32/fp16 and not query_states.requires_grad and not training # not training