Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phi3 support compresskv #11733

Merged
merged 19 commits into from
Aug 9, 2024
181 changes: 107 additions & 74 deletions python/llm/src/ipex_llm/transformers/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,62 +155,71 @@ def compress_kv(attn_config, key_states, query_states, value_states, attention_m
if q_len <= attn_config.max_capacity_prompt:
return key_states, value_states
else:
key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
key_states_expand.transpose(2, 3)) / math.sqrt(head_dim)
mask = torch.full((attn_config.window_size, attn_config.window_size),
torch.finfo(attn_weights.dtype).min,
device=attn_weights.device)
mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(attn_weights.device)
attention_mask = mask[None, None, :, :]

attn_weights[:, :, -attn_config.window_size:, -attn_config.window_size:] += attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights_sum = attn_weights[:, :, -attn_config.window_size:,
:-attn_config.window_size].sum(dim=-2)
if attn_config.pooling == 'avgpool':
if num_key_value_groups > 1:
attn_cache = F.avg_pool2d(attn_weights_sum, kernel_size=(num_key_value_groups,
attn_config.kernel_size),
padding=(0, attn_config.kernel_size//2),
stride=(num_key_value_groups, 1))
else:
attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
padding=attn_config.kernel_size//2, stride=1)
elif attn_config.pooling == 'maxpool':
if num_key_value_groups > 1:
attn_cache = F.max_pool2d(attn_weights_sum,
kernel_size=(num_key_value_groups,
attn_config.kernel_size),
padding=(0, attn_config.kernel_size//2),
stride=(num_key_value_groups, 1))
else:
attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
padding=attn_config.kernel_size//2, stride=1)
sliding_window_size = getattr(attn_config, "sliding_window", None)
if sliding_window_size is not None and sliding_window_size <= 2500:
return key_states[:, :, -sliding_window_size:, :], \
value_states[:, :, -sliding_window_size:, :]
else:
invalidInputError(False, 'Pooling method not supported')
indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
k_past_compress = key_states[:, :, :-attn_config.window_size, :].gather(dim=2,
index=indices)
v_past_compress = value_states[:, :, :-attn_config.window_size, :].gather(dim=2,
index=indices)
k_cur = key_states[:, :, -attn_config.window_size:, :]
v_cur = value_states[:, :, -attn_config.window_size:, :]
key_states = torch.cat([k_past_compress, k_cur], dim=2)
value_states = torch.cat([v_past_compress, v_cur], dim=2)
return key_states, value_states
key_states_expand = repeat_kv(key_states, num_key_value_groups).to(key_states.device)
attn_weights = torch.matmul(query_states[..., -attn_config.window_size:, :],
key_states_expand.transpose(2, 3)) / math.sqrt(head_dim)
mask = torch.full((attn_config.window_size, attn_config.window_size),
torch.finfo(attn_weights.dtype).min,
device=attn_weights.device)
mask_cond = torch.arange(mask.size(-1), device=attn_weights.device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(attn_weights.device)
attention_mask = mask[None, None, :, :]

attn_weights[:, :, -attn_config.window_size:,
-attn_config.window_size:] += attention_mask

attn_weights = nn.functional.softmax(attn_weights, dim=-1,
dtype=torch.float32).to(query_states.dtype)
attn_weights_sum = attn_weights[:, :, -attn_config.window_size:,
:-attn_config.window_size].sum(dim=-2)
if attn_config.pooling == 'avgpool':
if num_key_value_groups > 1:
attn_cache = F.avg_pool2d(attn_weights_sum,
kernel_size=(num_key_value_groups,
attn_config.kernel_size),
padding=(0, attn_config.kernel_size//2),
stride=(num_key_value_groups, 1))
else:
attn_cache = F.avg_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
padding=attn_config.kernel_size//2, stride=1)
elif attn_config.pooling == 'maxpool':
if num_key_value_groups > 1:
attn_cache = F.max_pool2d(attn_weights_sum,
kernel_size=(num_key_value_groups,
attn_config.kernel_size),
padding=(0, attn_config.kernel_size//2),
stride=(num_key_value_groups, 1))
else:
attn_cache = F.max_pool1d(attn_weights_sum, kernel_size=attn_config.kernel_size,
padding=attn_config.kernel_size//2, stride=1)
else:
invalidInputError(False, 'Pooling method not supported')
indices = attn_cache.topk(attn_config.max_capacity_prompt - attn_config.window_size,
dim=-1).indices
indices = indices.unsqueeze(-1).expand(-1, -1, -1, head_dim)
k_past_compress = key_states[:, :, :-attn_config.window_size, :]\
.gather(dim=2, index=indices)
v_past_compress = value_states[:, :, :-attn_config.window_size, :]\
.gather(dim=2, index=indices)
k_cur = key_states[:, :, -attn_config.window_size:, :]
v_cur = value_states[:, :, -attn_config.window_size:, :]
key_states = torch.cat([k_past_compress, k_cur], dim=2)
value_states = torch.cat([v_past_compress, v_cur], dim=2)
return key_states, value_states


class DynamicCompressCache(DynamicCache):
def __init__(self, *args, **kwargs):
def __init__(self, quant_kv=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.real_kv_len = 0
self.quant_kv = quant_kv
self.append_kv_func = append_fp8_kv_cache if quant_kv else append_kv_cache

def update_seen_tokens(self, layer_idx, q_len):
if layer_idx == 0:
Expand Down Expand Up @@ -260,49 +269,62 @@ def update(
self.key_cache.append(key_states_compress)
self.value_cache.append(value_states_compress)

k_cache_compressed, v_cache_compressed = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
key_states.dtype, key_states.device
)
k_cache_compressed, v_cache_compressed = append_kv_cache(
if not self.quant_kv:
k_cache_compressed, v_cache_compressed = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states_compress.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
key_states.dtype, key_states.device
)
else:
k_cache_compressed, v_cache_compressed = init_fp8_kv_cache(
bsz, num_heads, seq_len, head_dim,
device=key_states.device,
)
k_cache_compressed, v_cache_compressed = self.append_kv_func(
k_cache_compressed, v_cache_compressed,
key_states_compress, value_states_compress)
self.key_cache[layer_idx] = k_cache_compressed
self.value_cache[layer_idx] = v_cache_compressed

if key_states.stride(2) != head_dim:
k_cache, v_cache = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states.size(2),
key_states.dtype, key_states.device
)
k_cache, v_cache = append_kv_cache(k_cache, v_cache, key_states, value_states)
if not self.quant_kv:
k_cache, v_cache = init_kv_cache(
bsz, num_heads, head_dim,
0, key_states.size(2),
key_states.dtype, key_states.device
)
else:
k_cache, v_cache = init_fp8_kv_cache(
bsz, num_heads, 0, head_dim, key_states.device
)
k_cache, v_cache = self.append_kv_func(k_cache, v_cache,
key_states, value_states)
return k_cache, v_cache
else:
return key_states, value_states
else:
cache_k = self.key_cache[layer_idx]
cache_v = self.value_cache[layer_idx]
if not enough_kv_room:
if not enough_kv_room and not self.quant_kv:
# allocate new
new_c_k, new_c_v = extend_kv_cache(bsz,
num_heads, # Support GQA
head_dim,
cache_k.size(2),
cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=query_states.device)
new_c_k, new_c_v = extend_kv_cache(
bsz,
num_heads, # Support GQA
head_dim,
cache_k.size(2),
cache_k.size(2) + KV_CACHE_ALLOC_BLOCK_LENGTH,
dtype=cache_k.dtype,
device=query_states.device)

new_c_k[:] = cache_k
new_c_v[:] = cache_v
cache_k = new_c_k
cache_v = new_c_v

key_states, value_states = append_kv_cache(cache_k,
cache_v,
key_states,
value_states)
key_states, value_states = self.append_kv_func(cache_k,
cache_v,
key_states,
value_states)

# update past_key_value
self.key_cache[layer_idx] = key_states
Expand All @@ -316,3 +338,14 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if len(self.key_cache) <= layer_idx:
return 0
return self.real_kv_len

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
quantize_kv: Optional[bool] = False) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
cache = cls(quantize_kv)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache
46 changes: 39 additions & 7 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import math
import torch
import warnings
Expand All @@ -40,11 +41,13 @@
from ipex_llm.transformers.models.utils import mlp_fusion_check, SILU
from ipex_llm.transformers.models.utils import use_sdp, use_sdp_causal
from ipex_llm.transformers.models.utils import use_quantize_kv_cache, restore_fp8_kv_cache
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache
from ipex_llm.transformers.models.utils import should_use_compresskv, is_enough_kv_cache_room_4_36
from ipex_llm.transformers.kv import DynamicNormalCache, DynamicFp8Cache, DynamicCompressCache

from typing import Optional, Tuple, List
from transformers.models.phi.modeling_phi import repeat_kv
from transformers.cache_utils import Cache
KV_CACHE_ALLOC_BLOCK_LENGTH = int(os.environ.get("KV_CACHE_ALLOC_BLOCK_LENGTH", 256))


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
Expand Down Expand Up @@ -94,6 +97,9 @@ def attention_forward(

bsz, q_len, _ = hidden_states.size()

# [CompressKV]
use_compresskv = isinstance(past_key_value, DynamicCompressCache)

qkv = self.qkv_proj(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
qkv = qkv.transpose(1, 2)
Expand Down Expand Up @@ -127,12 +133,26 @@ def attention_forward(
cos, sin, position_ids)

if past_key_value is not None:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)
# [CompressKV]
if use_compresskv:
enough_kv_room = is_enough_kv_cache_room_4_36(past_key_value, self.layer_idx)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx,
query_states, attention_mask, self.num_key_value_groups,
self.config, enough_kv_room, KV_CACHE_ALLOC_BLOCK_LENGTH)
else:
key_states, value_states = past_key_value.update(key_states, value_states,
self.layer_idx, None)

if use_sdp(q_len, kv_seq_len, self.head_dim, query_states):
# [CompressKV]
if use_compresskv:
# print(attention_mask.shape)
context_len = key_states.size(2)
attention_mask = attention_mask[:, :, :, -context_len:]
import xe_addons
if isinstance(past_key_value, DynamicFp8Cache):
if isinstance(past_key_value,
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
attn_output = xe_addons.sdp_fp8(query_states, key_states, value_states,
attention_mask)
else:
Expand All @@ -148,7 +168,8 @@ def attention_forward(
# attn_output = xe_addons.sdp_causal(query_states, key_states,
# value_states, attention_mask)
else:
if isinstance(past_key_value, DynamicFp8Cache):
if isinstance(past_key_value,
DynamicFp8Cache) or (use_compresskv and past_key_value.quant_kv):
key_states, value_states = restore_fp8_kv_cache(key_states, value_states,
query_states.dtype)
# repeat k/v heads if n_kv_heads < n_heads
Expand Down Expand Up @@ -235,10 +256,21 @@ def model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
use_compress_kv = should_use_compresskv(input, input.shape[-1])
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
# TODO: use quantize kv only support phi3 mini-4k, medium-4k
past_key_values = DynamicCompressCache.\
from_legacy_cache(past_key_values,
quantize_kv=use_quantize_kv)
if use_quantize_kv and not isinstance(past_key_values,
(DynamicFp8Cache, DynamicCompressCache)):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
if not use_quantize_kv and not isinstance(past_key_values, DynamicNormalCache):
if not use_quantize_kv and not use_compress_kv and not isinstance(past_key_values,
(DynamicNormalCache,
DynamicCompressCache
)):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
return origin_model_forward(
self=self,
Expand Down
2 changes: 1 addition & 1 deletion python/llm/src/ipex_llm/transformers/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def update_past_key_value(past_key_value, key_states, value_states,
return key_states, value_states


def should_use_compresskv(x: torch.Tensor, prompt_len: int):
def should_use_compresskv(x: torch.Tensor, prompt_len: int, sliding_window: int=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is sliding_window used?

use_compress_kv = os.environ.get("IPEX_LLM_COMPRESS_KV_CACHE", None)
if use_compress_kv is None:
return (
Expand Down
Loading