diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md new file mode 100644 index 00000000000..f133dda5d0a --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md @@ -0,0 +1,33 @@ +# Serve IPEX-LLM on Multiple Intel GPUs in multi-stage pipeline parallel fashion + +This example demonstrates how to run IPEX-LLM serving on multiple [Intel GPUs](../README.md) with Pipeline Parallel. + +## Requirements + +To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine. + +## Example + +### 1. Install + +```bash +conda create -n llm python=3.11 +conda activate llm +# below command will install intel_extension_for_pytorch==2.1.10+xpu as default +pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ +# configures OneAPI environment variables +source /opt/intel/oneapi/setvars.sh +# pip install git+https://github.com/microsoft/DeepSpeed.git@ed8aed5 +# pip install git+https://github.com/intel/intel-extension-for-deepspeed.git@0eb734b +pip install mpi4py fastapi uvicorn +conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc +``` + +### 2. Run pipeline parallel serving on multiple GPUs + +```bash +# Need to set MODEL_PATH in run.sh first +bash run.sh +``` + diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py new file mode 100644 index 00000000000..4244a735ee9 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py @@ -0,0 +1,327 @@ +from transformers.modeling_utils import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRMSNorm, LlamaPreTrainedModel +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast + +from torch import nn +import torch +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from typing import List, Optional, Tuple, Union, Iterator +from transformers.utils import logging +logger = logging.get_logger(__name__) +import numpy as np +import time +from transformers import AutoTokenizer, AutoConfig +import torch.distributed as dist +from pipeline_models import ( + _make_causal_mask, _expand_mask, DummyLayer, PPConfig, + PipelineBaseModel, +) + + +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.config = config + + # pp modification + self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) + nr_slices = self.pp_config.pp_world_size + # self.config.num_hidden_layers = 8 + slice_size = (self.config.num_hidden_layers + nr_slices - + 1) // nr_slices + self.layer_start = slice_size * self.pp_config.pp_rank + self.layer_end = self.layer_start + min(slice_size, + self.config.num_hidden_layers - self.layer_start) + self.num_layers = self.layer_end - self.layer_start + layers = [] + for i in range(self.config.num_hidden_layers): + if i < self.layer_start or i >= self.layer_end: + layers.append(DummyLayer()) + else: + layers.append(LlamaDecoderLayer(config)) + self.layers = nn.ModuleList(layers) + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds for pp + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + assert self.pp_config.is_head, "input_ids is only supported on the head stage" + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage" + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx in range(self.num_layers): + decoder_layer = self.layers[self.layer_start + idx] + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.pp_config.is_tail: + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + + def __init__(self, config: LlamaConfig): + super().__init__(config=config) + self.config = config + self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) + self.model = LlamaModel(config) + self.pretraining_tp = config.pretraining_tp + self.vocab_size = config.vocab_size + if self.pp_config.is_tail: + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.pp_config.is_tail: + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + return outputs + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py new file mode 100644 index 00000000000..c96211e4604 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py @@ -0,0 +1,510 @@ +from torch import nn +import torch +import torch.distributed as dist +import intel_extension_for_pytorch as ipex + +from typing import List, Optional, Tuple, Union, Iterator +import time +from transformers import AutoTokenizer, AutoConfig +from transformers.utils import logging +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +import numpy as np +import asyncio, uuid +import threading + +logger = logging.get_logger(__name__) + + +class PPConfig: + """Configuration for ModelSlices.""" + + def __init__(self, pp_rank: int, pp_world_size: int) -> None: + self.pp_rank = pp_rank + self.pp_world_size = pp_world_size + self.is_head = self.pp_rank == 0 + self.is_tail = self.pp_rank == self.pp_world_size - 1 + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + +class DummyLayer(nn.Module): + pass + + +class PipelineBaseModel(nn.Module): + def __init__(self, config): + self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) + nr_slices = self.pp_config.pp_world_size + # self.config.num_hidden_layers = 8 + slice_size = (self.config.num_hidden_layers + nr_slices - + 1) // nr_slices + self.layer_start = slice_size * self.pp_config.pp_rank + self.layer_end = self.layer_start + min(slice_size, + self.config.num_hidden_layers - self.layer_start) + self.num_layers = self.layer_end - self.layer_start + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + assert self.pp_config.is_head, "input_ids is only supported on the head stage" + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage" + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx in range(self.num_layers): + decoder_layer = self.layers[self.layer_start + idx] + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if self.pp_config.is_tail: + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +def load_model(checkpoint): + from llama_models import LlamaForCausalLM + if 'llama' in checkpoint.lower(): + model = LlamaForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.float16) + return model + +from pydantic import BaseModel +class BatchTask(BaseModel): + batch_id: str + request_ids: List[str] + max_tokens: int + batch_size: int + input_len: int + # plain_texts: List[str] + prompt_lengths: List[int] + stopped: bool + # input_ids: torch.Tensor + # attention_mask: torch.Tensor + + +def make_attention_mask(prompt_lengths): + max_length = max(prompt_lengths) + attention_mask = torch.zeros((len(prompt_lengths), max_length), dtype=torch.int64) + for i, length in enumerate(prompt_lengths): + attention_mask[i, max_length - length:] = 1 + return attention_mask + +class ModelRunner: + + def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs): + + import sys + self.pp_config = PPConfig(rank, world_size) + + start = time.perf_counter() + model = load_model(checkpoint) + end = time.perf_counter() + logger.info(f"Time to load weights: {end - start:.2f}s") + from ipex_llm import optimize_model + + model = optimize_model(model, low_bit=low_bit) + + model = model.to(torch.float16).to(f'xpu:{rank}') + self.model = model + self.rank = rank + self.world_size = world_size + self.pre_rank = (self.rank - 1) % self.world_size + self.next_rank = (self.rank + 1) % self.world_size + self.hidden_size = self.model.config.hidden_size + + self.max_num_seqs = max_num_seqs + self.on_going_batches = [None] * self.world_size + self.input_ids_dict = {} + # self.attention_mask_dict = {} + self.past_key_values_dict = {} + self.tokens = {} + self.token_times = {} + self.dtype = torch.float16 + + self.waiting_requests = asyncio.Queue() + self.send_buff = None + self.dict_lock = threading.Lock() + + + # def generate(self, input_ids=None, max_tokens=5, attention_mask=None): + # times = [] + # with torch.no_grad(): + # _input_ids = None + # _past_key_values = None + # bs = input_ids.shape[0] + # output_ids = input_ids.clone() + # for i in range(max_tokens): + # start = time.perf_counter() + # if _input_ids is None: + # _input_ids = input_ids + # if self.rank == 0: + # outputs = self.model(input_ids=_input_ids, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True) + # else: + # inputs_embeds = torch.empty(_input_ids.shape + (self.hidden_size,) , device=f'xpu:{self.rank}', dtype=torch.float32) + # dist.recv(inputs_embeds, src=self.pre_rank) + # outputs = self.model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=_past_key_values, use_cache=True) + + # if self.rank == self.world_size - 1: + # logits = outputs.logits + # next_ids = torch.argmax(logits[:, -1:, :], dim=-1) + # assert next_ids.shape == (bs, 1) + # dist.broadcast(next_ids, src=self.rank) + # else: + # dist.send(outputs.last_hidden_state, dst=self.next_rank) + # next_ids = torch.empty((bs, 1), device=f'xpu:{self.rank}', dtype=torch.int64) + # dist.broadcast(next_ids, src=self.world_size - 1) + + # _input_ids = next_ids + # output_ids = torch.cat([output_ids, next_ids], dim=-1) + # _past_key_values = outputs.past_key_values + # end = time.perf_counter() + # times.append(end - start) + + # if self.rank == 0: + # logger.info(f"first token latency: {times[0]}, rest token avg latecy: {np.mean(times[1:])}") + # return output_ids + + + def model_step(self, input, cur_batch): + if cur_batch is None or cur_batch.stopped or input is None: + return None + + cur_id = cur_batch.batch_id + _past_key_values = self.past_key_values_dict.get(cur_id, None) + # attention_mask = self.attention_mask_dict[cur_id] + attention_mask = make_attention_mask(cur_batch.prompt_lengths) + + if self.rank == 0: + input_ids = input + inputs_embeds = None + else: + input_ids = None + inputs_embeds = input + output = self.model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=_past_key_values, + use_cache=True + ) + self.past_key_values_dict[cur_id] = output.past_key_values + if not self.pp_config.is_tail: + return output.last_hidden_state + else: + # logger.info(f"logits: {output.logits.shape}") + return output.logits + + + def is_initialized(self): + return True + + + async def add_request(self, tokenizer): + request_ids, prompt_requests = [], [] + for _ in range(self.max_num_seqs): + if self.waiting_requests.empty(): + break + + tmp_result = await self.waiting_requests.get() + # logger.info(tmp_result) + request_id, prompt_request = tmp_result + request_ids.append(request_id) + prompt_requests.append(prompt_request) + + plain_texts = [req.prompt for req in prompt_requests] + inputs = tokenizer(plain_texts, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(f'xpu:{self.rank}') + attention_mask = inputs.attention_mask.to(f'xpu:{self.rank}') + new_batch = BatchTask( + batch_id="batch_" + str(uuid.uuid4()), + request_ids=request_ids, + max_tokens=max([req.n_predict for req in prompt_requests]), + batch_size=input_ids.size(0), + input_len=input_ids.size(1), + prompt_lengths=[sum(attention_mask[i,:]) for i in range(input_ids.size(0))], + stopped=False, + # plain_texts=plain_texts, + # input_ids=input_ids, + # attention_mask=attention_mask, + ) + + self.input_ids_dict[new_batch.batch_id] = input_ids + self.token_times[new_batch.batch_id] = [time.perf_counter()] + # self.attention_mask_dict[new_batch.batch_id] = attention_mask + + return new_batch + + + def clear_batch(self, cur_id): + self.input_ids_dict.pop(cur_id, None) + self.tokens.pop(cur_id, None) + self.token_times.pop(cur_id, None) + # self.attention_mask_dict.pop(cur_id, None) + self.past_key_values_dict.pop(cur_id, None) + # torch.xpu.empty_cache() + + + async def process_step(self, tokenizer, result_dict): + cur_batch = None + + if self.rank == 0: + if self.on_going_batches[0] is not None: + cur_batch = self.on_going_batches[0] + cur_input = None + + if cur_batch is None: + if not self.waiting_requests.empty(): + # await asyncio.sleep(0.01) + cur_batch = await self.add_request(tokenizer) + cur_input = self.input_ids_dict[cur_batch.batch_id] + else: + cur_batch = None + cur_input = None + + if (cur_batch is not None) and (not cur_batch.stopped) and (cur_input is None): + cur_id = cur_batch.batch_id + next_ids = torch.empty((cur_batch.batch_size, 1,), device=f'xpu:{self.rank}', dtype=torch.int64) + # logger.info(f"rank: {self.rank}, recv: {next_ids.shape}") + dist.recv(next_ids, src=self.pre_rank) + + if self.tokens.get(cur_id, None) is None: + self.tokens[cur_id] = [] + + if len(next_ids.shape) == 1: + next_ids = next_ids.unsqueeze(0) + self.tokens[cur_id].append(next_ids) + self.token_times[cur_id].append(time.perf_counter()) + # self.input_ids_dict[cur_id] += next_ids + cur_input = next_ids + # cur_batch.input_len += 1 + cur_batch.input_len = 1 + cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] + if len(self.tokens[cur_id]) >= cur_batch.max_tokens: + # Finish a batch + # logger.info(self.tokens[cur_id]) + outputs = torch.cat(self.tokens[cur_id], dim=1) + outputs = outputs.cpu() + output_strs = tokenizer.batch_decode(outputs, skip_special_tokens=False) + for request_id, output_str in zip(cur_batch.request_ids, output_strs): + with self.dict_lock: + result_dict[request_id] = output_str + + cur_times = self.token_times[cur_id] + first_token = cur_times[1] - cur_times[0] + next_token = (cur_times[-1] - cur_times[1]) / (len(self.tokens[cur_id]) - 1) + logger.info(f"First token latency: {first_token}, next token latency: {next_token}") + self.clear_batch(cur_id) + cur_batch.stopped = True + else: + if (cur_batch is not None) and cur_batch.stopped: + cur_batch = None + + if self.send_buff is not None: + # logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}") + dist.send(self.send_buff, dst=self.next_rank) + dist.broadcast_object_list([cur_batch], src=0) + + else: + batch_list = [None] + dist.broadcast_object_list(batch_list, src=0) + + cur_batch = batch_list[0] + cur_input = None + + if self.send_buff is not None: + # logger.info(f"rank: {self.rank}, send: {self.send_buff.shape}") + dist.send(self.send_buff, dst=self.next_rank) + + if cur_batch is not None: + if cur_batch.stopped: + self.clear_batch(cur_batch.batch_id) + else: + cur_len = cur_batch.input_len + cur_input = torch.empty((cur_batch.batch_size, cur_len, self.hidden_size,), device=f'xpu:{self.rank}', dtype=self.dtype) + # logger.info(f"rank: {self.rank}, recv: {cur_input.shape}") + dist.recv(cur_input, src=self.pre_rank) + + # if self.attention_mask_dict.get(cur_batch.batch_id, None) is None: + # self.attention_mask_dict[cur_batch.batch_id] = make_attention_mask(cur_batch.prompt_lengths) + + # if self.rank == 0: + # logger.info(f"rank: {self.rank}, {batch_list}") + + output = self.model_step(cur_input, cur_batch) + if output is not None and self.rank == self.world_size - 1: + output = torch.argmax(output[:, -1:, :], dim=-1) + + if output is not None: + # dist.send(output, dst=self.next_rank) + self.send_buff = output + else: + self.send_buff = None + if self.rank == 0: + self.on_going_batches[:-1] = self.on_going_batches[1:] + self.on_going_batches[self.world_size - 1] = cur_batch + diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py new file mode 100644 index 00000000000..a73b5294660 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py @@ -0,0 +1,148 @@ +from pipeline_models import ModelRunner +import torch.nn.parallel +import torch.distributed as dist +import os +import intel_extension_for_pytorch as ipex + +import oneccl_bindings_for_pytorch + +from transformers.utils import logging +logger = logging.get_logger(__name__) + +os.environ['MASTER_ADDR'] = '127.0.0.1' +os.environ['MASTER_PORT'] = '29501' + +backend = 'ccl' +dist.init_process_group(backend) +my_rank = dist.get_rank() +my_size = dist.get_world_size() +device = f"xpu:{my_rank}" +logger.info(f"rank: {my_rank}, size: {my_size}") + +import time +from transformers import AutoTokenizer, AutoConfig, LlamaTokenizer +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +import uvicorn +import asyncio, uuid +from typing import Dict, List, Optional +import argparse + +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return int(default) + + +class PromptRequest(BaseModel): + prompt: str + n_predict: int = 32 + + +empty_req = PromptRequest(prompt="", n_predict=0) + +app = FastAPI() +global tokenizer +global local_model + +request_queue: asyncio.Queue = asyncio.Queue() +result_dict: Dict[str, str] = {} +local_rank = my_rank +max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16") + + +@app.post("/generate/") +async def generate(prompt_request: PromptRequest): + request_id = str(uuid.uuid4()) + await local_model.waiting_requests.put((request_id, prompt_request)) + while True: + if request_id in result_dict: + with local_model.dict_lock: + output_str = result_dict[request_id] + if len(output_str) == 0: + logger.info(f"Why? {request_id}") + # await asyncio.sleep(0.1) + # continue + result_dict.pop(request_id) + return {"generated_text": output_str} + await asyncio.sleep(0) + + +def generate_text(prompt: List[str], n_predict = 32): + while prompt[-1] == "": + prompt = prompt[:-1] + if isinstance(n_predict, list): + n_predict = max(n_predict) + + inputs = tokenizer(prompt, return_tensors="pt", padding=True) + input_ids = inputs.input_ids.to(f'xpu:{local_rank}') + print(inputs) + attention_mask = inputs.attention_mask.to(f'xpu:{local_rank}') + output = local_model.generate(input_ids, + max_tokens=n_predict, + # attention_mask=attention_mask, + # max_new_tokens=n_predict, + # min_new_tokens=n_predict, + # do_sample=False, + # use_cache=True + ) + torch.xpu.synchronize() + + return output + + +async def process_requests(local_model, result_dict): + while True: + await asyncio.sleep(0) + await local_model.process_step(tokenizer, result_dict) + + +@app.on_event("startup") +async def startup_event(): + asyncio.create_task(process_requests(local_model, result_dict)) + +async def main(): + parser = argparse.ArgumentParser(description='Predict Tokens using fastapi by leveraging DeepSpeed-AutoTP') + parser.add_argument('--repo-id-or-model-path', type=str, default="meta-llama/Llama-2-7b-chat-hf", + help='The huggingface repo id for the Llama2 (e.g. `meta-llama/Llama-2-7b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `meta-llama/Llama-2-70b-chat-hf`) to be downloaded' + ', or the path to the huggingface checkpoint folder') + parser.add_argument('--low-bit', type=str, default='sym_int4', + help='The quantization type the model will convert to.') + parser.add_argument('--port', type=int, default=8000, + help='The port number on which the server will run.') + parser.add_argument('--max-num-seqs', type=int, default=8, + help='Max num sequences in a batch.') + + args = parser.parse_args() + model_path = args.repo_id_or_model_path + low_bit = args.low_bit + max_num_seqs = args.max_num_seqs + + # serialize model initialization so that we do not run out of CPU memory + for i in range(my_size): + if my_rank == i: + logger.info("start model initialization") + global local_model + local_model = ModelRunner(model_path, my_rank, my_size, low_bit, max_num_seqs) + logger.info("model initialized") + dist.barrier() + # Load tokenizer + global tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, padding_side='left') + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + if local_rank == 0: + config = uvicorn.Config(app=app, host="0.0.0.0", port=args.port) + server = uvicorn.Server(config) + await server.serve() + else: + while True: + await asyncio.sleep(0) + await local_model.process_step(tokenizer, result_dict) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh new file mode 100644 index 00000000000..1e55c9d80ed --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh @@ -0,0 +1,11 @@ +source /opt/intel/oneapi/setvars.sh +export no_proxy=localhost +export FI_PROVIDER=tcp +export OMP_NUM_THREADS=8 + +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 +export TORCH_LLM_ALLREDUCE=0 + +export MODEL_PATH=YOUR_MODEL_PATH +CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node 2 pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8