diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py
index 73b43f1c6..6f3f92ef8 100644
--- a/swift/llm/argument/infer_args.py
+++ b/swift/llm/argument/infer_args.py
@@ -43,6 +43,7 @@ class VllmArguments:
Args:
gpu_memory_utilization (float): GPU memory utilization. Default is 0.9.
tensor_parallel_size (int): Tensor parallelism size. Default is 1.
+ pipeline_parallel_size(int): Pipeline parallelism size. Default is 1.
max_num_seqs (int): Maximum number of sequences. Default is 256.
max_model_len (Optional[int]): Maximum model length. Default is None.
disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is True.
@@ -55,6 +56,7 @@ class VllmArguments:
# vllm
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
+ pipeline_parallel_size: int = 1
max_num_seqs: int = 256
max_model_len: Optional[int] = None
disable_custom_all_reduce: bool = True # Default values different from vllm
diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py
index 08477677d..a4cdfb7a4 100644
--- a/swift/llm/infer/infer.py
+++ b/swift/llm/infer/infer.py
@@ -92,6 +92,7 @@ def get_infer_engine(self) -> InferEngine:
kwargs.update({
'gpu_memory_utilization': args.gpu_memory_utilization,
'tensor_parallel_size': args.tensor_parallel_size,
+ 'pipeline_parallel_size': args.pipeline_parallel_size,
'max_num_seqs': args.max_num_seqs,
'max_model_len': args.max_model_len,
'disable_custom_all_reduce': args.disable_custom_all_reduce,
diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py
index 32f5e79d0..3b71c250d 100644
--- a/swift/llm/infer/infer_engine/pt_engine.py
+++ b/swift/llm/infer/infer_engine/pt_engine.py
@@ -1,25 +1,28 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
+import json
+import time
from copy import deepcopy
from dataclasses import dataclass
from threading import Thread
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union
-import json
import torch
+from PIL import Image
from tqdm import tqdm
-from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList
+from transformers import GenerationConfig, LogitsProcessorList
from transformers.utils import is_torch_npu_available
from swift.llm import Template, TemplateMeta, to_device
from swift.plugin import Metric
from swift.tuners import Swift
from swift.utils import get_logger
-from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
- ChatCompletionStreamResponse, ChatMessage, DeltaMessage, InferRequest, RequestConfig,
- random_uuid)
from .infer_engine import InferEngine
-from .utils import InferStreamer, InferTools, LogitsStreamer, StopWordsCriteria, TokensIteratorStreamer
+from .utils import InferStreamer, LogitsStreamer, TokensIteratorStreamer
+from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
+ ChatCompletionStreamResponse, ChatMessage, ImagesResponse, DeltaMessage, InferRequest,
+ RequestConfig,
+ random_uuid, MultiModalRequestMixin, ImageObject)
logger = get_logger()
@@ -167,7 +170,7 @@ def _infer_stream(
if lora_request is not None:
kwargs['adapter_names'] = self._get_adapter_names(lora_request)
num_prompt_tokens = self._get_num_tokens(inputs)
- stopping_criteria = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, generation_config.stop_words)])
+ generation_props = template.prepare_for_generation(inputs, self.model)
if generation_config.num_beams != 1:
error_msg = 'Streaming generation does not support beam search.'
raise ValueError(error_msg)
@@ -179,16 +182,18 @@ def _model_generate(*args, **kwargs):
torch.npu.set_device(self.model.device)
self.model.generate(*args, **kwargs)
+ logits_processors = LogitsProcessorList(generation_props.logits_processors)
logits_streamer = None
if generation_config.output_logits:
logits_streamer = LogitsStreamer()
- kwargs['logits_processor'] = LogitsProcessorList([logits_streamer])
+ logits_processors.append(logits_streamer)
+ kwargs['logits_processor'] = logits_processors
thread = Thread(
target=_model_generate,
kwargs={
'generation_config': generation_config,
- 'stopping_criteria': stopping_criteria,
+ 'stopping_criteria': generation_props.criterias,
'streamer': streamer,
**inputs,
**kwargs
@@ -276,16 +281,20 @@ def _infer_full(self,
inputs: Dict[str, Any],
generation_config: GenerationConfig,
*,
- lora_request: Optional[PtLoRARequest] = None) -> List[ChatCompletionResponse]:
+ lora_request: Optional[PtLoRARequest] = None) -> Union[List[ChatCompletionResponse], List[ImagesResponse]]:
# bos_token TODO: encoder-decoder
kwargs = {}
if lora_request is not None:
kwargs['adapter_names'] = self._get_adapter_names(lora_request)
num_prompt_tokens = self._get_num_tokens(inputs)
- stopping_criteria = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, generation_config.stop_words)])
+ generation_props = template.prepare_for_generation(inputs, self.model)
output = dict(
self.model.generate(
- generation_config=generation_config, stopping_criteria=stopping_criteria, **inputs, **kwargs))
+ generation_config=generation_config,
+ stopping_criteria=generation_props.criterias,
+ **inputs,
+ logits_processor=generation_props.logits_processors,
+ **kwargs))
batched_generate_ids = output['sequences']
batched_generate_ids = template.get_generate_ids(batched_generate_ids, num_prompt_tokens)
batched_logprobs = self.preprocess_logits(
@@ -304,17 +313,25 @@ def _infer_full(self,
logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids, generation_config.top_logprobs)
usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids))
- response = InferTools.safe_decode(template, generate_ids, True)
-
- toolcall = self._get_toolcall(response, True)
- choices = [
- ChatCompletionResponseChoice(
- index=0,
- message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
- finish_reason=None,
- logprobs=logprobs)
- ]
- res.append(ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info))
+ response = template.safe_decode(generate_ids, True)
+ if isinstance(response, str):
+ toolcall = self._get_toolcall(response, True)
+ choices = [
+ ChatCompletionResponseChoice(
+ index=0,
+ message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
+ finish_reason=None,
+ logprobs=logprobs)
+ ]
+ res.append(ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info))
+ elif isinstance(response, Image.Image):
+ res.append(ImagesResponse(
+ created=time.time(),
+ data=[ImageObject(
+ b64_json=MultiModalRequestMixin._to_base64(response)
+ )]
+ ))
+
return res
@torch.inference_mode()
@@ -350,7 +367,7 @@ def _infer(
*,
template: Optional[Template] = None,
lora_request: Optional[PtLoRARequest] = None,
- ) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
+ ) -> Union[List[ChatCompletionResponse], List[ImagesResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
self.model.eval()
request_config = deepcopy(request_config or RequestConfig())
if template is None:
@@ -390,7 +407,7 @@ def infer(
template: Optional[Template] = None,
use_tqdm: Optional[bool] = None,
lora_request: Optional[PtLoRARequest] = None
- ) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
+ ) -> Union[List[ChatCompletionResponse], List[ImagesResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]:
if use_tqdm is None:
use_tqdm = request_config is None or not request_config.stream
prog_bar = tqdm(total=len(infer_requests), dynamic_ncols=True, disable=not use_tqdm)
diff --git a/swift/llm/infer/infer_engine/utils.py b/swift/llm/infer/infer_engine/utils.py
index dd9d01858..84c8b0687 100644
--- a/swift/llm/infer/infer_engine/utils.py
+++ b/swift/llm/infer/infer_engine/utils.py
@@ -25,36 +25,6 @@ def _is_chinese_char(cp: int) -> bool:
return False
- @staticmethod
- def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
- len_tokens = len(stop_tokens)
- if is_finished and generate_ids[-len_tokens:] == stop_tokens:
- return generate_ids[:-len_tokens]
- if not is_finished:
- for i in range(len_tokens, 0, -1):
- if generate_ids[-i:] == stop_tokens[:i]:
- return generate_ids[:-i]
- return generate_ids
-
- @staticmethod
- def safe_decode(template: Template, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> str:
- # Do not print template_meta.suffix[-1] and eos_token.
- tokenizer = template.tokenizer
-
- if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
- generate_ids = generate_ids[:-1]
- # skip suffix and eos_token
- template_suffix = template.template_meta.suffix[-1]
- if isinstance(template_suffix, str):
- template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)
- generate_ids = InferTools._skip_stop_tokens(generate_ids, template_suffix, is_finished)
- return tokenizer.decode(generate_ids, **decode_kwargs)
- # if not is_finished or is_finished and response[-len_suffix:] == template_suffix:
- # # To avoid response length being shorter than previous response length during streaming.
- # # TODO:check
- # # idx = max(len(response) - len_suffix, 0, self.print_idx)
- # response = response[:-len_suffix]
-
class InferStreamer(InferTools):
diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py
index fef39ddad..6892db340 100644
--- a/swift/llm/infer/protocol.py
+++ b/swift/llm/infer/protocol.py
@@ -204,6 +204,13 @@ class ChatMessage:
tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
+@dataclass
+class ImageObject:
+ b64_json: Optional[str] = None
+ revised_prompt: Optional[str] = None
+ url: Optional[str] = None
+
+
@dataclass
class ChatCompletionResponseChoice:
index: int
@@ -239,6 +246,12 @@ def to_cmpl_response(self) -> 'CompletionResponse':
return CompletionResponse(self.model, choices, deepcopy(self.usage), id_, created=self.created)
+@dataclass
+class ImagesResponse:
+ created: int
+
+ data: List[ImageObject]
+
@dataclass
class CompletionResponse:
model: str
diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py
index 31759905c..50032761c 100644
--- a/swift/llm/template/base.py
+++ b/swift/llm/template/base.py
@@ -20,7 +20,7 @@
from .agent import loss_scale_map, split_str_parts_by
from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs
-from .utils import Context, ContextType, Prompt, Word, fetch_one, findall
+from .utils import Context, ContextType, Prompt, Word, fetch_one, findall, GenerationProperty, StopWordsCriteria
from .vision_utils import load_batch, load_image, normalize_bbox, rescale_image
logger = get_logger()
@@ -169,6 +169,40 @@ def encode(
def post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]:
return {}
+ @staticmethod
+ def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]:
+ len_tokens = len(stop_tokens)
+ if is_finished and generate_ids[-len_tokens:] == stop_tokens:
+ return generate_ids[:-len_tokens]
+ if not is_finished:
+ for i in range(len_tokens, 0, -1):
+ if generate_ids[-i:] == stop_tokens[:i]:
+ return generate_ids[:-i]
+ return generate_ids
+
+ def safe_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any:
+ # Do not print template_meta.suffix[-1] and eos_token.
+ tokenizer = self.tokenizer
+
+ if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id:
+ generate_ids = generate_ids[:-1]
+ # skip suffix and eos_token
+ template_suffix = self.template_meta.suffix[-1]
+ if isinstance(template_suffix, str):
+ template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False)
+ generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished)
+ return tokenizer.decode(generate_ids, **decode_kwargs)
+ # if not is_finished or is_finished and response[-len_suffix:] == template_suffix:
+ # # To avoid response length being shorter than previous response length during streaming.
+ # # TODO:check
+ # # idx = max(len(response) - len_suffix, 0, self.print_idx)
+ # response = response[:-len_suffix]
+
+ def prepare_for_generation(self, example, model) -> GenerationProperty:
+ return GenerationProperty(
+ criterias=StopWordsCriteria(self.tokenizer, model.generation_config.stop_words)
+ )
+
def _preprocess_objects(self, inputs: StdTemplateInputs, objects: List[Dict[str, Any]]):
# Load image into PIL format
images = inputs.images
diff --git a/swift/llm/template/template/template.py b/swift/llm/template/template/template.py
index b0b7330bb..03ad1985d 100644
--- a/swift/llm/template/template/template.py
+++ b/swift/llm/template/template/template.py
@@ -1,16 +1,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
+import time
+from datetime import datetime
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
import json
import torch
+from PIL import Image
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from swift.utils import get_logger, upper_bound
from ..base import Template
from ..constant import TemplateType
from ..register import register_template
-from ..utils import Context, align_image_inputs, findall
+from ..utils import Context, align_image_inputs, findall, GenerationProperty, StopWordsCriteria
from ..vision_utils import (load_batch, load_image, load_video_cogvlm2, load_video_internvl,
load_video_minicpmv_mplug_owl3, transform_image)
from .llama import Llama3Template, Llama3TemplateMixin
@@ -231,6 +234,123 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] =
Template(['[INST] '], ['{{SYSTEM}}\n\n', '{{QUERY}}[/INST]'], ['[INST] '], ['']))
+class Emu3GenTemplate(Template):
+
+ NULL_PROMPT_PROB = 0.1
+
+ COOKBOOK_SIZE = 32768
+
+ APPLY_LOSS_ON_ONLY_VISION = True
+ NEGATIVE_PROMPT = os.environ.get('NEGATIVE_PROMPT')
+
+ def _init_template(self, *args, **kwargs):
+ super()._init_template(*args, **kwargs)
+ self.bov = self.tokenizer.encode(self.tokenizer.processor.visual_template[0].format(token_id=0))[0]
+ self.eov = self.tokenizer.encode(self.tokenizer.processor.visual_template[0].format(token_id=self.COOKBOOK_SIZE - 1))[0]
+ self.config = kwargs.get('config')
+
+ def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ query = example['query']
+
+ kwargs = dict(
+ mode='U' if self._is_training else 'G',
+ ratio="1:1",
+ image_area=self.config.image_area,
+ return_tensors="pt",
+ padding="longest",
+ )
+
+ # image
+ raw_image = example.get('images', None)
+ inputs = self.tokenizer.processor(query, raw_image, **kwargs)
+ labels = inputs["input_ids"]
+ if self.APPLY_LOSS_ON_ONLY_VISION:
+ labels = torch.where(torch.logical_and(labels >= self.bov, labels <= self.eov), labels, -100)
+
+ inputs["labels"] = labels
+ for k, v in inputs.items():
+ inputs[k] = v.squeeze(0)
+ return inputs, {}
+
+ def prepare_for_output(self, output: str) -> str:
+ return output
+
+ def prepare_for_generation(self, example, model) -> GenerationProperty:
+ from transformers import UnbatchedClassifierFreeGuidanceLogitsProcessor
+ from transformers import PrefixConstrainedLogitsProcessor
+ from transformers import LogitsProcessorList
+
+ kwargs = dict(
+ mode='G',
+ ratio="1:1",
+ image_area=self.config.image_area,
+ return_tensors="pt",
+ padding="longest",
+ )
+ negative_prompt = self.NEGATIVE_PROMPT
+ if 'negative_prompt' in example:
+ negative_prompt = example['negative_prompt']
+
+ classifier_free_guidance = 3.0
+ h, w = self.tokenizer.processor.calculate_generate_size("1:1",
+ self.config.image_area,
+ self.tokenizer.processor.vision_tokenizer.spatial_scale_factor)
+ # h = pos_inputs.image_size[:, 0]
+ # w = pos_inputs.image_size[:, 1]
+ neg_inputs = self.tokenizer.processor(text=negative_prompt, **kwargs)
+ constrained_fn = self.tokenizer.processor.build_prefix_constrained_fn(h, w)
+ logits_processors = LogitsProcessorList([
+ UnbatchedClassifierFreeGuidanceLogitsProcessor(
+ classifier_free_guidance,
+ model,
+ unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
+ ),
+ PrefixConstrainedLogitsProcessor(
+ constrained_fn,
+ num_beams=1,
+ ),
+ ])
+
+ return GenerationProperty(
+ logits_processors=logits_processors,
+ criterias=StopWordsCriteria(self.tokenizer, model.generation_config.stop_words)
+ )
+
+ def safe_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Image.Image:
+ mm_list = self.tokenizer.processor.decode(generate_ids)
+ for idx, im in enumerate(mm_list):
+ if not isinstance(im, Image.Image):
+ continue
+ return im
+
+ def format_image_prompt(self, image_tokens):
+ h, w = image_tokens.shape
+ imgstr = self.tokenizer.processor.to_imgstr(image_tokens)
+
+ image_prompt = (
+ self.tokenizer.boi_token +
+ f"{h}*{w}" +
+ self.tokenizer.img_token +
+ imgstr +
+ self.tokenizer.eol_token +
+ self.tokenizer.eof_token +
+ self.tokenizer.eoi_token
+ )
+
+ return image_prompt
+
+ def smart_resize(self, image):
+ w, h = image.size
+ current_area = h * w
+ target_ratio = (self.tokenizer.config.image_area / current_area) ** 0.5
+
+ th = int(round(h * target_ratio))
+ tw = int(round(w * target_ratio))
+
+ image = image.resize((tw, th))
+ return image
+
+
class ReflectionTemplate(Llama3TemplateMixin, Template):
system = ('You are a world-class AI system, capable of complex reasoning and reflection. '
'Reason through the query inside tags, and then provide your final '
diff --git a/swift/llm/template/utils.py b/swift/llm/template/utils.py
index 40d3445a1..7cbcca6c4 100644
--- a/swift/llm/template/utils.py
+++ b/swift/llm/template/utils.py
@@ -1,9 +1,10 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
+from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import torch
-from transformers import PreTrainedTokenizerBase, StoppingCriteria
+from transformers import PreTrainedTokenizerBase, StoppingCriteria, LogitsProcessor
from swift.llm import History
@@ -18,6 +19,13 @@ class ContextType:
OTHER = 'other'
+@dataclass
+class GenerationProperty:
+
+ logits_processors: Optional[List[LogitsProcessor]] = None
+ criterias: Optional[List[StoppingCriteria]] = None
+
+
class StopWordsCriteria(StoppingCriteria):
"""Adding extra stop words in template to prevent unstoppable generation
Like suffixes and chat seps in the template.