Skip to content

Commit

Permalink
add generation infer
Browse files Browse the repository at this point in the history
  • Loading branch information
tastelikefeet committed Oct 31, 2024
1 parent bd60d9a commit 11ab4f0
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 58 deletions.
2 changes: 2 additions & 0 deletions swift/llm/argument/infer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class VllmArguments:
Args:
gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
tensor_parallel_size (int): Tensor parallelism size. Default is 1.
pipeline_parallel_size(int): Pipeline parallelism size. Default is 1.
max_num_seqs (int): Maximum number of sequences. Default is 256.
max_model_len (Optional[int]): Maximum model length. Default is None.
disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is True.
Expand All @@ -55,6 +56,7 @@ class VllmArguments:
# vllm
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
max_num_seqs: int = 256
max_model_len: Optional[int] = None
disable_custom_all_reduce: bool = True # Default values different from vllm
Expand Down
1 change: 1 addition & 0 deletions swift/llm/infer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_infer_engine(self) -> InferEngine:
kwargs.update({
'gpu_memory_utilization': args.gpu_memory_utilization,
'tensor_parallel_size': args.tensor_parallel_size,
'pipeline_parallel_size': args.pipeline_parallel_size,
'max_num_seqs': args.max_num_seqs,
'max_model_len': args.max_model_len,
'disable_custom_all_reduce': args.disable_custom_all_reduce,
Expand Down
67 changes: 42 additions & 25 deletions swift/llm/infer/infer_engine/pt_engine.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import json
import time
from copy import deepcopy
from dataclasses import dataclass
from threading import Thread
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union

import json
import torch
from PIL import Image
from tqdm import tqdm
from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
from transformers import GenerationConfig, LogitsProcessorList
from transformers.utils import is_torch_npu_available

from swift.llm import Template, TemplateMeta, to_device
from swift.plugin import Metric
from swift.tuners import Swift
from swift.utils import get_logger
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, InferRequest, RequestConfig,
random_uuid)
from .infer_engine import InferEngine
from .utils import InferStreamer, InferTools, LogitsStreamer, StopWordsCriteria, TokensIteratorStreamer
from .utils import InferStreamer, LogitsStreamer, TokensIteratorStreamer
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, ImagesResponse, DeltaMessage, InferRequest,
RequestConfig,
random_uuid, MultiModalRequestMixin, ImageObject)

logger = get_logger()

Expand Down Expand Up @@ -167,7 +170,7 @@ def _infer_stream(
if lora_request is not None:
kwargs['adapter_names'] = self._get_adapter_names(lora_request)
num_prompt_tokens = self._get_num_tokens(inputs)
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, generation_config.stop_words)])
generation_props = template.prepare_for_generation(inputs, self.model)
if generation_config.num_beams != 1:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)
Expand All @@ -179,16 +182,18 @@ def _model_generate(*args, **kwargs):
torch.npu.set_device(self.model.device)
self.model.generate(*args, **kwargs)

logits_processors = LogitsProcessorList(generation_props.logits_processors)
logits_streamer = None
if generation_config.output_logits:
logits_streamer = LogitsStreamer()
kwargs['logits_processor'] = LogitsProcessorList([logits_streamer])
logits_processors.append(logits_streamer)
kwargs['logits_processor'] = logits_processors

thread = Thread(
target=_model_generate,
kwargs={
'generation_config': generation_config,
'stopping_criteria': stopping_criteria,
'stopping_criteria': generation_props.criterias,
'streamer': streamer,
**inputs,
**kwargs
Expand Down Expand Up @@ -276,16 +281,20 @@ def _infer_full(self,
inputs: Dict[str, Any],
generation_config: GenerationConfig,
*,
lora_request: Optional[PtLoRARequest] = None) -> List[ChatCompletionResponse]:
lora_request: Optional[PtLoRARequest] = None) -> Union[List[ChatCompletionResponse], List[ImagesResponse]]:
# bos_token TODO: encoder-decoder
kwargs = {}
if lora_request is not None:
kwargs['adapter_names'] = self._get_adapter_names(lora_request)
num_prompt_tokens = self._get_num_tokens(inputs)
stopping_criteria = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, generation_config.stop_words)])
generation_props = template.prepare_for_generation(inputs, self.model)
output = dict(
self.model.generate(
generation_config=generation_config, stopping_criteria=stopping_criteria, **inputs, **kwargs))
generation_config=generation_config,
stopping_criteria=generation_props.criterias,
**inputs,
logits_processor=generation_props.logits_processors,
**kwargs))
batched_generate_ids = output['sequences']
batched_generate_ids = template.get_generate_ids(batched_generate_ids, num_prompt_tokens)
batched_logprobs = self.preprocess_logits(
Expand All @@ -304,17 +313,25 @@ def _infer_full(self,

logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids, generation_config.top_logprobs)
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
response = InferTools.safe_decode(template, generate_ids, True)

toolcall = self._get_toolcall(response, True)
choices = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=None,
logprobs=logprobs)
]
res.append(ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info))
response = template.safe_decode(generate_ids, True)
if isinstance(response, str):
toolcall = self._get_toolcall(response, True)
choices = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=None,
logprobs=logprobs)
]
res.append(ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info))
elif isinstance(response, Image.Image):
res.append(ImagesResponse(
created=time.time(),
data=[ImageObject(
b64_json=MultiModalRequestMixin._to_base64(response)
)]
))

