From c157bcd9bb2e507c2bd832cbaf261cf129dcc540 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 27 Jun 2024 23:25:44 +0800 Subject: [PATCH 01/10] init --- .../pipeline_serving.py | 5 +- .../transformers/pipeline_parallel.py | 255 ++++++++++++------ 2 files changed, 175 insertions(+), 85 deletions(-) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py index d917891c937..0567280d26c 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py @@ -306,18 +306,21 @@ async def main(): help='The port number on which the server will run.') parser.add_argument('--max-num-seqs', type=int, default=8, help='Max num sequences in a batch.') + parser.add_argument('--max-prefilled-seqs', type=int, default=0, + help='Max num sequences in a batch during prefilling.') args = parser.parse_args() model_path = args.repo_id_or_model_path low_bit = args.low_bit max_num_seqs = args.max_num_seqs + max_prefilled_seqs = args.max_prefilled_seqs # serialize model initialization so that we do not run out of CPU memory for i in range(my_size): if my_rank == i: logger.info("start model initialization") global local_model - local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs) + local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs, max_prefilled_seqs) logger.info("model initialized") dist.barrier() # Load tokenizer diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 6db55d536cc..bc5971aecc4 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -146,6 +146,8 @@ def pipeline_parallel(model, pipeline_parallel_stages): model._modules['lm_head'] = DummyLayer() model.pipeline_parallel_stages = pipeline_parallel_stages + model.layer_start = layer_start + model.layer_end = layer_end model = model.to(f'xpu:{local_rank}') return model @@ -331,6 +333,9 @@ class BatchTask(BaseModel): prompt_lengths: List[int] stopped: bool + prefilled_index: int + partial_prefilling: int + def make_attention_mask(prompt_lengths): max_length = max(prompt_lengths) @@ -342,7 +347,7 @@ def make_attention_mask(prompt_lengths): class ModelRunner: """Implementation for pipeline parallel multi-stage serving.""" - def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, + def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_prefilled_seqs, torch_dtype=torch.float16): self.pp_config = PPConfig(rank, world_size) self.dtype = torch_dtype @@ -371,7 +376,9 @@ def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, self.print_len = {} self.is_finish = {} self.model_name = checkpoint - self.layer_start = 0 + + self.max_prefilled_seqs = max_prefilled_seqs + self.partial_output_dict = {} def load_model(self, model_path, world_size, low_bit='sym_int4'): from ipex_llm.transformers import AutoModelForCausalLM, AutoModel @@ -394,11 +401,26 @@ def load_model(self, model_path, world_size, low_bit='sym_int4'): model = model.eval() return model + + def prepare_batch(self, cur_batch): + if self.rank == 0: + cur_input_start = cur_batch.prefilled_index + if self.max_prefilled_seqs > 0: + if cur_input_start < cur_batch.batch_size: + cur_input_end = cur_input_start + self.max_prefilled_seqs + cur_input_end = min(cur_input_end, cur_batch.batch_size) + cur_batch.partial_prefilling = cur_input_end - cur_input_start + else: + cur_batch.partial_prefilling = 0 + + return cur_batch + @torch.no_grad() def model_step(self, input, cur_batch): if cur_batch is None or cur_batch.stopped or input is None: - return None + return None, cur_batch + # logger.info(f"{self.rank} {cur_batch} {input.shape}") cur_id = cur_batch.batch_id _past_key_values = self.past_key_values_dict.get(cur_id, None) attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device) @@ -406,10 +428,25 @@ def model_step(self, input, cur_batch): if self.rank == 0: input_ids = input inputs_embeds = None + + if cur_batch.partial_prefilling > 0: + cur_input_start = cur_batch.prefilled_index + cur_input_end = cur_input_start + cur_batch.partial_prefilling + input_ids = input_ids[cur_input_start:cur_input_end] + attention_mask = attention_mask[cur_input_start:cur_input_end] + tmp_past_key_values = _past_key_values + _past_key_values = None else: input_ids = None inputs_embeds = input + if cur_batch.partial_prefilling > 0: + cur_input_start = cur_batch.prefilled_index + cur_input_end = cur_input_start + cur_batch.partial_prefilling + attention_mask = attention_mask[cur_input_start:cur_input_end] + tmp_past_key_values = _past_key_values + _past_key_values = None + torch.xpu.empty_cache() output = self.model(input_ids=input_ids, inputs_embeds=inputs_embeds, @@ -417,20 +454,49 @@ def model_step(self, input, cur_batch): attention_mask=attention_mask, use_cache=True,) - if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - past_key_values_placeholder = tuple( - (value_placeholder, value_placeholder) for _ in range(layer_start) - ) + (output.past_key_values)[layer_start:] - _past_key_values = past_key_values_placeholder + if cur_batch.partial_prefilling > 0: + cur_batch.prefilled_index = cur_input_end + if tmp_past_key_values is None: + tmp_past_key_values = output.past_key_values + else: + num_layers = self.model.layer_end - self.model.layer_start + for layer_idx in range(num_layers): + tmp_past_key_values.key_cache[layer_idx] = \ + torch.cat([tmp_past_key_values.key_cache[layer_idx], + output.past_key_values.key_cache[layer_idx]], dim=0) + tmp_past_key_values.value_cache[layer_idx] = \ + torch.cat([tmp_past_key_values.value_cache[layer_idx], + output.past_key_values.value_cache[layer_idx]], dim=0) + + self.past_key_values_dict[cur_id] = tmp_past_key_values + + if self.pp_config.is_tail: + _pre_output = self.partial_output_dict.get(cur_id, None) + if _pre_output is None: + _pre_output = output.logits.to(self.dtype) + else: + _pre_output = torch.cat((_pre_output, output.logits.to(self.dtype)), dim=0) + self.partial_output_dict[cur_id] = _pre_output else: - _past_key_values = output.past_key_values - self.past_key_values_dict[cur_id] = _past_key_values + if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: + value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) + past_key_values_placeholder = tuple( + (value_placeholder, value_placeholder) for _ in range(layer_start) + ) + (output.past_key_values)[layer_start:] + _past_key_values = past_key_values_placeholder + else: + _past_key_values = output.past_key_values + self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize() if not self.pp_config.is_tail: - return output[0].to(self.dtype) + return output[0].to(self.dtype), cur_batch else: - return output.logits + if cur_batch.partial_prefilling > 0 and cur_batch.prefilled_index == cur_batch.batch_size: + _output = self.partial_output_dict.get(cur_id, None) + cur_batch.partial_prefilling = 0 + return _output, cur_batch + else: + return output.logits, cur_batch def is_initialized(self): return True @@ -458,6 +524,8 @@ async def add_request(self, tokenizer): input_len=input_ids.size(1), prompt_lengths=[sum(attention_mask[i, :]) for i in range(input_ids.size(0))], stopped=False, + prefilled_index=0, + partial_prefilling=0, ) self.input_ids_dict[new_batch.batch_id] = input_ids @@ -476,6 +544,7 @@ async def process_step(self, tokenizer, result_dict): if self.rank == 0: if self.send_buff is not None: + # logger.info(f"send {self.rank} {self.send_buff.shape}") dist.send(self.send_buff, dst=self.next_rank) if self.on_going_batches[0] is not None: @@ -493,84 +562,96 @@ async def process_step(self, tokenizer, result_dict): if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None): cur_id = cur_batch.batch_id - next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}', - dtype=torch.int64) - dist.recv(next_ids, src=self.pre_rank) + cur_batch = self.prepare_batch(cur_batch) + if cur_batch.partial_prefilling > 0: + next_ids = torch.empty((cur_batch.partial_prefilling, 1,), + device=f'xpu:{self.rank}', dtype=torch.int64) + else: + next_ids = torch.empty((cur_batch.batch_size, 1,), + device=f'xpu:{self.rank}', dtype=torch.int64) - if self.tokens.get(cur_id, None) is None: - self.tokens[cur_id] = [] - - if len(next_ids.shape) == 1: - next_ids = next_ids.unsqueeze(0) - self.tokens[cur_id].append(next_ids) - self.token_times[cur_id].append(time.perf_counter()) - cur_input = next_ids - cur_batch.input_len = 1 - cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] - - for index, request_id in enumerate(cur_batch.request_ids): - - if not self.is_finish.get(request_id, False): - remain = cur_batch.max_tokens - len(self.tokens[cur_id]) - - if self.streamer.get(request_id, None) is None: - self.streamer[request_id] = asyncio.Queue() - - # Currently ignore eos for benchmark - # if next_ids[index].int() == tokenizer.eos_token_id: - # remain = 0 - # self.is_finish[request_id] = True - - if self.token_cache.get(request_id, None) is None: - self.token_cache[request_id] = [] - self.print_len[request_id] = 0 - self.token_cache[request_id].extend(next_ids[index].tolist()) - - text = tokenizer.decode(self.token_cache[request_id]) - if text.endswith("\n"): - printable_text = text[self.print_len[request_id]:] - self.token_cache[request_id] = [] - self.print_len[request_id] = 0 - elif len(text) > 0 and _is_chinese_char(ord(text[-1])): - printable_text = text[self.print_len[request_id]:] - self.print_len[request_id] += len(printable_text) - else: - printable_text = text[self.print_len[request_id]: text.rfind(" ") + 1] - self.print_len[request_id] += len(printable_text) - - if remain > 0: - await self.streamer[request_id].put((remain, printable_text)) - else: - printable_text = printable_text + text[self.print_len[request_id]:] - self.token_cache.pop(request_id, None) - self.print_len.pop(request_id, None) - await self.streamer[request_id].put((remain, printable_text)) - - if len(self.tokens[cur_id]) >= cur_batch.max_tokens: - # Finish a batch - outputs = torch.cat(self.tokens[cur_id], dim=1) - outputs = outputs.cpu() - output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False) - for request_id, output_str in zip(cur_batch.request_ids, output_strs): - with self.dict_lock: - result_dict[request_id] = output_str - - cur_times = self.token_times[cur_id] - first_token = cur_times[1] - cur_times[0] - next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1) - logger.info(f"First token latency: {first_token}, " - f"next token latency: {next_token}") - self.clear_batch(cur_id) - cur_batch.stopped = True + # logger.info(f"recv {self.rank} {next_ids.shape}") + dist.recv(next_ids, src=self.pre_rank) + + if cur_batch.partial_prefilling > 0: + cur_input = self.input_ids_dict[cur_batch.batch_id] + else: + if self.tokens.get(cur_id, None) is None: + self.tokens[cur_id] = [] + + if len(next_ids.shape) == 1: + next_ids = next_ids.unsqueeze(0) + self.tokens[cur_id].append(next_ids) + self.token_times[cur_id].append(time.perf_counter()) + cur_input = next_ids + cur_batch.input_len = 1 + cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] + + for index, request_id in enumerate(cur_batch.request_ids): + + if not self.is_finish.get(request_id, False): + remain = cur_batch.max_tokens - len(self.tokens[cur_id]) + + if self.streamer.get(request_id, None) is None: + self.streamer[request_id] = asyncio.Queue() + + # Currently ignore eos for benchmark + # if next_ids[index].int() == tokenizer.eos_token_id: + # remain = 0 + # self.is_finish[request_id] = True + + if self.token_cache.get(request_id, None) is None: + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + self.token_cache[request_id].extend(next_ids[index].tolist()) + + text = tokenizer.decode(self.token_cache[request_id]) + if text.endswith("\n"): + printable_text = text[self.print_len[request_id]:] + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + elif len(text) > 0 and _is_chinese_char(ord(text[-1])): + printable_text = text[self.print_len[request_id]:] + self.print_len[request_id] += len(printable_text) + else: + printable_text = text[self.print_len[request_id]: text.rfind(" ") + 1] + self.print_len[request_id] += len(printable_text) + + if remain > 0: + await self.streamer[request_id].put((remain, printable_text)) + else: + printable_text = printable_text + text[self.print_len[request_id]:] + self.token_cache.pop(request_id, None) + self.print_len.pop(request_id, None) + await self.streamer[request_id].put((remain, printable_text)) + + if len(self.tokens[cur_id]) >= cur_batch.max_tokens: + # Finish a batch + outputs = torch.cat(self.tokens[cur_id], dim=1) + outputs = outputs.cpu() + output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False) + for request_id, output_str in zip(cur_batch.request_ids, output_strs): + with self.dict_lock: + result_dict[request_id] = output_str + + cur_times = self.token_times[cur_id] + first_token = cur_times[1] - cur_times[0] + next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1) + logger.info(f"First token latency: {first_token}, " + f"next token latency: {next_token}") + self.clear_batch(cur_id) + cur_batch.stopped = True else: if (cur_batch is not None) and cur_batch.stopped: cur_batch = None if cur_batch is not None: + cur_batch = self.prepare_batch(cur_batch) dist.broadcast_object_list([cur_batch], src=0) else: if self.send_buff is not None: + # logger.info(f"send {self.rank} {self.send_buff.shape}") dist.send(self.send_buff, dst=self.next_rank) batch_list = [None] @@ -583,12 +664,18 @@ async def process_step(self, tokenizer, result_dict): if cur_batch.stopped: self.clear_batch(cur_batch.batch_id) else: + cur_batch = self.prepare_batch(cur_batch) cur_len = cur_batch.input_len - cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), - device=f'xpu:{self.rank}', dtype=self.dtype) + if cur_batch.partial_prefilling: + cur_input = torch.empty((cur_batch.partial_prefilling, cur_len, self.hidden_size,), + device=f'xpu:{self.rank}', dtype=self.dtype) + else: + cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), + device=f'xpu:{self.rank}', dtype=self.dtype) + # logger.info(f"recv {self.rank} {cur_input.shape}") dist.recv(cur_input, src=self.pre_rank) - output = self.model_step(cur_input, cur_batch) + output, cur_batch = self.model_step(cur_input, cur_batch) if output is not None and self.rank == self.world_size - 1: output = torch.argmax(output[:, -1:, :], dim=-1) From bf788c3c3af476fb42e7215fe717147252e701a1 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Fri, 28 Jun 2024 09:43:00 +0800 Subject: [PATCH 02/10] refine --- .../transformers/pipeline_parallel.py | 37 +++++++++++-------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index bc5971aecc4..30b7f652084 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -337,6 +337,7 @@ class BatchTask(BaseModel): partial_prefilling: int + def make_attention_mask(prompt_lengths): max_length = max(prompt_lengths) attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64) @@ -376,6 +377,8 @@ def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_pref self.print_len = {} self.is_finish = {} self.model_name = checkpoint + # self.layer_start = 0 + # self.layer_end = 0 self.max_prefilled_seqs = max_prefilled_seqs self.partial_output_dict = {} @@ -459,14 +462,17 @@ def model_step(self, input, cur_batch): if tmp_past_key_values is None: tmp_past_key_values = output.past_key_values else: - num_layers = self.model.layer_end - self.model.layer_start - for layer_idx in range(num_layers): - tmp_past_key_values.key_cache[layer_idx] = \ - torch.cat([tmp_past_key_values.key_cache[layer_idx], - output.past_key_values.key_cache[layer_idx]], dim=0) - tmp_past_key_values.value_cache[layer_idx] = \ - torch.cat([tmp_past_key_values.value_cache[layer_idx], - output.past_key_values.value_cache[layer_idx]], dim=0) + if self.model.config.model_type in ["baichuan", "chatglm"]: + tmp_past_key_values = torch.cat((tmp_past_key_values, output.past_key_values), dim=0) + else: + num_layers = self.model.layer_end - self.model.layer_start + for layer_idx in range(num_layers): + tmp_past_key_values.key_cache[layer_idx] = \ + torch.cat([tmp_past_key_values.key_cache[layer_idx], + output.past_key_values.key_cache[layer_idx]], dim=0) + tmp_past_key_values.value_cache[layer_idx] = \ + torch.cat([tmp_past_key_values.value_cache[layer_idx], + output.past_key_values.value_cache[layer_idx]], dim=0) self.past_key_values_dict[cur_id] = tmp_past_key_values @@ -539,6 +545,9 @@ def clear_batch(self, cur_id): self.token_times.pop(cur_id, None) self.past_key_values_dict.pop(cur_id, None) + self.is_finish.pop(cur_id, None) + self.partial_output_dict.pop(cur_id, None) + async def process_step(self, tokenizer, result_dict): cur_batch = None @@ -564,11 +573,9 @@ async def process_step(self, tokenizer, result_dict): cur_id = cur_batch.batch_id cur_batch = self.prepare_batch(cur_batch) if cur_batch.partial_prefilling > 0: - next_ids = torch.empty((cur_batch.partial_prefilling, 1,), - device=f'xpu:{self.rank}', dtype=torch.int64) + next_ids = torch.empty((cur_batch.partial_prefilling, 1,), device=f'xpu:{self.rank}', dtype=torch.int64) else: - next_ids = torch.empty((cur_batch.batch_size, 1,), - device=f'xpu:{self.rank}', dtype=torch.int64) + next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}', dtype=torch.int64) # logger.info(f"recv {self.rank} {next_ids.shape}") dist.recv(next_ids, src=self.pre_rank) @@ -667,11 +674,9 @@ async def process_step(self, tokenizer, result_dict): cur_batch = self.prepare_batch(cur_batch) cur_len = cur_batch.input_len if cur_batch.partial_prefilling: - cur_input = torch.empty((cur_batch.partial_prefilling, cur_len, self.hidden_size,), - device=f'xpu:{self.rank}', dtype=self.dtype) + cur_input = torch.empty((cur_batch.partial_prefilling, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype) else: - cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), - device=f'xpu:{self.rank}', dtype=self.dtype) + cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype) # logger.info(f"recv {self.rank} {cur_input.shape}") dist.recv(cur_input, src=self.pre_rank) From d1af961af5d983450706238cb571b0ebb384f446 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Fri, 28 Jun 2024 15:25:17 +0800 Subject: [PATCH 03/10] add support for chatglm2/3 --- .../ipex_llm/transformers/models/chatglm2.py | 1 + .../transformers/pipeline_parallel.py | 76 ++++++++++++++----- 2 files changed, 59 insertions(+), 18 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/models/chatglm2.py b/python/llm/src/ipex_llm/transformers/models/chatglm2.py index 2bff252150d..352aaea38b9 100644 --- a/python/llm/src/ipex_llm/transformers/models/chatglm2.py +++ b/python/llm/src/ipex_llm/transformers/models/chatglm2.py @@ -80,6 +80,7 @@ def chatglm2_model_forward( else: inputs_embeds = inputs_embeds.transpose(0, 1).contiguous() seq_length, batch_size, _ = inputs_embeds.shape + input_ids = torch.empty((batch_size, seq_length), device=inputs_embeds.device) if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or ( diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 30b7f652084..b1aec56e975 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -417,7 +417,41 @@ def prepare_batch(self, cur_batch): cur_batch.partial_prefilling = 0 return cur_batch - + + + def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): + if model_type in ["baichuan", "chatglm"]: + result = [] + for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2): + if sub_tuple1 is None: + sub_result = [sub_tuple2] + elif sub_tuple2 is None: + sub_result = [sub_tuple1] + else: + sub_result = [] + for t1, t2 in zip(sub_tuple1, sub_tuple2): + if t1 is None: + sub_result.append(t2) + elif t2 is None: + sub_result.append(t1) + else: + if model_type == "chatglm": + sub_result.append(torch.cat((t1, t2), dim=1)) + else: + sub_result.append(torch.cat((t1, t2), dim=0)) + result.append(tuple(sub_result)) + return tuple(result) + else: + num_layers = self.model.layer_end - self.model.layer_start + for layer_idx in range(num_layers): + kv_cache_1.key_cache[layer_idx] = \ + torch.cat([kv_cache_1.key_cache[layer_idx], + kv_cache_2.key_cache[layer_idx]], dim=0) + kv_cache_1.value_cache[layer_idx] = \ + torch.cat([kv_cache_1.value_cache[layer_idx], + kv_cache_2.value_cache[layer_idx]], dim=0) + + return kv_cache_1 @torch.no_grad() def model_step(self, input, cur_batch): if cur_batch is None or cur_batch.stopped or input is None: @@ -462,26 +496,31 @@ def model_step(self, input, cur_batch): if tmp_past_key_values is None: tmp_past_key_values = output.past_key_values else: - if self.model.config.model_type in ["baichuan", "chatglm"]: - tmp_past_key_values = torch.cat((tmp_past_key_values, output.past_key_values), dim=0) - else: - num_layers = self.model.layer_end - self.model.layer_start - for layer_idx in range(num_layers): - tmp_past_key_values.key_cache[layer_idx] = \ - torch.cat([tmp_past_key_values.key_cache[layer_idx], - output.past_key_values.key_cache[layer_idx]], dim=0) - tmp_past_key_values.value_cache[layer_idx] = \ - torch.cat([tmp_past_key_values.value_cache[layer_idx], - output.past_key_values.value_cache[layer_idx]], dim=0) + tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type, tmp_past_key_values, output.past_key_values) + torch.xpu.empty_cache() + + if cur_batch.prefilled_index == cur_batch.batch_size: + if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: + value_placeholder = torch.empty_like((tmp_past_key_values)[-1][0]) + tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \ + tuple(None for _ in range(layer_start)) + \ + (tmp_past_key_values)[layer_start:] + + # past_key_values_placeholder = tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_start) + # ) + (output.past_key_values)[layer_start:] + # _past_key_values = past_key_values_placeholder self.past_key_values_dict[cur_id] = tmp_past_key_values if self.pp_config.is_tail: _pre_output = self.partial_output_dict.get(cur_id, None) + tmp_output = output.logits.to(self.dtype) + tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1) if _pre_output is None: - _pre_output = output.logits.to(self.dtype) + _pre_output = tmp_output else: - _pre_output = torch.cat((_pre_output, output.logits.to(self.dtype)), dim=0) + _pre_output = torch.cat((_pre_output, tmp_output), dim=0) self.partial_output_dict[cur_id] = _pre_output else: if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: @@ -498,11 +537,12 @@ def model_step(self, input, cur_batch): return output[0].to(self.dtype), cur_batch else: if cur_batch.partial_prefilling > 0 and cur_batch.prefilled_index == cur_batch.batch_size: - _output = self.partial_output_dict.get(cur_id, None) + _output = self.partial_output_dict.pop(cur_id, None) cur_batch.partial_prefilling = 0 return _output, cur_batch else: - return output.logits, cur_batch + _output = torch.argmax(output.logits[:, -1:, :], dim=-1) + return _output, cur_batch def is_initialized(self): return True @@ -681,8 +721,8 @@ async def process_step(self, tokenizer, result_dict): dist.recv(cur_input, src=self.pre_rank) output, cur_batch = self.model_step(cur_input, cur_batch) - if output is not None and self.rank == self.world_size - 1: - output = torch.argmax(output[:, -1:, :], dim=-1) + # if output is not None and self.rank == self.world_size - 1: + # output = torch.argmax(output[:, -1:, :], dim=-1) if output is not None: # dist.send(output, dst=self.next_rank) From be7cc05f88e7b432f1526fe4bdb3c0d76ba843ef Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Mon, 1 Jul 2024 16:11:43 +0800 Subject: [PATCH 04/10] fix --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index b1aec56e975..cec1a90e39f 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -611,7 +611,9 @@ async def process_step(self, tokenizer, result_dict): if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None): cur_id = cur_batch.batch_id - cur_batch = self.prepare_batch(cur_batch) + # cur_batch = self.prepare_batch(cur_batch) + if cur_batch.prefilled_index >= cur_batch.batch_size: + cur_batch.partial_prefilling = 0 if cur_batch.partial_prefilling > 0: next_ids = torch.empty((cur_batch.partial_prefilling, 1,), device=f'xpu:{self.rank}', dtype=torch.int64) else: From 5627c4d33af9c014a0b44389a242fc368c770355 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Tue, 2 Jul 2024 09:36:23 +0800 Subject: [PATCH 05/10] format --- .../transformers/pipeline_parallel.py | 44 ++++++++++++------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index cec1a90e39f..d033fbb2110 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -337,7 +337,6 @@ class BatchTask(BaseModel): partial_prefilling: int - def make_attention_mask(prompt_lengths): max_length = max(prompt_lengths) attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64) @@ -404,7 +403,6 @@ def load_model(self, model_path, world_size, low_bit='sym_int4'): model = model.eval() return model - def prepare_batch(self, cur_batch): if self.rank == 0: cur_input_start = cur_batch.prefilled_index @@ -415,9 +413,8 @@ def prepare_batch(self, cur_batch): cur_batch.partial_prefilling = cur_input_end - cur_input_start else: cur_batch.partial_prefilling = 0 - + return cur_batch - def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): if model_type in ["baichuan", "chatglm"]: @@ -446,12 +443,13 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): for layer_idx in range(num_layers): kv_cache_1.key_cache[layer_idx] = \ torch.cat([kv_cache_1.key_cache[layer_idx], - kv_cache_2.key_cache[layer_idx]], dim=0) + kv_cache_2.key_cache[layer_idx]], dim=0) kv_cache_1.value_cache[layer_idx] = \ torch.cat([kv_cache_1.value_cache[layer_idx], - kv_cache_2.value_cache[layer_idx]], dim=0) + kv_cache_2.value_cache[layer_idx]], dim=0) return kv_cache_1 + @torch.no_grad() def model_step(self, input, cur_batch): if cur_batch is None or cur_batch.stopped or input is None: @@ -465,7 +463,7 @@ def model_step(self, input, cur_batch): if self.rank == 0: input_ids = input inputs_embeds = None - + if cur_batch.partial_prefilling > 0: cur_input_start = cur_batch.prefilled_index cur_input_end = cur_input_start + cur_batch.partial_prefilling @@ -496,7 +494,9 @@ def model_step(self, input, cur_batch): if tmp_past_key_values is None: tmp_past_key_values = output.past_key_values else: - tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type, tmp_past_key_values, output.past_key_values) + tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type, + tmp_past_key_values, + output.past_key_values) torch.xpu.empty_cache() if cur_batch.prefilled_index == cur_batch.batch_size: @@ -505,7 +505,7 @@ def model_step(self, input, cur_batch): tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \ tuple(None for _ in range(layer_start)) + \ (tmp_past_key_values)[layer_start:] - + # past_key_values_placeholder = tuple( # (value_placeholder, value_placeholder) for _ in range(layer_start) # ) + (output.past_key_values)[layer_start:] @@ -536,7 +536,8 @@ def model_step(self, input, cur_batch): if not self.pp_config.is_tail: return output[0].to(self.dtype), cur_batch else: - if cur_batch.partial_prefilling > 0 and cur_batch.prefilled_index == cur_batch.batch_size: + if cur_batch.partial_prefilling > 0 and \ + cur_batch.prefilled_index == cur_batch.batch_size: _output = self.partial_output_dict.pop(cur_id, None) cur_batch.partial_prefilling = 0 return _output, cur_batch @@ -615,13 +616,15 @@ async def process_step(self, tokenizer, result_dict): if cur_batch.prefilled_index >= cur_batch.batch_size: cur_batch.partial_prefilling = 0 if cur_batch.partial_prefilling > 0: - next_ids = torch.empty((cur_batch.partial_prefilling, 1,), device=f'xpu:{self.rank}', dtype=torch.int64) + next_ids = torch.empty((cur_batch.partial_prefilling, 1,), + device=f'xpu:{self.rank}', dtype=torch.int64) else: - next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}', dtype=torch.int64) + next_ids = torch.empty((cur_batch.batch_size, 1,), + device=f'xpu:{self.rank}', dtype=torch.int64) # logger.info(f"recv {self.rank} {next_ids.shape}") dist.recv(next_ids, src=self.pre_rank) - + if cur_batch.partial_prefilling > 0: cur_input = self.input_ids_dict[cur_batch.batch_id] else: @@ -663,7 +666,8 @@ async def process_step(self, tokenizer, result_dict): printable_text = text[self.print_len[request_id]:] self.print_len[request_id] += len(printable_text) else: - printable_text = text[self.print_len[request_id]: text.rfind(" ") + 1] + r_index = text.rfind(" ") + 1 + printable_text = text[self.print_len[request_id]: r_index] self.print_len[request_id] += len(printable_text) if remain > 0: @@ -716,9 +720,17 @@ async def process_step(self, tokenizer, result_dict): cur_batch = self.prepare_batch(cur_batch) cur_len = cur_batch.input_len if cur_batch.partial_prefilling: - cur_input = torch.empty((cur_batch.partial_prefilling, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype) + cur_input = torch.empty( + (cur_batch.partial_prefilling, cur_len, self.hidden_size,), + device=f'xpu:{self.rank}', + dtype=self.dtype, + ) else: - cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype) + cur_input = torch.empty( + (cur_batch.batch_size, cur_len, self.hidden_size,), + device=f'xpu:{self.rank}', + dtype=self.dtype, + ) # logger.info(f"recv {self.rank} {cur_input.shape}") dist.recv(cur_input, src=self.pre_rank) From 7ffeb75a17f013a030eb0c34b9e178396d584b8f Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Wed, 3 Jul 2024 14:28:35 +0800 Subject: [PATCH 06/10] refine --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index d033fbb2110..93fd1c9e9b3 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -482,7 +482,7 @@ def model_step(self, input, cur_batch): tmp_past_key_values = _past_key_values _past_key_values = None - torch.xpu.empty_cache() + # torch.xpu.empty_cache() output = self.model(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=_past_key_values, @@ -497,7 +497,7 @@ def model_step(self, input, cur_batch): tmp_past_key_values = self.cat_kv_cache(self.model.config.model_type, tmp_past_key_values, output.past_key_values) - torch.xpu.empty_cache() + # torch.xpu.empty_cache() if cur_batch.prefilled_index == cur_batch.batch_size: if self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: @@ -603,6 +603,7 @@ async def process_step(self, tokenizer, result_dict): if cur_batch is None: if not self.waiting_requests.empty(): + # wait more requests to be put in self.waiting_requests await asyncio.sleep(0.01) cur_batch = await self.add_request(tokenizer) cur_input = self.input_ids_dict[cur_batch.batch_id] From cf1e9e20e9ff9a50faaa0b46ec9f569ed68c56b8 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 4 Jul 2024 09:46:53 +0800 Subject: [PATCH 07/10] refine --- .../transformers/pipeline_parallel.py | 119 +++++++++++------- 1 file changed, 76 insertions(+), 43 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 38911777ecf..9b066235cd9 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -154,6 +154,7 @@ def pipeline_parallel(model, pipeline_parallel_stages): model.pipeline_parallel_stages = pipeline_parallel_stages model.layer_start = layer_start model.layer_end = layer_end + model.num_layers = num_layers model = model.to(f'xpu:{local_rank}') return model @@ -457,8 +458,8 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): result.append(tuple(sub_result)) return tuple(result) else: - num_layers = self.model.layer_end - self.model.layer_start - for layer_idx in range(num_layers): + # num_layers = self.model.layer_end - self.model.layer_start + for layer_idx in range(self.model.num_layers): kv_cache_1.key_cache[layer_idx] = \ torch.cat([kv_cache_1.key_cache[layer_idx], kv_cache_2.key_cache[layer_idx]], dim=0) @@ -467,6 +468,40 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): kv_cache_2.value_cache[layer_idx]], dim=0) return kv_cache_1 + + + def update_kv_cache(self, kv_cache, cur_id): + layer_start = self.model.layer_start + layer_end = self.model.layer_end + num_layers = self.model.num_layers + + if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: + # for glm-4-9b-chat + if self.past_key_values_dict.get(cur_id, None) is None: + value_placeholder = torch.empty_like((kv_cache)[-1][0]) + past_key_values_placeholder = tuple( + (value_placeholder, value_placeholder) for _ in range(layer_start) + ) + (kv_cache)[:layer_end - layer_start] + tuple( + (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) + ) + kv_cache = past_key_values_placeholder + else: + pass + elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: + value_placeholder = torch.empty_like((kv_cache)[-1][0]) + kv_cache = tuple((value_placeholder, value_placeholder)) + \ + tuple(None for _ in range(layer_start)) + \ + (kv_cache)[layer_start:] + + # past_key_values_placeholder = tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_start) + # ) + (kv_cache)[layer_start:] + # kv_cache = past_key_values_placeholder + else: + pass + + return kv_cache + @torch.no_grad() def model_step(self, input, cur_batch): @@ -518,28 +553,25 @@ def model_step(self, input, cur_batch): # torch.xpu.empty_cache() if cur_batch.prefilled_index == cur_batch.batch_size: - if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: - # for glm-4-9b-chat - if self.past_key_values_dict.get(cur_id, None) is None: - value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - past_key_values_placeholder = tuple( - (value_placeholder, value_placeholder) for _ in range(layer_start) - ) + (output.past_key_values)[: layer_end - layer_start] + tuple( - (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) - ) - _past_key_values = past_key_values_placeholder - else: - _past_key_values = output.past_key_values - elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - value_placeholder = torch.empty_like((tmp_past_key_values)[-1][0]) - tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \ - tuple(None for _ in range(layer_start)) + \ - (tmp_past_key_values)[layer_start:] - - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (output.past_key_values)[layer_start:] - # _past_key_values = past_key_values_placeholder + tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id) + # # TODO: remove reduntent code here + # if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: + # # for glm-4-9b-chat + # if self.past_key_values_dict.get(cur_id, None) is None: + # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) + # past_key_values_placeholder = tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_start) + # ) + (output.past_key_values)[: layer_end - layer_start] + tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) + # ) + # _past_key_values = past_key_values_placeholder + # else: + # _past_key_values = output.past_key_values + # elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: + # value_placeholder = torch.empty_like((tmp_past_key_values)[-1][0]) + # tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \ + # tuple(None for _ in range(layer_start)) + \ + # (tmp_past_key_values)[layer_start:] self.past_key_values_dict[cur_id] = tmp_past_key_values @@ -553,25 +585,26 @@ def model_step(self, input, cur_batch): _pre_output = torch.cat((_pre_output, tmp_output), dim=0) self.partial_output_dict[cur_id] = _pre_output else: - if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: - # for glm-4-9b-chat - if self.past_key_values_dict.get(cur_id, None) is None: - value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - past_key_values_placeholder = tuple( - (value_placeholder, value_placeholder) for _ in range(layer_start) - ) + (output.past_key_values)[: layer_end - layer_start] + tuple( - (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) - ) - _past_key_values = past_key_values_placeholder - else: - _past_key_values = output.past_key_values - elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - # for baichuan2 and chatglm3 - value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - past_key_values_placeholder = tuple( - (value_placeholder, value_placeholder) for _ in range(layer_start) - ) + (output.past_key_values)[layer_start:] - _past_key_values = past_key_values_placeholder + # if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: + # # for glm-4-9b-chat + # if self.past_key_values_dict.get(cur_id, None) is None: + # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) + # past_key_values_placeholder = tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_start) + # ) + (output.past_key_values)[: layer_end - layer_start] + tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) + # ) + # _past_key_values = past_key_values_placeholder + # else: + # _past_key_values = output.past_key_values + # elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: + # # for baichuan2 and chatglm3 + # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) + # past_key_values_placeholder = tuple( + # (value_placeholder, value_placeholder) for _ in range(layer_start) + # ) + (output.past_key_values)[layer_start:] + # _past_key_values = past_key_values_placeholder + _past_key_values = self.update_kv_cache(_past_key_values, cur_id) self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize() if not self.pp_config.is_tail: From fdc87dfaba22c7b14177329257ff2899c5d820c5 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 4 Jul 2024 10:03:46 +0800 Subject: [PATCH 08/10] format --- .../transformers/pipeline_parallel.py | 42 +------------------ 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 9b066235cd9..7ba4cea1ae4 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -468,13 +468,12 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): kv_cache_2.value_cache[layer_idx]], dim=0) return kv_cache_1 - def update_kv_cache(self, kv_cache, cur_id): layer_start = self.model.layer_start layer_end = self.model.layer_end num_layers = self.model.num_layers - + if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: # for glm-4-9b-chat if self.past_key_values_dict.get(cur_id, None) is None: @@ -492,7 +491,6 @@ def update_kv_cache(self, kv_cache, cur_id): kv_cache = tuple((value_placeholder, value_placeholder)) + \ tuple(None for _ in range(layer_start)) + \ (kv_cache)[layer_start:] - # past_key_values_placeholder = tuple( # (value_placeholder, value_placeholder) for _ in range(layer_start) # ) + (kv_cache)[layer_start:] @@ -502,7 +500,6 @@ def update_kv_cache(self, kv_cache, cur_id): return kv_cache - @torch.no_grad() def model_step(self, input, cur_batch): if cur_batch is None or cur_batch.stopped or input is None: @@ -554,24 +551,6 @@ def model_step(self, input, cur_batch): if cur_batch.prefilled_index == cur_batch.batch_size: tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id) - # # TODO: remove reduntent code here - # if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: - # # for glm-4-9b-chat - # if self.past_key_values_dict.get(cur_id, None) is None: - # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (output.past_key_values)[: layer_end - layer_start] + tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) - # ) - # _past_key_values = past_key_values_placeholder - # else: - # _past_key_values = output.past_key_values - # elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - # value_placeholder = torch.empty_like((tmp_past_key_values)[-1][0]) - # tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \ - # tuple(None for _ in range(layer_start)) + \ - # (tmp_past_key_values)[layer_start:] self.past_key_values_dict[cur_id] = tmp_past_key_values @@ -585,25 +564,6 @@ def model_step(self, input, cur_batch): _pre_output = torch.cat((_pre_output, tmp_output), dim=0) self.partial_output_dict[cur_id] = _pre_output else: - # if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: - # # for glm-4-9b-chat - # if self.past_key_values_dict.get(cur_id, None) is None: - # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (output.past_key_values)[: layer_end - layer_start] + tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) - # ) - # _past_key_values = past_key_values_placeholder - # else: - # _past_key_values = output.past_key_values - # elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - # # for baichuan2 and chatglm3 - # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (output.past_key_values)[layer_start:] - # _past_key_values = past_key_values_placeholder _past_key_values = self.update_kv_cache(_past_key_values, cur_id) self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize() From 7e74090fc886e9e7328d7ce28fb901a3ce867e84 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 4 Jul 2024 13:48:53 +0800 Subject: [PATCH 09/10] fix --- python/llm/src/ipex_llm/transformers/pipeline_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 7ba4cea1ae4..72908d61876 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -564,7 +564,7 @@ def model_step(self, input, cur_batch): _pre_output = torch.cat((_pre_output, tmp_output), dim=0) self.partial_output_dict[cur_id] = _pre_output else: - _past_key_values = self.update_kv_cache(_past_key_values, cur_id) + _past_key_values = self.update_kv_cache(output.past_key_values, cur_id) self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize() if not self.pp_config.is_tail: From 6e74538b7f3073e98693a7b842bb47c12a2f33bc Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Fri, 5 Jul 2024 11:11:32 +0800 Subject: [PATCH 10/10] refine readme --- python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md | 6 +++++- python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh | 6 ++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md index b01e9282758..a18816ab72f 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md @@ -32,8 +32,12 @@ pip install transformers==4.37.0 bash run.sh ``` -> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. +### Command Line Arguments in `run.sh` +> Note: INT4 optimization is applied to the model by default. You could specify other low bit optimizations (such as 'fp8' and 'fp6') through `--low-bit`. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine. Other relative settings are listed below: +- `--low-bit`: Sets the low bit optimizations (such as 'sym_int4', 'fp16', 'fp8' and 'fp6') for the model. +- `--max-num-seqs`: Sets the maximum batch size on a single card during pipeline parallel serving. +- `--max-prefilled-seqs`: Sets the maximum batch size for prefilled sequences. Use `0` to disable partial prefetching and process all requests in a single batch. ### 3. Sample Input and Output diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh index d1bbacf4acd..91bf6161bff 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh @@ -24,11 +24,13 @@ source $basekit_root/setvars.sh --force source $basekit_root/ccl/latest/env/vars.sh --force export USE_XETLA=OFF -export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 +if [[ $KERNEL_VERSION != *"6.5"* ]]; then + export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +fi export TORCH_LLM_ALLREDUCE=0 export MODEL_PATH=YOUR_MODEL_PATH export NUM_GPUS=2 export IPEX_LLM_QUANTIZE_KV_CACHE=1 -CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 +CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4 --max-prefilled-seqs 0