Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add phi-3 model support for pipeline parallel inference #11334

Merged
merged 2 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)
- [microsoft/Phi-3-mini-4k-instruct](./run_phi3_arc_2_card.sh)
- [microsoft/Phi-3-medium-4k-instruct](./run_phi3_arc_2_card.sh)


## Example: Run pipeline parallel inference on multiple GPUs

Expand Down Expand Up @@ -81,6 +84,22 @@ bash run_baichuan2_arc_2_card.sh

</details>

</details>

<details>
<summary> Show Phi3 example </summary>

#### Run Phi-3-mini-4k-instruct / Phi-3-medium-4k-instruct on two Intel Arc A770

You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Phi3 to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.

```bash
pip install transformers==4.37.0
bash run_phi3_arc_2_card.sh
```

</details>

### 3. Sample Output
#### [meta-llama/Llama-2-13b-chat-hf](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf)
```log
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

source /opt/intel/oneapi/setvars.sh
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=9090
export FI_PROVIDER=tcp
export USE_XETLA=OFF
export OMP_NUM_THREADS=6
export IPEX_LLM_QUANTIZE_KV_CACHE=1
if [[ $KERNEL_VERSION != *"6.5"* ]]; then
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
fi
export TORCH_LLM_ALLREDUCE=0

NUM_GPUS=2 # number of used GPU

# To run Phi-3-medium-4k-instruct
CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
generate.py --repo-id-or-model-path 'microsoft/Phi-3-medium-4k-instruct' --gpu-num $NUM_GPUS

# # To run Phi-3-mini-4k-instruct
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'microsoft/Phi-3-mini-4k-instruct' --gpu-num $NUM_GPUS
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/models/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def model_forward(
):
# IPEX-LLM OPT: kv cache and quantize kv cache and sdp
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input_ids)
input = input_ids if input_ids is not None else inputs_embeds
use_quantize_kv = use_quantize_kv_cache(self.layers[0].mlp.down_proj, input)
if use_cache:
if use_quantize_kv and not isinstance(past_key_values, DynamicFp8Cache):
past_key_values = DynamicFp8Cache.from_legacy_cache(past_key_values)
Expand Down
1 change: 1 addition & 0 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(self, *args):
# to avoid AttributeError in https://github.com/intel-analytics/ipex-llm/blob/main/
# python/llm/src/ipex_llm/transformers/models/llama.py#L119
self.up_proj = DummyLayer()
self.down_proj = DummyLayer()

def forward(self, x):
return x
Expand Down
Loading