Skip to content

Commit

Permalink
LLM: Refine Pipeline Parallel FastAPI (#11587)
Browse files Browse the repository at this point in the history
Refine Pipeline Parallel FastAPI
  • Loading branch information
xiangyuT authored Jul 22, 2024
1 parent 4d56ef5 commit 060792a
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,11 +462,15 @@ class BatchTask(BaseModel):
partial_prefilling: int


def make_attention_mask(prompt_lengths):
def make_attention_mask(prompt_lengths, device):
max_length = max(prompt_lengths)
attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64)
for i, length in enumerate(prompt_lengths):
attention_mask[i, max_length - length:] = 1
batch_size = len(prompt_lengths)

range_tensor = torch.arange(max_length, device=device).expand(batch_size, max_length)
prompt_lengths_tensor = torch.tensor(prompt_lengths, device=device).unsqueeze(1)
attention_mask = range_tensor >= max_length - prompt_lengths_tensor
attention_mask = attention_mask.to(torch.int64)

return attention_mask


Expand Down Expand Up @@ -501,6 +505,8 @@ def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs, max_pref
self.print_len = {}
self.is_finish = {}
self.model_name = checkpoint

self.device = f"xpu:{self.rank}"
# self.layer_start = 0
# self.layer_end = 0

Expand Down Expand Up @@ -543,7 +549,7 @@ def prepare_batch(self, cur_batch):
return cur_batch

def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
if model_type in ["baichuan", "chatglm"]:
if model_type in ["baichuan", "chatglm", "mixtral"]:
result = []
for sub_tuple1, sub_tuple2 in zip(kv_cache_1, kv_cache_2):
if sub_tuple1 is None:
Expand Down Expand Up @@ -613,7 +619,7 @@ def model_step(self, input, cur_batch):
# logger.info(f"{self.rank} {cur_batch} {input.shape}")
cur_id = cur_batch.batch_id
_past_key_values = self.past_key_values_dict.get(cur_id, None)
attention_mask = make_attention_mask(cur_batch.prompt_lengths).to(input.device)
attention_mask = make_attention_mask(cur_batch.prompt_lengths, input.device)

if self.rank == 0:
input_ids = input
Expand All @@ -637,7 +643,7 @@ def model_step(self, input, cur_batch):
tmp_past_key_values = _past_key_values
_past_key_values = None

# torch.xpu.empty_cache()
torch.xpu.empty_cache()
output = self.model(input_ids=input_ids,
inputs_embeds=inputs_embeds,
past_key_values=_past_key_values,
Expand All @@ -661,7 +667,7 @@ def model_step(self, input, cur_batch):

if self.pp_config.is_tail:
_pre_output = self.partial_output_dict.get(cur_id, None)
tmp_output = output.logits.to(self.dtype)
tmp_output = output.logits
tmp_output = torch.argmax(tmp_output[:, -1:, :], dim=-1)
if _pre_output is None:
_pre_output = tmp_output
Expand All @@ -674,7 +680,10 @@ def model_step(self, input, cur_batch):
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
if not self.pp_config.is_tail:
return output[0].to(self.dtype), cur_batch
_output = output[0]
if _output.dtype != self.dtype:
_output = _output.to(self.dtype)
return _output, cur_batch
else:
if cur_batch.partial_prefilling > 0 and \
cur_batch.prefilled_index == cur_batch.batch_size:
Expand Down

0 comments on commit 060792a

Please sign in to comment.