Skip to content

Commit

Permalink
llama 3.1/3.2 support compresskv (#12347)
Browse files Browse the repository at this point in the history
* llama 3.1/3.2 support compresskv

* update

* fix transformers 4.45 error

* fix style

* fix typo

* disable llama3.2 1b compresskv
  • Loading branch information
cyita authored Nov 6, 2024
1 parent d984c06 commit f24352a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 5 deletions.
16 changes: 16 additions & 0 deletions python/llm/src/ipex_llm/transformers/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,22 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
return 0
return self.real_kv_len

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
num_hidden_layers: int = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls(num_hidden_layers)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
invalidInputError(
len(key_states) == 0 and len(value_states) == 0,
"from_legacy_cache should be called with an empty kv cache.")
return cache


class DynamicCompressFp8Cache(DynamicCompressCache, DynamicFp8Cache):
def update(
Expand Down
39 changes: 34 additions & 5 deletions python/llm/src/ipex_llm/transformers/models/llama32.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import should_use_fuse_rope
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.utils import should_use_compresskv, \
is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache, \
DynamicCompressFp8Cache


def llama_model_forward(
Expand Down Expand Up @@ -83,11 +86,25 @@ def llama_model_forward(
self.layers[0].mlp.down_proj, inputs,
self.config.num_attention_heads // self.config.num_key_value_heads
)
use_compresskv = should_use_compresskv(inputs, inputs.shape[1]) or \
isinstance(past_key_values, DynamicCompressCache)
# disable llama3.2 1b for prefill performance and output quality
use_compresskv = use_compresskv and self.config.hidden_size != 2048
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_compresskv and not isinstance(past_key_values, DynamicCompressCache):
if use_quantize_kv:
past_key_values = DynamicCompressFp8Cache.from_legacy_cache(past_key_values)
else:
past_key_values = DynamicCompressCache.from_legacy_cache(past_key_values)
elif use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
elif not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
elif (
not use_quantize_kv
and not use_compresskv
and not isinstance(past_key_values, DynamicNormalCache)
):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)

# IPEX-LLM OPT end

return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -182,6 +199,9 @@ def llama_attention_forward(
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()

# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
Expand All @@ -201,8 +221,17 @@ def llama_attention_forward(
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx,
q_len)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, 256)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

kv_seq_len = key_states.size(2)
if attention_mask is not None: # no matter the length, we just slice it
Expand Down

0 comments on commit f24352a

Please sign in to comment.