diff --git a/python/llm/src/bigdl/llm/transformers/speculative.py b/python/llm/src/bigdl/llm/transformers/speculative.py index 2b1fb319764..1e3f002922a 100644 --- a/python/llm/src/bigdl/llm/transformers/speculative.py +++ b/python/llm/src/bigdl/llm/transformers/speculative.py @@ -24,8 +24,8 @@ import os import copy import logging -import warnings -import inspect +import transformers +from packaging import version from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union from transformers import top_k_top_p_filtering, GenerationConfig, \ LogitsProcessorList, StoppingCriteriaList @@ -367,6 +367,55 @@ def _update_past_key_values_storage_cpu(self, past_key_values, past_key_values_s delta_past_value.to(torch.float32) +def _check_and_extend_kv_cache(past_key_values, max_step_draft, kv_alloc_block_len=256, + model_type="llama"): + from bigdl.llm.transformers.models.utils import is_enough_kv_cache_room_4_31, \ + extend_kv_cache + enough_kv_room = True + if model_type not in ["chatglm", "qwen", "baichuan", "llama", "mistral", + "gptj", "opt"]: + return past_key_values, False + cache_k = past_key_values[0][0] + if model_type == "chatglm": + cache_k = cache_k.permute(1, 2, 0, 3) + elif model_type == "qwen": + cache_k = cache_k.transpose(1, 2) + + enough_kv_room = is_enough_kv_cache_room_4_31(past_key_value=(cache_k, None), + seq_len=max_step_draft) + bsz, num_heads, current_seq_len, head_dim = cache_k.shape + device = past_key_values[0][0].device + if not enough_kv_room: + past_key_values = list(past_key_values) + for i in range(len(past_key_values)): + cache_k = past_key_values[i][0] + cache_v = past_key_values[i][1] + if model_type == "chatglm": + cache_k = cache_k.permute(1, 2, 0, 3) + cache_v = cache_v.permute(1, 2, 0, 3) + elif model_type == "qwen": + cache_k = cache_k.transpose(1, 2) + cache_v = cache_v.transpose(1, 2) + new_cache_k, new_cache_v = extend_kv_cache( + bsz, + num_heads, # Support GQA + head_dim, + cache_k.size(2), + current_seq_len + max_step_draft + kv_alloc_block_len, + dtype=cache_v.dtype, + device=device) + new_cache_k[:] = cache_k + new_cache_v[:] = cache_v + if model_type == "chatglm": + past_key_values[i] = (new_cache_k.permute(2, 0, 1, 3), + new_cache_v.permute(2, 0, 1, 3)) + elif model_type == "qwen": + past_key_values[i] = (new_cache_k.transpose(1, 2), new_cache_v.transpose(1, 2)) + else: + past_key_values[i] = (new_cache_k, new_cache_v) + return past_key_values, not enough_kv_room + + @torch.no_grad() def speculative_generate(self, inputs: Optional[torch.Tensor] = None, @@ -504,6 +553,9 @@ def speculative_generate(self, self.clear_benchmarks() + if self.device.type == 'xpu': + torch.xpu.empty_cache() + # Example: # Target model forward for the first token # Step 1. target_model(prompt) -> a @@ -562,6 +614,10 @@ def speculative_generate(self, past_key_values_storage, _enable_ipex) original_draft_past_key_values = draft_past_key_values else: + past_key_values, extend_kv = _check_and_extend_kv_cache(past_key_values, + max_step_draft, + max_new_tokens - step + 40, + self.config.model_type) draft_past_key_values = past_key_values draft_generate_ids[:, 0] = current_input_ids draft_prob_list = [] @@ -742,6 +798,8 @@ def speculative_generate(self, output_ids = greedy(logits) if self.device.type == 'xpu': torch.xpu.synchronize() + if extend_kv: + torch.xpu.empty_cache() toc = time.time() self.verify_time.append(toc - tic) self.generate_time.append(self.draft_time[-1] + self.verify_time[-1])