Skip to content

Commit

Permalink
Optimize speculative decoding PVC memory usage (#10329)
Browse files Browse the repository at this point in the history
* optimize memory

* update

* update

* update

* support other models

* update

* fix style
  • Loading branch information
cyita authored Mar 6, 2024
1 parent cc79684 commit 9ea499c
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions python/llm/src/bigdl/llm/transformers/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import os
import copy
import logging
import warnings
import inspect
import transformers
from packaging import version
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from transformers import top_k_top_p_filtering, GenerationConfig, \
LogitsProcessorList, StoppingCriteriaList
Expand Down Expand Up @@ -367,6 +367,55 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s
delta_past_value.to(torch.float32)


def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_len=256,
model_type="llama"):
from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \
extend_kv_cache
enough_kv_room = True
if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral",
"gptj", "opt"]:
return past_key_values, False
cache_k = past_key_values[0][0]
if model_type == "chatglm":
cache_k = cache_k.permute(1, 2, 0, 3)
elif model_type == "qwen":
cache_k = cache_k.transpose(1, 2)

enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value=(cache_k, None),
seq_len=max_step_draft)
bsz, num_heads, current_seq_len, head_dim = cache_k.shape
device = past_key_values[0][0].device
if not enough_kv_room:
past_key_values = list(past_key_values)
for i in range(len(past_key_values)):
cache_k = past_key_values[i][0]
cache_v = past_key_values[i][1]
if model_type == "chatglm":
cache_k = cache_k.permute(1, 2, 0, 3)
cache_v = cache_v.permute(1, 2, 0, 3)
elif model_type == "qwen":
cache_k = cache_k.transpose(1, 2)
cache_v = cache_v.transpose(1, 2)
new_cache_k, new_cache_v = extend_kv_cache(
bsz,
num_heads, # Support GQA
head_dim,
cache_k.size(2),
current_seq_len + max_step_draft + kv_alloc_block_len,
dtype=cache_v.dtype,
device=device)
new_cache_k[:] = cache_k
new_cache_v[:] = cache_v
if model_type == "chatglm":
past_key_values[i] = (new_cache_k.permute(2, 0, 1, 3),
new_cache_v.permute(2, 0, 1, 3))
elif model_type == "qwen":
past_key_values[i] = (new_cache_k.transpose(1, 2), new_cache_v.transpose(1, 2))
else:
past_key_values[i] = (new_cache_k, new_cache_v)
return past_key_values, not enough_kv_room


@torch.no_grad()
def speculative_generate(self,
inputs: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -504,6 +553,9 @@ def speculative_generate(self,

self.clear_benchmarks()

if self.device.type == 'xpu':
torch.xpu.empty_cache()

# Example:
# Target model forward for the first token
# Step 1. target_model(prompt) -> a
Expand Down Expand Up @@ -562,6 +614,10 @@ def speculative_generate(self,
past_key_values_storage, _enable_ipex)
original_draft_past_key_values = draft_past_key_values
else:
past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values,
max_step_draft,
max_new_tokens - step + 40,
self.config.model_type)
draft_past_key_values = past_key_values
draft_generate_ids[:, 0] = current_input_ids
draft_prob_list = []
Expand Down Expand Up @@ -742,6 +798,8 @@ def speculative_generate(self,
output_ids = greedy(logits)
if self.device.type == 'xpu':
torch.xpu.synchronize()
if extend_kv:
torch.xpu.empty_cache()
toc = time.time()
self.verify_time.append(toc - tic)
self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])
Expand Down

0 comments on commit 9ea499c

Please sign in to comment.