return res

@torch.inference_mode()
Expand Down Expand Up @@ -350,7 +367,7 @@ def _infer(
*,
template: Optional[Template] = None,
lora_request: Optional[PtLoRARequest] = None,
) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
) -> Union[List[ChatCompletionResponse], List[ImagesResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
self.model.eval()
request_config = deepcopy(request_config or RequestConfig())
if template is None:
Expand Down Expand Up @@ -390,7 +407,7 @@ def infer(
template: Optional[Template] = None,
use_tqdm: Optional[bool] = None,
lora_request: Optional[PtLoRARequest] = None
) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
) -> Union[List[ChatCompletionResponse], List[ImagesResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
if use_tqdm is None:
use_tqdm = request_config is None or not request_config.stream
prog_bar = tqdm(total=len(infer_requests), dynamic_ncols=True, disable=not use_tqdm)
Expand Down
30 changes: 0 additions & 30 deletions swift/llm/infer/infer_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,6 @@ def _is_chinese_char(cp: int) -> bool:

return False

@staticmethod
def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
len_tokens = len(stop_tokens)
if is_finished and generate_ids[-len_tokens:] == stop_tokens:
return generate_ids[:-len_tokens]
if not is_finished:
for i in range(len_tokens, 0, -1):
if generate_ids[-i:] == stop_tokens[:i]:
return generate_ids[:-i]
return generate_ids

@staticmethod
def safe_decode(template: Template, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> str:
# Do not print template_meta.suffix[-1] and eos_token.
tokenizer = template.tokenizer

if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
generate_ids = generate_ids[:-1]
# skip suffix and eos_token
template_suffix = template.template_meta.suffix[-1]
if isinstance(template_suffix, str):
template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)
generate_ids = InferTools._skip_stop_tokens(generate_ids, template_suffix, is_finished)
return tokenizer.decode(generate_ids, **decode_kwargs)
# if not is_finished or is_finished and response[-len_suffix:] == template_suffix:
# # To avoid response length being shorter than previous response length during streaming.
# # TODO:check
# # idx = max(len(response) - len_suffix, 0, self.print_idx)
# response = response[:-len_suffix]


class InferStreamer(InferTools):

Expand Down
13 changes: 13 additions & 0 deletions swift/llm/infer/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,13 @@ class ChatMessage:
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None


@dataclass
class ImageObject:
b64_json: Optional[str] = None
revised_prompt: Optional[str] = None
url: Optional[str] = None


@dataclass
class ChatCompletionResponseChoice:
index: int
Expand Down Expand Up @@ -239,6 +246,12 @@ def to_cmpl_response(self) -> 'CompletionResponse':
return CompletionResponse(self.model, choices, deepcopy(self.usage), id_, created=self.created)


@dataclass
class ImagesResponse:
created: int

data: List[ImageObject]

@dataclass
class CompletionResponse:
model: str
Expand Down
36 changes: 35 additions & 1 deletion swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .agent import loss_scale_map, split_str_parts_by
from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
from .utils import Context, ContextType, Prompt, Word, fetch_one, findall
from .utils import Context, ContextType, Prompt, Word, fetch_one, findall, GenerationProperty, StopWordsCriteria
from .vision_utils import load_batch, load_image, normalize_bbox, rescale_image

logger = get_logger()
Expand Down Expand Up @@ -169,6 +169,40 @@ def encode(
def post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {}

@staticmethod
def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
len_tokens = len(stop_tokens)
if is_finished and generate_ids[-len_tokens:] == stop_tokens:
return generate_ids[:-len_tokens]
if not is_finished:
for i in range(len_tokens, 0, -1):
if generate_ids[-i:] == stop_tokens[:i]:
return generate_ids[:-i]
return generate_ids

def safe_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any:
# Do not print template_meta.suffix[-1] and eos_token.
tokenizer = self.tokenizer

if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
generate_ids = generate_ids[:-1]
# skip suffix and eos_token
template_suffix = self.template_meta.suffix[-1]
if isinstance(template_suffix, str):
template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)
generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished)
return tokenizer.decode(generate_ids, **decode_kwargs)
# if not is_finished or is_finished and response[-len_suffix:] == template_suffix:
# # To avoid response length being shorter than previous response length during streaming.
# # TODO:check
# # idx = max(len(response) - len_suffix, 0, self.print_idx)
# response = response[:-len_suffix]

def prepare_for_generation(self, example, model) -> GenerationProperty:
return GenerationProperty(
criterias=StopWordsCriteria(self.tokenizer, model.generation_config.stop_words)
)

def _preprocess_objects(self, inputs: StdTemplateInputs, objects: List[Dict[str, Any]]):
# Load image into PIL format
images = inputs.images
Expand Down
Loading

0 comments on commit 11ab4f0

Please sign in to comment.