From 7d8bc834155079a781a82b37d4df25229667de68 Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Fri, 5 Jul 2024 13:10:35 +0800 Subject: [PATCH] LLM: Partial Prefilling for Pipeline Parallel Serving (#11457) LLM: Partial Prefilling for Pipeline Parallel Serving --- .../GPU/Pipeline-Parallel-FastAPI/README.md | 6 +- .../pipeline_serving.py | 5 +- .../GPU/Pipeline-Parallel-FastAPI/run.sh | 6 +- .../transformers/pipeline_parallel.py | 346 +++++++++++++----- 4 files changed, 261 insertions(+), 102 deletions(-) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md index 1c994d29554..278c8c549b6 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md @@ -57,8 +57,12 @@ pip install trl==0.8.1 bash run.sh ``` -> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. +### Command Line Arguments in `run.sh` +> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. Other relative settings are listed below: +- `--low-bit`: Sets the low bit optimizations (such as 'sym_int4', 'fp16', 'fp8' and 'fp6') for the model. +- `--max-num-seqs`: Sets the maximum batch size on a single card during pipeline parallel serving. +- `--max-prefilled-seqs`: Sets the maximum batch size for prefilled sequences. Use `0` to disable partial prefetching and process all requests in a single batch. ### 3. Sample Input and Output diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py index d917891c937..0567280d26c 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py @@ -306,18 +306,21 @@ async def main(): help='The port number on which the server will run.') parser.add_argument('--max-num-seqs', type=int, default=8, help='Max num sequences in a batch.') + parser.add_argument('--max-prefilled-seqs', type=int, default=0, + help='Max num sequences in a batch during prefilling.') args = parser.parse_args() model_path = args.repo_id_or_model_path low_bit = args.low_bit max_num_seqs = args.max_num_seqs + max_prefilled_seqs = args.max_prefilled_seqs # serialize model initialization so that we do not run out of CPU memory for i in range(my_size): if my_rank == i: logger.info("start model initialization") global local_model - local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs) + local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs) logger.info("model initialized") dist.barrier() # Load tokenizer diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh index d1bbacf4acd..91bf6161bff 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh @@ -24,11 +24,13 @@ source $basekit_root/setvars.sh --force source $basekit_root/ccl/latest/env/vars.sh --force export USE_XETLA=OFF -export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 +if [[ $KERNEL_VERSION != *"6.5"* ]]; then + export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +fi export TORCH_LLM_ALLREDUCE=0 export MODEL_PATH=YOUR_MODEL_PATH export NUM_GPUS=2 export IPEX_LLM_QUANTIZE_KV_CACHE=1 -CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 +CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 --max-prefilled-seqs 0 diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 23d945314c0..e514d3d3cac 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -163,6 +163,9 @@ def pipeline_parallel(model, pipeline_parallel_stages): model._modules['lm_head'] = DummyLayer() model.pipeline_parallel_stages = pipeline_parallel_stages + model.layer_start = layer_start + model.layer_end = layer_end + model.num_layers = num_layers model = model.to(f'xpu:{local_rank}') return model @@ -364,6 +367,9 @@ class BatchTask(BaseModel): prompt_lengths: List[int] stopped: bool + prefilled_index: int + partial_prefilling: int + def make_attention_mask(prompt_lengths): max_length = max(prompt_lengths) @@ -375,7 +381,7 @@ def make_attention_mask(prompt_lengths): class ModelRunner: """Implementation for pipeline parallel multi-stage serving.""" - def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, + def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs, torch_dtype=torch.float16): self.pp_config = PPConfig(rank, world_size) self.dtype = torch_dtype @@ -404,7 +410,11 @@ def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, self.print_len = {} self.is_finish = {} self.model_name = checkpoint - self.layer_start = 0 + # self.layer_start = 0 + # self.layer_end = 0 + + self.max_prefilled_seqs = max_prefilled_seqs + self.partial_output_dict = {} def load_model(self, model_path, world_size, low_bit='sym_int4'): from ipex_llm.transformers import AutoModelForCausalLM, AutoModel @@ -427,11 +437,90 @@ def load_model(self, model_path, world_size, low_bit='sym_int4'): model = model.eval() return model + def prepare_batch(self, cur_batch): + if self.rank == 0: + cur_input_start = cur_batch.prefilled_index + if self.max_prefilled_seqs > 0: + if cur_input_start < cur_batch.batch_size: + cur_input_end = cur_input_start + self.max_prefilled_seqs + cur_input_end = min(cur_input_end, cur_batch.batch_size) + cur_batch.partial_prefilling = cur_input_end - cur_input_start + else: + cur_batch.partial_prefilling = 0 + + return cur_batch + + def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): + if model_type in ["baichuan", "chatglm"]: + result = [] + for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2): + if sub_tuple1 is None: + sub_result = [sub_tuple2] + elif sub_tuple2 is None: + sub_result = [sub_tuple1] + else: + sub_result = [] + for t1, t2 in zip(sub_tuple1, sub_tuple2): + if t1 is None: + sub_result.append(t2) + elif t2 is None: + sub_result.append(t1) + else: + if model_type == "chatglm" and self.model.config.num_layers != 40: + sub_result.append(torch.cat((t1, t2), dim=1)) + else: + sub_result.append(torch.cat((t1, t2), dim=0)) + result.append(tuple(sub_result)) + return tuple(result) + else: + # num_layers = self.model.layer_end - self.model.layer_start + for layer_idx in range(self.model.num_layers): + kv_cache_1.key_cache[layer_idx] = \ + torch.cat([kv_cache_1.key_cache[layer_idx], + kv_cache_2.key_cache[layer_idx]], dim=0) + kv_cache_1.value_cache[layer_idx] = \ + torch.cat([kv_cache_1.value_cache[layer_idx], + kv_cache_2.value_cache[layer_idx]], dim=0) + + return kv_cache_1 + + def update_kv_cache(self, kv_cache, cur_id): + layer_start = self.model.layer_start + layer_end = self.model.layer_end + num_layers = self.model.num_layers + + if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: + # for glm-4-9b-chat + if self.past_key_values_dict.get(cur_id, None) is None: + value_placeholder = torch.empty_like((kv_cache)[-1][0]) + past_key_values_placeholder = tuple( + (value_placeholder, value_placeholder) for _ in range(layer_start) + ) + (kv_cache)[:layer_end - layer_start] + tuple( + (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) + ) + kv_cache = past_key_values_placeholder + else: + pass + elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: + value_placeholder = torch.empty_like((kv_cache)[-1][0]) + kv_cache = tuple((value_placeholder, value_placeholder)) + \ + tuple(None for _ in range(layer_start)) + \ + (kv_cache)[layer_start:] + # past_key_values_placeholder = tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_start) + # ) + (kv_cache)[layer_start:] + # kv_cache = past_key_values_placeholder + else: + pass + + return kv_cache + @torch.no_grad() def model_step(self, input, cur_batch): if cur_batch is None or cur_batch.stopped or input is None: - return None + return None, cur_batch + # logger.info(f"{self.rank} {cur_batch} {input.shape}") cur_id = cur_batch.batch_id _past_key_values = self.past_key_values_dict.get(cur_id, None) attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device) @@ -439,44 +528,71 @@ def model_step(self, input, cur_batch): if self.rank == 0: input_ids = input inputs_embeds = None + + if cur_batch.partial_prefilling > 0: + cur_input_start = cur_batch.prefilled_index + cur_input_end = cur_input_start + cur_batch.partial_prefilling + input_ids = input_ids[cur_input_start:cur_input_end] + attention_mask = attention_mask[cur_input_start:cur_input_end] + tmp_past_key_values = _past_key_values + _past_key_values = None else: input_ids = None inputs_embeds = input - torch.xpu.empty_cache() + if cur_batch.partial_prefilling > 0: + cur_input_start = cur_batch.prefilled_index + cur_input_end = cur_input_start + cur_batch.partial_prefilling + attention_mask = attention_mask[cur_input_start:cur_input_end] + tmp_past_key_values = _past_key_values + _past_key_values = None + + # torch.xpu.empty_cache() output = self.model(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=_past_key_values, attention_mask=attention_mask, use_cache=True,) - if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: - # for glm-4-9b-chat - if self.past_key_values_dict.get(cur_id, None) is None: - value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - past_key_values_placeholder = tuple( - (value_placeholder, value_placeholder) for _ in range(layer_start) - ) + (output.past_key_values)[: layer_end - layer_start] + tuple( - (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) - ) - _past_key_values = past_key_values_placeholder + if cur_batch.partial_prefilling > 0: + cur_batch.prefilled_index = cur_input_end + if tmp_past_key_values is None: + tmp_past_key_values = output.past_key_values else: - _past_key_values = output.past_key_values - elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - # for baichuan2 and chatglm3 - value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - past_key_values_placeholder = tuple( - (value_placeholder, value_placeholder) for _ in range(layer_start) - ) + (output.past_key_values)[layer_start:] - _past_key_values = past_key_values_placeholder + tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type, + tmp_past_key_values, + output.past_key_values) + # torch.xpu.empty_cache() + + if cur_batch.prefilled_index == cur_batch.batch_size: + tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id) + + self.past_key_values_dict[cur_id] = tmp_past_key_values + + if self.pp_config.is_tail: + _pre_output = self.partial_output_dict.get(cur_id, None) + tmp_output = output.logits.to(self.dtype) + tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1) + if _pre_output is None: + _pre_output = tmp_output + else: + _pre_output = torch.cat((_pre_output, tmp_output), dim=0) + self.partial_output_dict[cur_id] = _pre_output else: - _past_key_values = output.past_key_values - self.past_key_values_dict[cur_id] = _past_key_values + _past_key_values = self.update_kv_cache(output.past_key_values, cur_id) + self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize() if not self.pp_config.is_tail: - return output[0].to(self.dtype) + return output[0].to(self.dtype), cur_batch else: - return output.logits + if cur_batch.partial_prefilling > 0 and \ + cur_batch.prefilled_index == cur_batch.batch_size: + _output = self.partial_output_dict.pop(cur_id, None) + cur_batch.partial_prefilling = 0 + return _output, cur_batch + else: + _output = torch.argmax(output.logits[:, -1:, :], dim=-1) + return _output, cur_batch def is_initialized(self): return True @@ -504,6 +620,8 @@ async def add_request(self, tokenizer): input_len=input_ids.size(1), prompt_lengths=[sum(attention_mask[i, :]) for i in range(input_ids.size(0))], stopped=False, + prefilled_index=0, + partial_prefilling=0, ) self.input_ids_dict[new_batch.batch_id] = input_ids @@ -517,11 +635,15 @@ def clear_batch(self, cur_id): self.token_times.pop(cur_id, None) self.past_key_values_dict.pop(cur_id, None) + self.is_finish.pop(cur_id, None) + self.partial_output_dict.pop(cur_id, None) + async def process_step(self, tokenizer, result_dict): cur_batch = None if self.rank == 0: if self.send_buff is not None: + # logger.info(f"send {self.rank} {self.send_buff.shape}") dist.send(self.send_buff, dst=self.next_rank) if self.on_going_batches[0] is not None: @@ -530,6 +652,7 @@ async def process_step(self, tokenizer, result_dict): if cur_batch is None: if not self.waiting_requests.empty(): + # wait more requests to be put in self.waiting_requests await asyncio.sleep(0.01) cur_batch = await self.add_request(tokenizer) cur_input = self.input_ids_dict[cur_batch.batch_id] @@ -539,84 +662,99 @@ async def process_step(self, tokenizer, result_dict): if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None): cur_id = cur_batch.batch_id - next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}', - dtype=torch.int64) - dist.recv(next_ids, src=self.pre_rank) + # cur_batch = self.prepare_batch(cur_batch) + if cur_batch.prefilled_index >= cur_batch.batch_size: + cur_batch.partial_prefilling = 0 + if cur_batch.partial_prefilling > 0: + next_ids = torch.empty((cur_batch.partial_prefilling, 1,), + device=f'xpu:{self.rank}', dtype=torch.int64) + else: + next_ids = torch.empty((cur_batch.batch_size, 1,), + device=f'xpu:{self.rank}', dtype=torch.int64) - if self.tokens.get(cur_id, None) is None: - self.tokens[cur_id] = [] - - if len(next_ids.shape) == 1: - next_ids = next_ids.unsqueeze(0) - self.tokens[cur_id].append(next_ids) - self.token_times[cur_id].append(time.perf_counter()) - cur_input = next_ids - cur_batch.input_len = 1 - cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] - - for index, request_id in enumerate(cur_batch.request_ids): - - if not self.is_finish.get(request_id, False): - remain = cur_batch.max_tokens - len(self.tokens[cur_id]) - - if self.streamer.get(request_id, None) is None: - self.streamer[request_id] = asyncio.Queue() - - # Currently ignore eos for benchmark - # if next_ids[index].int() == tokenizer.eos_token_id: - # remain = 0 - # self.is_finish[request_id] = True - - if self.token_cache.get(request_id, None) is None: - self.token_cache[request_id] = [] - self.print_len[request_id] = 0 - self.token_cache[request_id].extend(next_ids[index].tolist()) - - text = tokenizer.decode(self.token_cache[request_id]) - if text.endswith("\n"): - printable_text = text[self.print_len[request_id]:] - self.token_cache[request_id] = [] - self.print_len[request_id] = 0 - elif len(text) > 0 and _is_chinese_char(ord(text[-1])): - printable_text = text[self.print_len[request_id]:] - self.print_len[request_id] += len(printable_text) - else: - printable_text = text[self.print_len[request_id]: text.rfind(" ") + 1] - self.print_len[request_id] += len(printable_text) + # logger.info(f"recv {self.rank} {next_ids.shape}") + dist.recv(next_ids, src=self.pre_rank) - if remain > 0: - await self.streamer[request_id].put((remain, printable_text)) - else: - printable_text = printable_text + text[self.print_len[request_id]:] - self.token_cache.pop(request_id, None) - self.print_len.pop(request_id, None) - await self.streamer[request_id].put((remain, printable_text)) - - if len(self.tokens[cur_id]) >= cur_batch.max_tokens: - # Finish a batch - outputs = torch.cat(self.tokens[cur_id], dim=1) - outputs = outputs.cpu() - output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False) - for request_id, output_str in zip(cur_batch.request_ids, output_strs): - with self.dict_lock: - result_dict[request_id] = output_str - - cur_times = self.token_times[cur_id] - first_token = cur_times[1] - cur_times[0] - next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1) - logger.info(f"First token latency: {first_token}, " - f"next token latency: {next_token}") - self.clear_batch(cur_id) - cur_batch.stopped = True + if cur_batch.partial_prefilling > 0: + cur_input = self.input_ids_dict[cur_batch.batch_id] + else: + if self.tokens.get(cur_id, None) is None: + self.tokens[cur_id] = [] + + if len(next_ids.shape) == 1: + next_ids = next_ids.unsqueeze(0) + self.tokens[cur_id].append(next_ids) + self.token_times[cur_id].append(time.perf_counter()) + cur_input = next_ids + cur_batch.input_len = 1 + cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] + + for index, request_id in enumerate(cur_batch.request_ids): + + if not self.is_finish.get(request_id, False): + remain = cur_batch.max_tokens - len(self.tokens[cur_id]) + + if self.streamer.get(request_id, None) is None: + self.streamer[request_id] = asyncio.Queue() + + # Currently ignore eos for benchmark + # if next_ids[index].int() == tokenizer.eos_token_id: + # remain = 0 + # self.is_finish[request_id] = True + + if self.token_cache.get(request_id, None) is None: + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + self.token_cache[request_id].extend(next_ids[index].tolist()) + + text = tokenizer.decode(self.token_cache[request_id]) + if text.endswith("\n"): + printable_text = text[self.print_len[request_id]:] + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + elif len(text) > 0 and _is_chinese_char(ord(text[-1])): + printable_text = text[self.print_len[request_id]:] + self.print_len[request_id] += len(printable_text) + else: + r_index = text.rfind(" ") + 1 + printable_text = text[self.print_len[request_id]: r_index] + self.print_len[request_id] += len(printable_text) + + if remain > 0: + await self.streamer[request_id].put((remain, printable_text)) + else: + printable_text = printable_text + text[self.print_len[request_id]:] + self.token_cache.pop(request_id, None) + self.print_len.pop(request_id, None) + await self.streamer[request_id].put((remain, printable_text)) + + if len(self.tokens[cur_id]) >= cur_batch.max_tokens: + # Finish a batch + outputs = torch.cat(self.tokens[cur_id], dim=1) + outputs = outputs.cpu() + output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False) + for request_id, output_str in zip(cur_batch.request_ids, output_strs): + with self.dict_lock: + result_dict[request_id] = output_str + + cur_times = self.token_times[cur_id] + first_token = cur_times[1] - cur_times[0] + next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1) + logger.info(f"First token latency: {first_token}, " + f"next token latency: {next_token}") + self.clear_batch(cur_id) + cur_batch.stopped = True else: if (cur_batch is not None) and cur_batch.stopped: cur_batch = None if cur_batch is not None: + cur_batch = self.prepare_batch(cur_batch) dist.broadcast_object_list([cur_batch], src=0) else: if self.send_buff is not None: + # logger.info(f"send {self.rank} {self.send_buff.shape}") dist.send(self.send_buff, dst=self.next_rank) batch_list = [None] @@ -629,14 +767,26 @@ async def process_step(self, tokenizer, result_dict): if cur_batch.stopped: self.clear_batch(cur_batch.batch_id) else: + cur_batch = self.prepare_batch(cur_batch) cur_len = cur_batch.input_len - cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), - device=f'xpu:{self.rank}', dtype=self.dtype) + if cur_batch.partial_prefilling: + cur_input = torch.empty( + (cur_batch.partial_prefilling, cur_len, self.hidden_size,), + device=f'xpu:{self.rank}', + dtype=self.dtype, + ) + else: + cur_input = torch.empty( + (cur_batch.batch_size, cur_len, self.hidden_size,), + device=f'xpu:{self.rank}', + dtype=self.dtype, + ) + # logger.info(f"recv {self.rank} {cur_input.shape}") dist.recv(cur_input, src=self.pre_rank) - output = self.model_step(cur_input, cur_batch) - if output is not None and self.rank == self.world_size - 1: - output = torch.argmax(output[:, -1:, :], dim=-1) + output, cur_batch = self.model_step(cur_input, cur_batch) + # if output is not None and self.rank == self.world_size - 1: + # output = torch.argmax(output[:, -1:, :], dim=-1) if output is not None: # dist.send(output, dst=self.next_rank)