Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update quantize kv cache condition #12681

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def baichuan_model_7b_forward(
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs)
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads,
self.config.num_attention_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
Expand Down Expand Up @@ -278,8 +280,6 @@ def baichuan_attention_forward_7b(
key_states = key_states.to(hidden_states.dtype)

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)

# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value,
Expand All @@ -290,6 +290,8 @@ def baichuan_attention_forward_7b(
query_states, attention_mask, 1,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down Expand Up @@ -340,7 +342,8 @@ def baichuan_attention_forward_13b(
kv_seq_len += past_key_value[0].shape[2]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.W_pack, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down
11 changes: 8 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,13 @@ def chatglm2_model_forward(

if use_cache:
use_compress_kv = should_use_compresskv(input_ids, input_ids.shape[1])
n_heads = self.config.num_attention_heads
if self.config.multi_query_attention:
n_kv_heads = self.config.multi_query_group_num
else:
n_kv_heads = n_heads
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
input_ids)
input_ids, n_heads, n_kv_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
Expand Down Expand Up @@ -285,8 +290,6 @@ def chatglm2_attention_forward(
key_states[..., :rot_dim] = k_rot[...]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)

# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
Expand All @@ -300,6 +303,8 @@ def chatglm2_attention_forward(
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down
13 changes: 9 additions & 4 deletions python/llm/src/ipex_llm/transformers/models/chatglm4.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,13 @@ def chatglm4_model_forward(
if use_cache:
inputs = input_ids if input_ids is not None else inputs_embeds
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj,
inputs)
n_heads = self.config.num_attention_heads
if self.config.multi_query_attention:
n_kv_heads = self.config.multi_query_group_num
else:
n_kv_heads = n_heads
use_quantize_kv = use_quantize_kv_cache(self.encoder.layers[0].mlp.gate_proj, inputs,
n_heads, n_kv_heads)
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
if use_quantize_kv:
Expand Down Expand Up @@ -211,8 +216,6 @@ def chatglm4_attention_forward(
key_states[..., :rot_dim] = k_rot[...]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)

# [CompressKV]
if use_compresskv:
from transformers.configuration_utils import PretrainedConfig
Expand All @@ -226,6 +229,8 @@ def chatglm4_attention_forward(
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH
)
else:
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states,
n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/chatglm4v.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def chatglm4v_attention_forward(
key_states[..., :rot_dim] = k_rot[...]

# IPEX-LLM OPT: kv cache and quantize kv
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states)
use_quantize_kv = use_quantize_kv_cache(self.query_key_value, query_states, n_head, n_kv_head)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def glm_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache or inputs.device.type == 'xpu'
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)

if use_cache:
Expand Down
9 changes: 6 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ def internlm_attention_forward(
)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.qkv_proj, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down Expand Up @@ -171,7 +172,8 @@ def internlm2_attention_forward(
)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
self.num_heads, self.num_key_value_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, hidden_states.device
Expand Down Expand Up @@ -346,7 +348,8 @@ def internlm_xcomposser2_attention_forward(
query_states, key_states, cos, sin, position_ids, "internlm")

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.wqkv, hidden_states,
self.num_heads, self.num_key_value_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def llama_model_forward(
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
self.config.num_attention_heads, self.config.num_key_value_heads
)
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/minicpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def minicpm_model_forward(
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/minicpm3.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ def minicpm3_model_forward(
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def mistral_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache or inputs.device.type == 'xpu'
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads //
self.config.num_attention_heads,
self.config.num_key_value_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.size(1)) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def mllama_text_model_forward(
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
self.config.num_attention_heads, self.config.num_key_value_heads
)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
Expand Down
8 changes: 6 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def model_forward(
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
Expand Down Expand Up @@ -305,7 +307,9 @@ def model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def qwen_attention_forward(
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down Expand Up @@ -205,7 +206,8 @@ def qwen_attention_forward_registered(
query_states = query_states * logn_tensor.type_as(query_states).expand_as(query_states)

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.c_attn, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
past_key_value, key_states, value_states,
kv_seq_len, use_quantize_kv, device
Expand Down
7 changes: 4 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def qwen2_model_forward(
# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs,
self.config.num_attention_heads//self.config.num_key_value_heads)
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs, num_heads, num_kv_heads)
)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down Expand Up @@ -305,10 +305,11 @@ def qwen2_model_forward_4_42(

# ipex-llm changes start
# IPEX-LLM OPT: kv cache and quantize kv cache
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (
self.config.hidden_size != 3584 # disable quantize kv in specific model
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, inputs_embeds,
self.config.num_attention_heads//self.config.num_key_value_heads)
num_heads, num_kv_heads)
)
use_compress_kv = should_use_compresskv(inputs_embeds, inputs_embeds.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
Expand Down
6 changes: 4 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def qwen2moe_model_forward(
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, input)
inputs = input_ids if input_ids is not None else inputs_embeds
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.shared_expert.up_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ def qwen2_vl_model_forward(
# IPEX-LLM OPT start: kv cache and quantize kv cache
inputs = input_ids if input_ids is not None else inputs_embeds
use_cache = True if inputs.device.type == "xpu" else use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ def stablelm_model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = (self.layers[0].self_attn.head_dim in [64, 80, 96, 128]
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids))
and use_quantize_kv_cache(self.layers[0].mlp.up_proj, input_ids,
num_heads, num_kv_heads))
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
4 changes: 3 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def model_forward(
return_dict: Optional[bool] = None,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids)
num_heads, num_kv_heads = self.config.num_attention_heads, self.config.num_key_value_heads
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.c_fc, input_ids,
num_heads, num_kv_heads)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
10 changes: 7 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def append_kv_cache(cache_k, cache_v, key_states, value_states):
return new_cache_k, new_cache_v


def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: int = 1) -> bool:
def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor,
num_heads: int, num_kv_heads: int) -> bool:
if os.environ.get("BIGDL_QUANTIZE_KV_CACHE", None) is not None:
warnings.warn(
"`BIGDL_QUANTIZE_KV_CACHE` is deprecated and will be removed in future releases. "
Expand All @@ -90,8 +91,11 @@ def use_quantize_kv_cache(linear: torch.nn.Module, x: torch.Tensor, kv_group: in
else:
device_name = get_xpu_device_name(x.device)
return (
device_name in ["mtl", "lnl", "arl"] and kv_group == 1
or device_name in ["arc", "bmg"] and x.size(0) > 1
num_kv_heads >= 4
and (
device_name in ["mtl", "lnl", "arl"] and num_heads // num_kv_heads <= 4
or device_name in ["arc", "bmg"] and x.size(0) > 1
)
)


Expand Down
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/yuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def yuan_attention_forward(
"yuan")

# IPEX-LLM OPT: kv cache and quantzie kv cache
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states)
use_quantize_kv = use_quantize_kv_cache(self.qk_proj, hidden_states,
self.num_heads, self.num_heads)
key_states, value_states = update_past_key_value(
None if past_key_value is None else (past_key_value[0], past_key_value[1]),
key_states, value_states,
Expand Down
Loading