Skip to content

Commit

Permalink
update quantize kv cache condition (#12681)
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 authored Jan 9, 2025
1 parent 5d8081a commit 7234c9b
Show file tree
Hide file tree
Showing 20 changed files with 75 additions and 37 deletions.
11 changes: 7 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def baichuan_model_7b_forward(
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads,
self.config.num_attention_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
Expand Down Expand Up @@ -246,8 +248,6 @@ def baichuan_attention_forward_7b(
key_states = key_states.to(hidden_states.dtype)

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)

# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
Expand All @@ -258,6 +258,8 @@ def baichuan_attention_forward_7b(
query_states, attention_mask, 1,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down Expand Up @@ -308,7 +310,8 @@ def baichuan_attention_forward_13b(
kv_seq_len += past_key_value[0].shape[2]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down
11 changes: 8 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ def chatglm2_model_forward(

if use_cache:
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
n_heads = self.config.num_attention_heads
if self.config.multi_query_attention:
n_kv_heads = self.config.multi_query_group_num
else:
n_kv_heads = n_heads
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
input_ids)
input_ids, n_heads, n_kv_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
Expand Down Expand Up @@ -257,8 +262,6 @@ def chatglm2_attention_forward(
key_states[..., :rot_dim] = k_rot[...]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)

# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
Expand All @@ -272,6 +275,8 @@ def chatglm2_attention_forward(
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down
13 changes: 9 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def chatglm4_model_forward(
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
inputs)
n_heads = self.config.num_attention_heads
if self.config.multi_query_attention:
n_kv_heads = self.config.multi_query_group_num
else:
n_kv_heads = n_heads
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs,
n_heads, n_kv_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
Expand Down Expand Up @@ -211,8 +216,6 @@ def chatglm4_attention_forward(
key_states[..., :rot_dim] = k_rot[...]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)

# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
Expand All @@ -226,6 +229,8 @@ def chatglm4_attention_forward(
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/chatglm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def chatglm4v_attention_forward(
key_states[..., :rot_dim] = k_rot[...]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def glm_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache or inputs.device.type == 'xpu'
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)

if use_cache:
Expand Down
9 changes: 6 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def internlm_attention_forward(
)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down Expand Up @@ -171,7 +172,8 @@ def internlm2_attention_forward(
)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
self.num_heads, self.num_key_value_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down Expand Up @@ -346,7 +348,8 @@ def internlm_xcomposser2_attention_forward(
query_states, key_states, cos, sin, position_ids, "internlm")

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
self.num_heads, self.num_key_value_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def llama_model_forward(
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
self.config.num_attention_heads, self.config.num_key_value_heads
)
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def minicpm_model_forward(
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def minicpm3_model_forward(
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def mistral_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache or inputs.device.type == 'xpu'
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def mllama_text_model_forward(
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
self.config.num_attention_heads, self.config.num_key_value_heads
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
Expand Down
8 changes: 6 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def model_forward(
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
Expand Down Expand Up @@ -305,7 +307,9 @@ def model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def qwen_attention_forward(
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down Expand Up @@ -205,7 +206,8 @@ def qwen_attention_forward_registered(
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down
7 changes: 4 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def qwen2_model_forward(
# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, num_heads, num_kv_heads)
)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down Expand Up @@ -305,10 +305,11 @@ def qwen2_model_forward_4_42(

# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
self.config.num_attention_heads//self.config.num_key_value_heads)
num_heads, num_kv_heads)
)
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def qwen2moe_model_forward(
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input)
inputs = input_ids if input_ids is not None else inputs_embeds
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def qwen2_vl_model_forward(
# IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def stablelm_model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
num_heads, num_kv_heads))
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def model_forward(
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
10 changes: 7 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
return new_cache_k, new_cache_v


def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor,
num_heads: int, num_kv_heads: int) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
warnings.warn(
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
Expand All @@ -90,8 +91,11 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
else:
device_name = get_xpu_device_name(x.device)
return (
device_name in ["mtl", "lnl", "arl"] and kv_group == 1
or device_name in ["arc", "bmg"] and x.size(0) > 1
num_kv_heads >= 4
and (
device_name in ["mtl", "lnl", "arl"] and num_heads // num_kv_heads <= 4
or device_name in ["arc", "bmg"] and x.size(0) > 1
)
)


Expand Down
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/yuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def yuan_attention_forward(
"yuan")

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
None if past_key_value is None else (past_key_value[0], past_key_value[1]),
key_states, value_states,
Expand Down

0 comments on commit 7234c9b

Please sign in to comment.