From dd46c141bdd235f4d1258999b67a3fda66fb4af8 Mon Sep 17 00:00:00 2001 From: Yina Chen <33650826+cyita@users.noreply.github.com> Date: Fri, 9 Aug 2024 10:43:43 +0300 Subject: [PATCH] Phi3 support compresskv (#11733) * phi3 support compresskv * fix phi3 mtl error * fix conflict with quant kv * fix abnormal on mtl * fix style * use slide windows size to compress kv * support sliding window * fix style * fix style * temp: partial support quant kv * support quant kv with compress kv, todo: model check * temp * fix style * fix style * remove prepare * address comment * default -> 1.8k --- python/llm/src/ipex_llm/transformers/kv.py | 181 +++++++++++------- .../src/ipex_llm/transformers/models/phi3.py | 45 ++++- .../src/ipex_llm/transformers/models/utils.py | 2 +- 3 files changed, 146 insertions(+), 82 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/kv.py b/python/llm/src/ipex_llm/transformers/kv.py index 1543ab34d5e..0e3803f5737 100644 --- a/python/llm/src/ipex_llm/transformers/kv.py +++ b/python/llm/src/ipex_llm/transformers/kv.py @@ -155,62 +155,71 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m if q_len <= attn_config.max_capacity_prompt: return key_states, value_states else: - key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device) - attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :], - key_states_expand.transpose(2, 3)) / math.sqrt(head_dim) - mask = torch.full((attn_config.window_size, attn_config.window_size), - torch.finfo(attn_weights.dtype).min, - device=attn_weights.device) - mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(attn_weights.device) - attention_mask = mask[None, None, :, :] - - attn_weights[:, :, -attn_config.window_size:, -attn_config.window_size:] += attention_mask - - attn_weights = nn.functional.softmax(attn_weights, dim=-1, - dtype=torch.float32).to(query_states.dtype) - attn_weights_sum = attn_weights[:, :, -attn_config.window_size:, - :-attn_config.window_size].sum(dim=-2) - if attn_config.pooling == 'avgpool': - if num_key_value_groups > 1: - attn_cache = F.avg_pool2d(attn_weights_sum, kernel_size=(num_key_value_groups, - attn_config.kernel_size), - padding=(0, attn_config.kernel_size//2), - stride=(num_key_value_groups, 1)) - else: - attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size, - padding=attn_config.kernel_size//2, stride=1) - elif attn_config.pooling == 'maxpool': - if num_key_value_groups > 1: - attn_cache = F.max_pool2d(attn_weights_sum, - kernel_size=(num_key_value_groups, - attn_config.kernel_size), - padding=(0, attn_config.kernel_size//2), - stride=(num_key_value_groups, 1)) - else: - attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size, - padding=attn_config.kernel_size//2, stride=1) + sliding_window_size = getattr(attn_config, "sliding_window", None) + if sliding_window_size is not None and sliding_window_size <= 2500: + return key_states[:, :, -sliding_window_size:, :], \ + value_states[:, :, -sliding_window_size:, :] else: - invalidInputError(False, 'Pooling method not supported') - indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size, - dim=-1).indices - indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) - k_past_compress = key_states[:, :, :-attn_config.window_size, :].gather(dim=2, - index=indices) - v_past_compress = value_states[:, :, :-attn_config.window_size, :].gather(dim=2, - index=indices) - k_cur = key_states[:, :, -attn_config.window_size:, :] - v_cur = value_states[:, :, -attn_config.window_size:, :] - key_states = torch.cat([k_past_compress, k_cur], dim=2) - value_states = torch.cat([v_past_compress, v_cur], dim=2) - return key_states, value_states + key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device) + attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :], + key_states_expand.transpose(2, 3)) / math.sqrt(head_dim) + mask = torch.full((attn_config.window_size, attn_config.window_size), + torch.finfo(attn_weights.dtype).min, + device=attn_weights.device) + mask_cond = torch.arange(mask.size(-1), device=attn_weights.device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(attn_weights.device) + attention_mask = mask[None, None, :, :] + + attn_weights[:, :, -attn_config.window_size:, + -attn_config.window_size:] += attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, + dtype=torch.float32).to(query_states.dtype) + attn_weights_sum = attn_weights[:, :, -attn_config.window_size:, + :-attn_config.window_size].sum(dim=-2) + if attn_config.pooling == 'avgpool': + if num_key_value_groups > 1: + attn_cache = F.avg_pool2d(attn_weights_sum, + kernel_size=(num_key_value_groups, + attn_config.kernel_size), + padding=(0, attn_config.kernel_size//2), + stride=(num_key_value_groups, 1)) + else: + attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size, + padding=attn_config.kernel_size//2, stride=1) + elif attn_config.pooling == 'maxpool': + if num_key_value_groups > 1: + attn_cache = F.max_pool2d(attn_weights_sum, + kernel_size=(num_key_value_groups, + attn_config.kernel_size), + padding=(0, attn_config.kernel_size//2), + stride=(num_key_value_groups, 1)) + else: + attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size, + padding=attn_config.kernel_size//2, stride=1) + else: + invalidInputError(False, 'Pooling method not supported') + indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size, + dim=-1).indices + indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim) + k_past_compress = key_states[:, :, :-attn_config.window_size, :]\ + .gather(dim=2, index=indices) + v_past_compress = value_states[:, :, :-attn_config.window_size, :]\ + .gather(dim=2, index=indices) + k_cur = key_states[:, :, -attn_config.window_size:, :] + v_cur = value_states[:, :, -attn_config.window_size:, :] + key_states = torch.cat([k_past_compress, k_cur], dim=2) + value_states = torch.cat([v_past_compress, v_cur], dim=2) + return key_states, value_states class DynamicCompressCache(DynamicCache): - def __init__(self, *args, **kwargs): + def __init__(self, quant_kv=False, *args, **kwargs): super().__init__(*args, **kwargs) self.real_kv_len = 0 + self.quant_kv = quant_kv + self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache def update_seen_tokens(self, layer_idx, q_len): if layer_idx == 0: @@ -260,49 +269,62 @@ def update( self.key_cache.append(key_states_compress) self.value_cache.append(value_states_compress) - k_cache_compressed, v_cache_compressed = init_kv_cache( - bsz, num_heads, head_dim, - 0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, - key_states.dtype, key_states.device - ) - k_cache_compressed, v_cache_compressed = append_kv_cache( + if not self.quant_kv: + k_cache_compressed, v_cache_compressed = init_kv_cache( + bsz, num_heads, head_dim, + 0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, + key_states.dtype, key_states.device + ) + else: + k_cache_compressed, v_cache_compressed = init_fp8_kv_cache( + bsz, num_heads, seq_len, head_dim, + device=key_states.device, + ) + k_cache_compressed, v_cache_compressed = self.append_kv_func( k_cache_compressed, v_cache_compressed, key_states_compress, value_states_compress) self.key_cache[layer_idx] = k_cache_compressed self.value_cache[layer_idx] = v_cache_compressed if key_states.stride(2) != head_dim: - k_cache, v_cache = init_kv_cache( - bsz, num_heads, head_dim, - 0, key_states.size(2), - key_states.dtype, key_states.device - ) - k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states) + if not self.quant_kv: + k_cache, v_cache = init_kv_cache( + bsz, num_heads, head_dim, + 0, key_states.size(2), + key_states.dtype, key_states.device + ) + else: + k_cache, v_cache = init_fp8_kv_cache( + bsz, num_heads, 0, head_dim, key_states.device + ) + k_cache, v_cache = self.append_kv_func(k_cache, v_cache, + key_states, value_states) return k_cache, v_cache else: return key_states, value_states else: cache_k = self.key_cache[layer_idx] cache_v = self.value_cache[layer_idx] - if not enough_kv_room: + if not enough_kv_room and not self.quant_kv: # allocate new - new_c_k, new_c_v = extend_kv_cache(bsz, - num_heads, # Support GQA - head_dim, - cache_k.size(2), - cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, - dtype=cache_k.dtype, - device=query_states.device) + new_c_k, new_c_v = extend_kv_cache( + bsz, + num_heads, # Support GQA + head_dim, + cache_k.size(2), + cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH, + dtype=cache_k.dtype, + device=query_states.device) new_c_k[:] = cache_k new_c_v[:] = cache_v cache_k = new_c_k cache_v = new_c_v - key_states, value_states = append_kv_cache(cache_k, - cache_v, - key_states, - value_states) + key_states, value_states = self.append_kv_func(cache_k, + cache_v, + key_states, + value_states) # update past_key_value self.key_cache[layer_idx] = key_states @@ -316,3 +338,14 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: if len(self.key_cache) <= layer_idx: return 0 return self.real_kv_len + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + quantize_kv: Optional[bool] = False) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" + cache = cls(quantize_kv) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache diff --git a/python/llm/src/ipex_llm/transformers/models/phi3.py b/python/llm/src/ipex_llm/transformers/models/phi3.py index 6378b6fe348..60e3c394eb5 100644 --- a/python/llm/src/ipex_llm/transformers/models/phi3.py +++ b/python/llm/src/ipex_llm/transformers/models/phi3.py @@ -31,6 +31,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import math import torch import warnings @@ -40,11 +41,13 @@ from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache -from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache +from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36 +from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache from typing import Optional, Tuple, List from transformers.models.phi.modeling_phi import repeat_kv from transformers.cache_utils import Cache +KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256)) def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): @@ -94,6 +97,9 @@ def attention_forward( bsz, q_len, _ = hidden_states.size() + # [CompressKV] + use_compresskv = isinstance(past_key_value, DynamicCompressCache) + qkv = self.qkv_proj(hidden_states) qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim) qkv = qkv.transpose(1, 2) @@ -127,12 +133,26 @@ def attention_forward( cos, sin, position_ids) if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, - self.layer_idx, None) + # [CompressKV] + if use_compresskv: + enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, + query_states, attention_mask, self.num_key_value_groups, + self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH) + else: + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, None) if use_sdp(q_len, kv_seq_len, self.head_dim, query_states): + # [CompressKV] + if use_compresskv: + # print(attention_mask.shape) + context_len = key_states.size(2) + attention_mask = attention_mask[:, :, :, -context_len:] import xe_addons - if isinstance(past_key_value, DynamicFp8Cache): + if isinstance(past_key_value, + DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv): attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states, attention_mask) else: @@ -148,7 +168,8 @@ def attention_forward( # attn_output = xe_addons.sdp_causal(query_states, key_states, # value_states, attention_mask) else: - if isinstance(past_key_value, DynamicFp8Cache): + if isinstance(past_key_value, + DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv): key_states, value_states = restore_fp8_kv_cache(key_states, value_states, query_states.dtype) # repeat k/v heads if n_kv_heads < n_heads @@ -235,10 +256,20 @@ def model_forward( use_cache = use_cache if use_cache is not None else self.config.use_cache input = input_ids if input_ids is not None else inputs_embeds use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input) + use_compress_kv = should_use_compresskv(input, input.shape[-1]) if use_cache: - if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache): + if use_compress_kv and not isinstance(past_key_values, + DynamicCompressCache): + past_key_values = DynamicCompressCache.\ + from_legacy_cache(past_key_values, + quantize_kv=use_quantize_kv) + if use_quantize_kv and not isinstance(past_key_values, + (DynamicFp8Cache, DynamicCompressCache)): past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values) - if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache): + if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values, + (DynamicNormalCache, + DynamicCompressCache + )): past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values) return origin_model_forward( self=self, diff --git a/python/llm/src/ipex_llm/transformers/models/utils.py b/python/llm/src/ipex_llm/transformers/models/utils.py index 14375dd6a70..be19f5ac0b5 100644 --- a/python/llm/src/ipex_llm/transformers/models/utils.py +++ b/python/llm/src/ipex_llm/transformers/models/utils.py @@ -490,7 +490,7 @@ def should_use_compresskv(x: torch.Tensor, prompt_len: int): if use_compress_kv is None: return ( get_xpu_device_type(x) == "mtl" - and prompt_len >= 2500 + and prompt_len >= 1800 and prompt_len <= 4500 ) else: