Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama 3.1/3.2 support compresskv #12347

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/llm/src/ipex_llm/transformers/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,22 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return 0
return self.real_kv_len

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
num_hidden_layers: int = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(num_hidden_layers)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
invalidInputError(
len(key_states) == 0 and len(value_states) == 0,
"from_legacy_cache should be called with an empty kv cache.")
return cache


class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
def update(
Expand Down
37 changes: 32 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.utils import should_use_compresskv, \
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \
DynamicCompressFp8Cache


def llama_model_forward(
Expand Down Expand Up @@ -83,11 +86,23 @@ def llama_model_forward(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_compresskv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
elif (
not use_quantize_kv
and not use_compresskv
and not isinstance(past_key_values, DynamicNormalCache)
):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)

# IPEX-LLM OPT end

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -182,6 +197,9 @@ def llama_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
Expand All @@ -201,8 +219,17 @@ def llama_attention_forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, 256)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

kv_seq_len = key_states.size(2)
if attention_mask is not None: # no matter the length, we just slice it
Expand Down
Loading