Skip to content

Commit

Permalink
optimize phi3 memory usage
Browse files Browse the repository at this point in the history
  • Loading branch information
MeouSker77 committed Aug 20, 2024
1 parent 2946420 commit 6153be0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
15 changes: 15 additions & 0 deletions python/llm/src/ipex_llm/transformers/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def update(

return self.key_cache[layer_idx], self.value_cache[layer_idx]

@classmethod
def from_reserved(cls, layers: int,
bsz: int, n_head: int, length: int, head_dim: int,
dtype: torch.dtype, device: torch.device):
past_key_values = cls()
for _i in range(layers):
k_cache, v_cache = init_kv_cache(
bsz, n_head, head_dim,
0, length + cls.KV_ALLOC_BLOCK_LENGTH,
dtype, device
)
past_key_values.key_cache.append(k_cache)
past_key_values.value_cache.append(v_cache)
return past_key_values


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down
14 changes: 11 additions & 3 deletions python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ def model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
use_compress_kv = should_use_compresskv(input, input.shape[1])
inputs = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, inputs)
use_compress_kv = should_use_compresskv(inputs, inputs.shape[1])
if use_cache:
if use_compress_kv and not isinstance(past_key_values,
DynamicCompressCache):
Expand All @@ -272,6 +272,14 @@ def model_forward(
DynamicCompressCache
)):
past_key_values = DynamicNormalCache.from_legacy_cache(past_key_values)
if past_key_values.get_seq_length() == 0:
n_layer = self.config.num_hidden_layers
n_head = self.config.num_attention_heads
head_dim = self.config.hidden_size // self.config.num_attention_heads
past_key_values = DynamicNormalCache.from_reserved(
n_layer, inputs.size(0), n_head, inputs.size(1), head_dim,
inputs.dtype, inputs.device
)
return origin_model_forward(
self=self,
input_ids=input_ids,
Expand Down

0 comments on commit 6153be0

Please sign in to comment.