-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LLM: Add Pipeline-Parallel-FastAPI example (#10917)
Add multi-stage Pipeline-Parallel-FastAPI example --------- Co-authored-by: hzjane <[email protected]>
- Loading branch information
Showing
5 changed files
with
1,029 additions
and
0 deletions.
There are no files selected for viewing
33 changes: 33 additions & 0 deletions
33
python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# Serve IPEX-LLM on Multiple Intel GPUs in multi-stage pipeline parallel fashion | ||
|
||
This example demonstrates how to run IPEX-LLM serving on multiple [Intel GPUs](../README.md) with Pipeline Parallel. | ||
|
||
## Requirements | ||
|
||
To run this example with IPEX-LLM on Intel GPUs, we have some recommended requirements for your machine, please refer to [here](../README.md#recommended-requirements) for more information. For this particular example, you will need at least two GPUs on your machine. | ||
|
||
## Example | ||
|
||
### 1. Install | ||
|
||
```bash | ||
conda create -n llm python=3.11 | ||
conda activate llm | ||
# below command will install intel_extension_for_pytorch==2.1.10+xpu as default | ||
pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ | ||
pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ | ||
# configures OneAPI environment variables | ||
source /opt/intel/oneapi/setvars.sh | ||
# pip install git+https://github.com/microsoft/DeepSpeed.git@ed8aed5 | ||
# pip install git+https://github.com/intel/intel-extension-for-deepspeed.git@0eb734b | ||
pip install mpi4py fastapi uvicorn | ||
conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc | ||
``` | ||
|
||
### 2. Run pipeline parallel serving on multiple GPUs | ||
|
||
```bash | ||
# Need to set MODEL_PATH in run.sh first | ||
bash run.sh | ||
``` | ||
|
327 changes: 327 additions & 0 deletions
327
python/llm/example/GPU/Pipeline-Parallel-FastAPI/llama_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,327 @@ | ||
from transformers.modeling_utils import PreTrainedModel | ||
from transformers.models.llama.modeling_llama import LlamaConfig, LlamaDecoderLayer, LlamaRMSNorm, LlamaPreTrainedModel | ||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast | ||
|
||
from torch import nn | ||
import torch | ||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | ||
from typing import List, Optional, Tuple, Union, Iterator | ||
from transformers.utils import logging | ||
logger = logging.get_logger(__name__) | ||
import numpy as np | ||
import time | ||
from transformers import AutoTokenizer, AutoConfig | ||
import torch.distributed as dist | ||
from pipeline_models import ( | ||
_make_causal_mask, _expand_mask, DummyLayer, PPConfig, | ||
PipelineBaseModel, | ||
) | ||
|
||
|
||
class LlamaModel(LlamaPreTrainedModel): | ||
""" | ||
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] | ||
Args: | ||
config: LlamaConfig | ||
""" | ||
|
||
def __init__(self, config: LlamaConfig): | ||
super().__init__(config) | ||
self.config = config | ||
|
||
# pp modification | ||
self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) | ||
nr_slices = self.pp_config.pp_world_size | ||
# self.config.num_hidden_layers = 8 | ||
slice_size = (self.config.num_hidden_layers + nr_slices - | ||
1) // nr_slices | ||
self.layer_start = slice_size * self.pp_config.pp_rank | ||
self.layer_end = self.layer_start + min(slice_size, | ||
self.config.num_hidden_layers - self.layer_start) | ||
self.num_layers = self.layer_end - self.layer_start | ||
layers = [] | ||
for i in range(self.config.num_hidden_layers): | ||
if i < self.layer_start or i >= self.layer_end: | ||
layers.append(DummyLayer()) | ||
else: | ||
layers.append(LlamaDecoderLayer(config)) | ||
self.layers = nn.ModuleList(layers) | ||
|
||
self.padding_idx = config.pad_token_id | ||
self.vocab_size = config.vocab_size | ||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | ||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||
|
||
|
||
def get_input_embeddings(self): | ||
return self.embed_tokens | ||
|
||
def set_input_embeddings(self, value): | ||
self.embed_tokens = value | ||
|
||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask | ||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): | ||
# create causal mask | ||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
combined_attention_mask = None | ||
if input_shape[-1] > 1: | ||
combined_attention_mask = _make_causal_mask( | ||
input_shape, | ||
inputs_embeds.dtype, | ||
device=inputs_embeds.device, | ||
past_key_values_length=past_key_values_length, | ||
) | ||
|
||
if attention_mask is not None: | ||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | ||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( | ||
inputs_embeds.device | ||
) | ||
combined_attention_mask = ( | ||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask | ||
) | ||
|
||
return combined_attention_mask | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, BaseModelOutputWithPast]: | ||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
# retrieve input_ids and inputs_embeds for pp | ||
if input_ids is not None and inputs_embeds is not None: | ||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") | ||
elif input_ids is not None: | ||
assert self.pp_config.is_head, "input_ids is only supported on the head stage" | ||
batch_size, seq_length = input_ids.shape | ||
elif inputs_embeds is not None: | ||
assert not self.pp_config.is_head, "inputs_embeds is only supported on the tail stage" | ||
batch_size, seq_length, _ = inputs_embeds.shape | ||
else: | ||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") | ||
|
||
seq_length_with_past = seq_length | ||
past_key_values_length = 0 | ||
|
||
if past_key_values is not None: | ||
past_key_values_length = past_key_values[0][0].shape[2] | ||
seq_length_with_past = seq_length_with_past + past_key_values_length | ||
|
||
if position_ids is None: | ||
device = input_ids.device if input_ids is not None else inputs_embeds.device | ||
position_ids = torch.arange( | ||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device | ||
) | ||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | ||
else: | ||
position_ids = position_ids.view(-1, seq_length).long() | ||
|
||
if inputs_embeds is None: | ||
inputs_embeds = self.embed_tokens(input_ids) | ||
# embed positions | ||
if attention_mask is None: | ||
attention_mask = torch.ones( | ||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device | ||
) | ||
attention_mask = self._prepare_decoder_attention_mask( | ||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length | ||
) | ||
|
||
hidden_states = inputs_embeds | ||
|
||
# decoder layers | ||
all_hidden_states = () if output_hidden_states else None | ||
all_self_attns = () if output_attentions else None | ||
next_decoder_cache = () if use_cache else None | ||
|
||
for idx in range(self.num_layers): | ||
decoder_layer = self.layers[self.layer_start + idx] | ||
if output_hidden_states: | ||
all_hidden_states += (hidden_states,) | ||
|
||
past_key_value = past_key_values[idx] if past_key_values is not None else None | ||
|
||
layer_outputs = decoder_layer( | ||
hidden_states, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_value=past_key_value, | ||
output_attentions=output_attentions, | ||
use_cache=use_cache, | ||
) | ||
|
||
hidden_states = layer_outputs[0] | ||
|
||
if use_cache: | ||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | ||
|
||
if output_attentions: | ||
all_self_attns += (layer_outputs[1],) | ||
|
||
if self.pp_config.is_tail: | ||
hidden_states = self.norm(hidden_states) | ||
|
||
# add hidden states from the last decoder layer | ||
if output_hidden_states: | ||
all_hidden_states += (hidden_states,) | ||
|
||
next_cache = next_decoder_cache if use_cache else None | ||
if not return_dict: | ||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | ||
return BaseModelOutputWithPast( | ||
last_hidden_state=hidden_states, | ||
past_key_values=next_cache, | ||
hidden_states=all_hidden_states, | ||
attentions=all_self_attns, | ||
) | ||
|
||
|
||
class LlamaForCausalLM(LlamaPreTrainedModel): | ||
|
||
def __init__(self, config: LlamaConfig): | ||
super().__init__(config=config) | ||
self.config = config | ||
self.pp_config = PPConfig(pp_rank=dist.get_rank(), pp_world_size=dist.get_world_size()) | ||
self.model = LlamaModel(config) | ||
self.pretraining_tp = config.pretraining_tp | ||
self.vocab_size = config.vocab_size | ||
if self.pp_config.is_tail: | ||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | ||
|
||
def get_input_embeddings(self): | ||
return self.model.embed_tokens | ||
|
||
def set_input_embeddings(self, value): | ||
self.model.embed_tokens = value | ||
|
||
def get_output_embeddings(self): | ||
return self.lm_head | ||
|
||
def set_output_embeddings(self, new_embeddings): | ||
self.lm_head = new_embeddings | ||
|
||
def set_decoder(self, decoder): | ||
self.model = decoder | ||
|
||
def get_decoder(self): | ||
return self.model | ||
|
||
def forward( | ||
self, | ||
input_ids: torch.LongTensor = None, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_values: Optional[List[torch.FloatTensor]] = None, | ||
inputs_embeds: Optional[torch.FloatTensor] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
use_cache: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, CausalLMOutputWithPast]: | ||
|
||
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | ||
output_hidden_states = ( | ||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | ||
) | ||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | ||
|
||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | ||
outputs = self.model( | ||
input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
past_key_values=past_key_values, | ||
inputs_embeds=inputs_embeds, | ||
use_cache=use_cache, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
return_dict=return_dict, | ||
) | ||
|
||
if self.pp_config.is_tail: | ||
hidden_states = outputs[0] | ||
logits = self.lm_head(hidden_states) | ||
logits = logits.float() | ||
|
||
loss = None | ||
if labels is not None: | ||
# Shift so that tokens < n predict n | ||
shift_logits = logits[..., :-1, :].contiguous() | ||
shift_labels = labels[..., 1:].contiguous() | ||
# Flatten the tokens | ||
loss_fct = CrossEntropyLoss() | ||
shift_logits = shift_logits.view(-1, self.config.vocab_size) | ||
shift_labels = shift_labels.view(-1) | ||
# Enable model parallelism | ||
shift_labels = shift_labels.to(shift_logits.device) | ||
loss = loss_fct(shift_logits, shift_labels) | ||
|
||
if not return_dict: | ||
output = (logits,) + outputs[1:] | ||
return (loss,) + output if loss is not None else output | ||
|
||
return CausalLMOutputWithPast( | ||
loss=loss, | ||
logits=logits, | ||
past_key_values=outputs.past_key_values, | ||
hidden_states=outputs.hidden_states, | ||
attentions=outputs.attentions, | ||
) | ||
return outputs | ||
|
||
def prepare_inputs_for_generation( | ||
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs | ||
): | ||
if past_key_values: | ||
input_ids = input_ids[:, -1:] | ||
|
||
position_ids = kwargs.get("position_ids", None) | ||
if attention_mask is not None and position_ids is None: | ||
# create position_ids on the fly for batch generation | ||
position_ids = attention_mask.long().cumsum(-1) - 1 | ||
position_ids.masked_fill_(attention_mask == 0, 1) | ||
if past_key_values: | ||
position_ids = position_ids[:, -1].unsqueeze(-1) | ||
|
||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | ||
if inputs_embeds is not None and past_key_values is None: | ||
model_inputs = {"inputs_embeds": inputs_embeds} | ||
else: | ||
model_inputs = {"input_ids": input_ids} | ||
|
||
model_inputs.update( | ||
{ | ||
"position_ids": position_ids, | ||
"past_key_values": past_key_values, | ||
"use_cache": kwargs.get("use_cache"), | ||
"attention_mask": attention_mask, | ||
} | ||
) | ||
return model_inputs | ||
|
||
@staticmethod | ||
def _reorder_cache(past_key_values, beam_idx): | ||
reordered_past = () | ||
for layer_past in past_key_values: | ||
reordered_past += ( | ||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), | ||
) | ||
return reordered_past |
Oops, something went wrong.