Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support finishing PP inference once eos_token_id is found #11336

Merged
merged 1 commit into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/llm/example/GPU/Pipeline-Parallel-Inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ To run this example with IPEX-LLM on Intel GPUs, we have some recommended requir
- [meta-llama/Meta-Llama-3-8B-Instruct](./run_llama_arc_2_card.sh)
- [Qwen/Qwen1.5-7B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-14B-Chat](./run_qwen1.5_arc_2_card.sh)
- [Qwen/Qwen1.5-32B-Chat](./run_qwen1.5_arc_2_card.sh)
- [baichuan-inc/Baichuan2-7B-Chat](./run_baichuan2_arc_2_card.sh)
- [baichuan-inc/Baichuan2-13B-Chat](./run_baichuan2_arc_2_card.sh)

Expand Down Expand Up @@ -54,7 +55,7 @@ bash run_llama_arc_2_card.sh
<details>
<summary> Show Qwen1.5 example </summary>

#### Run Qwen1.5-7B-Chat / Qwen1.5-14B-Chat on two Intel Arc A770
#### Run Qwen1.5-7B-Chat / Qwen1.5-14B-Chat / Qwen1.5-32B-Chat on two Intel Arc A770

You could specify `--repo-id-or-model-path` in the test script to be the huggingface repo id for Qwen1.5 to be downloaded, or the path to the huggingface checkpoint folder. Besides, you could change `NUM_GPUS` to the number of GPUs you have on your machine.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
optimize_model=True,
trust_remote_code=True,
use_cache=True,
torch_dtype=torch.float16,
pipeline_parallel_stages=args.gpu_num)

# Load tokenizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $N
# # To run Qwen1.5-14B-Chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-14B-Chat' --gpu-num $NUM_GPUS

# # To run Qwen1.5-32B-Chat
# CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS \
# generate.py --repo-id-or-model-path 'Qwen/Qwen1.5-32B-Chat' --gpu-num $NUM_GPUS
59 changes: 58 additions & 1 deletion python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import numpy as np
from typing import Callable, List, Optional
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from ipex_llm.utils.common import invalidInputError
import logging
logger = logging.getLogger(__name__)

# patch GenerationMixin.generate
from transformers import GenerationMixin
Expand Down Expand Up @@ -117,12 +120,34 @@ def generate(
**kwargs,
):
if hasattr(self, 'pipeline_parallel_stages') and self.pipeline_parallel_stages > 1:
# priority: `generation_config` argument > `model.generation_config`
if generation_config is None:
if (
self.generation_config._from_model_config
and self.generation_config._original_object_hash == hash(self.generation_config)
and self.config._has_non_default_generation_parameters()
):
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
self.generation_config = new_generation_config
generation_config = self.generation_config

if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, list):
eos_token_id = eos_token_id[0]
logger.warning("Setting `pad_token_id` to `eos_token_id`: "
f"{eos_token_id} for open-end generation.")
generation_config.pad_token_id = eos_token_id

if generation_config is not None and generation_config.max_new_tokens is not None:
max_new_tokens = generation_config.max_new_tokens
else:
max_new_tokens = kwargs.get("max_new_tokens", None)

return self.pipeline_parallel_generate(inputs=inputs,
max_new_tokens=max_new_tokens,)
max_new_tokens=max_new_tokens,
generation_config=generation_config,)

return original_generate(self,
inputs=inputs,
Expand All @@ -142,6 +167,7 @@ def generate(
def pipeline_parallel_generate(self,
inputs: Optional[torch.Tensor] = None,
max_new_tokens: int = 32,
generation_config: Optional[GenerationConfig] = None,
**kwargs):
local_rank = dist.get_rank()
pre_rank = (local_rank - 1) % self.pipeline_parallel_stages
Expand All @@ -153,12 +179,22 @@ def pipeline_parallel_generate(self,
self.first_token_time = 0
self.next_token_time = []

pad_token_id = generation_config.pad_token_id
eos_token_id = generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(inputs.device) \
if eos_token_id is not None else None

_input_ids = None
_past_key_values = None
bs = inputs.shape[0]
output_ids = inputs.clone()

step = 0
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(inputs.shape[0], dtype=torch.long, device=inputs.device)
this_peer_finished = False
while True:
if step >= max_new_tokens:
break
Expand Down Expand Up @@ -189,6 +225,14 @@ def pipeline_parallel_generate(self,
_input_ids = next_ids
output_ids = torch.cat([output_ids, next_ids], dim=-1)

# finished sentences should have their next token be a padding token
next_ids = next_ids.squeeze()
if eos_token_id is not None:
if pad_token_id is None:
invalidInputError(False, "If `eos_token_id` is defined, "
"make sure that `pad_token_id` is defined.")
next_ids = next_ids * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

if isinstance(outputs.past_key_values, tuple) and local_rank != 0:
value_placeholder = torch.empty_like((outputs.past_key_values)[-1][0])
past_key_values_placeholder = tuple(
Expand All @@ -203,6 +247,19 @@ def pipeline_parallel_generate(self,
self.first_token_time = toc - tic
else:
self.next_token_time.append(toc - tic)

# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_ids.tile(eos_token_id_tensor.shape[0], 1)
.ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
if this_peer_finished:
break

step += 1
if self.device.type == 'xpu':
torch.xpu.synchronize()
Expand Down
Loading