diff --git a/swift/llm/argument/infer_args.py b/swift/llm/argument/infer_args.py index 73b43f1c6..6f3f92ef8 100644 --- a/swift/llm/argument/infer_args.py +++ b/swift/llm/argument/infer_args.py @@ -43,6 +43,7 @@ class VllmArguments: Args: gpu_memory_utilization (float): GPU memory utilization. Default is 0.9. tensor_parallel_size (int): Tensor parallelism size. Default is 1. + pipeline_parallel_size(int): Pipeline parallelism size. Default is 1. max_num_seqs (int): Maximum number of sequences. Default is 256. max_model_len (Optional[int]): Maximum model length. Default is None. disable_custom_all_reduce (bool): Flag to disable custom all-reduce. Default is True. @@ -55,6 +56,7 @@ class VllmArguments: # vllm gpu_memory_utilization: float = 0.9 tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 max_num_seqs: int = 256 max_model_len: Optional[int] = None disable_custom_all_reduce: bool = True # Default values different from vllm diff --git a/swift/llm/infer/infer.py b/swift/llm/infer/infer.py index 08477677d..a4cdfb7a4 100644 --- a/swift/llm/infer/infer.py +++ b/swift/llm/infer/infer.py @@ -92,6 +92,7 @@ def get_infer_engine(self) -> InferEngine: kwargs.update({ 'gpu_memory_utilization': args.gpu_memory_utilization, 'tensor_parallel_size': args.tensor_parallel_size, + 'pipeline_parallel_size': args.pipeline_parallel_size, 'max_num_seqs': args.max_num_seqs, 'max_model_len': args.max_model_len, 'disable_custom_all_reduce': args.disable_custom_all_reduce, diff --git a/swift/llm/infer/infer_engine/pt_engine.py b/swift/llm/infer/infer_engine/pt_engine.py index 32f5e79d0..3b71c250d 100644 --- a/swift/llm/infer/infer_engine/pt_engine.py +++ b/swift/llm/infer/infer_engine/pt_engine.py @@ -1,25 +1,28 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import inspect +import json +import time from copy import deepcopy from dataclasses import dataclass from threading import Thread from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Optional, Union -import json import torch +from PIL import Image from tqdm import tqdm -from transformers import GenerationConfig, LogitsProcessorList, StoppingCriteriaList +from transformers import GenerationConfig, LogitsProcessorList from transformers.utils import is_torch_npu_available from swift.llm import Template, TemplateMeta, to_device from swift.plugin import Metric from swift.tuners import Swift from swift.utils import get_logger -from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, InferRequest, RequestConfig, - random_uuid) from .infer_engine import InferEngine -from .utils import InferStreamer, InferTools, LogitsStreamer, StopWordsCriteria, TokensIteratorStreamer +from .utils import InferStreamer, LogitsStreamer, TokensIteratorStreamer +from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, ImagesResponse, DeltaMessage, InferRequest, + RequestConfig, + random_uuid, MultiModalRequestMixin, ImageObject) logger = get_logger() @@ -167,7 +170,7 @@ def _infer_stream( if lora_request is not None: kwargs['adapter_names'] = self._get_adapter_names(lora_request) num_prompt_tokens = self._get_num_tokens(inputs) - stopping_criteria = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, generation_config.stop_words)]) + generation_props = template.prepare_for_generation(inputs, self.model) if generation_config.num_beams != 1: error_msg = 'Streaming generation does not support beam search.' raise ValueError(error_msg) @@ -179,16 +182,18 @@ def _model_generate(*args, **kwargs): torch.npu.set_device(self.model.device) self.model.generate(*args, **kwargs) + logits_processors = LogitsProcessorList(generation_props.logits_processors) logits_streamer = None if generation_config.output_logits: logits_streamer = LogitsStreamer() - kwargs['logits_processor'] = LogitsProcessorList([logits_streamer]) + logits_processors.append(logits_streamer) + kwargs['logits_processor'] = logits_processors thread = Thread( target=_model_generate, kwargs={ 'generation_config': generation_config, - 'stopping_criteria': stopping_criteria, + 'stopping_criteria': generation_props.criterias, 'streamer': streamer, **inputs, **kwargs @@ -276,16 +281,20 @@ def _infer_full(self, inputs: Dict[str, Any], generation_config: GenerationConfig, *, - lora_request: Optional[PtLoRARequest] = None) -> List[ChatCompletionResponse]: + lora_request: Optional[PtLoRARequest] = None) -> Union[List[ChatCompletionResponse], List[ImagesResponse]]: # bos_token TODO: encoder-decoder kwargs = {} if lora_request is not None: kwargs['adapter_names'] = self._get_adapter_names(lora_request) num_prompt_tokens = self._get_num_tokens(inputs) - stopping_criteria = StoppingCriteriaList([StopWordsCriteria(self.tokenizer, generation_config.stop_words)]) + generation_props = template.prepare_for_generation(inputs, self.model) output = dict( self.model.generate( - generation_config=generation_config, stopping_criteria=stopping_criteria, **inputs, **kwargs)) + generation_config=generation_config, + stopping_criteria=generation_props.criterias, + **inputs, + logits_processor=generation_props.logits_processors, + **kwargs)) batched_generate_ids = output['sequences'] batched_generate_ids = template.get_generate_ids(batched_generate_ids, num_prompt_tokens) batched_logprobs = self.preprocess_logits( @@ -304,17 +313,25 @@ def _infer_full(self, logprobs = self._get_logprobs(self.tokenizer, logprobs_list, generate_ids, generation_config.top_logprobs) usage_info = self._get_usage_info(num_prompt_tokens, len(generate_ids)) - response = InferTools.safe_decode(template, generate_ids, True) - - toolcall = self._get_toolcall(response, True) - choices = [ - ChatCompletionResponseChoice( - index=0, - message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), - finish_reason=None, - logprobs=logprobs) - ] - res.append(ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info)) + response = template.safe_decode(generate_ids, True) + if isinstance(response, str): + toolcall = self._get_toolcall(response, True) + choices = [ + ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), + finish_reason=None, + logprobs=logprobs) + ] + res.append(ChatCompletionResponse(model=self.model_dir, choices=choices, usage=usage_info)) + elif isinstance(response, Image.Image): + res.append(ImagesResponse( + created=time.time(), + data=[ImageObject( + b64_json=MultiModalRequestMixin._to_base64(response) + )] + )) + return res @torch.inference_mode() @@ -350,7 +367,7 @@ def _infer( *, template: Optional[Template] = None, lora_request: Optional[PtLoRARequest] = None, - ) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]: + ) -> Union[List[ChatCompletionResponse], List[ImagesResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]: self.model.eval() request_config = deepcopy(request_config or RequestConfig()) if template is None: @@ -390,7 +407,7 @@ def infer( template: Optional[Template] = None, use_tqdm: Optional[bool] = None, lora_request: Optional[PtLoRARequest] = None - ) -> Union[List[ChatCompletionResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]: + ) -> Union[List[ChatCompletionResponse], List[ImagesResponse], Iterator[List[Optional[ChatCompletionStreamResponse]]]]: if use_tqdm is None: use_tqdm = request_config is None or not request_config.stream prog_bar = tqdm(total=len(infer_requests), dynamic_ncols=True, disable=not use_tqdm) diff --git a/swift/llm/infer/infer_engine/utils.py b/swift/llm/infer/infer_engine/utils.py index dd9d01858..84c8b0687 100644 --- a/swift/llm/infer/infer_engine/utils.py +++ b/swift/llm/infer/infer_engine/utils.py @@ -25,36 +25,6 @@ def _is_chinese_char(cp: int) -> bool: return False - @staticmethod - def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]: - len_tokens = len(stop_tokens) - if is_finished and generate_ids[-len_tokens:] == stop_tokens: - return generate_ids[:-len_tokens] - if not is_finished: - for i in range(len_tokens, 0, -1): - if generate_ids[-i:] == stop_tokens[:i]: - return generate_ids[:-i] - return generate_ids - - @staticmethod - def safe_decode(template: Template, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> str: - # Do not print template_meta.suffix[-1] and eos_token. - tokenizer = template.tokenizer - - if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id: - generate_ids = generate_ids[:-1] - # skip suffix and eos_token - template_suffix = template.template_meta.suffix[-1] - if isinstance(template_suffix, str): - template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False) - generate_ids = InferTools._skip_stop_tokens(generate_ids, template_suffix, is_finished) - return tokenizer.decode(generate_ids, **decode_kwargs) - # if not is_finished or is_finished and response[-len_suffix:] == template_suffix: - # # To avoid response length being shorter than previous response length during streaming. - # # TODO:check - # # idx = max(len(response) - len_suffix, 0, self.print_idx) - # response = response[:-len_suffix] - class InferStreamer(InferTools): diff --git a/swift/llm/infer/protocol.py b/swift/llm/infer/protocol.py index fef39ddad..6892db340 100644 --- a/swift/llm/infer/protocol.py +++ b/swift/llm/infer/protocol.py @@ -204,6 +204,13 @@ class ChatMessage: tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None +@dataclass +class ImageObject: + b64_json: Optional[str] = None + revised_prompt: Optional[str] = None + url: Optional[str] = None + + @dataclass class ChatCompletionResponseChoice: index: int @@ -239,6 +246,12 @@ def to_cmpl_response(self) -> 'CompletionResponse': return CompletionResponse(self.model, choices, deepcopy(self.usage), id_, created=self.created) +@dataclass +class ImagesResponse: + created: int + + data: List[ImageObject] + @dataclass class CompletionResponse: model: str diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index 31759905c..50032761c 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -20,7 +20,7 @@ from .agent import loss_scale_map, split_str_parts_by from .template_inputs import InferRequest, StdTemplateInputs, TemplateInputs -from .utils import Context, ContextType, Prompt, Word, fetch_one, findall +from .utils import Context, ContextType, Prompt, Word, fetch_one, findall, GenerationProperty, StopWordsCriteria from .vision_utils import load_batch, load_image, normalize_bbox, rescale_image logger = get_logger() @@ -169,6 +169,40 @@ def encode( def post_encode(self, model: nn.Module, inputs: Dict[str, Any]) -> Dict[str, Any]: return {} + @staticmethod + def _skip_stop_tokens(generate_ids: List[int], stop_tokens: List[int], is_finished: bool) -> List[int]: + len_tokens = len(stop_tokens) + if is_finished and generate_ids[-len_tokens:] == stop_tokens: + return generate_ids[:-len_tokens] + if not is_finished: + for i in range(len_tokens, 0, -1): + if generate_ids[-i:] == stop_tokens[:i]: + return generate_ids[:-i] + return generate_ids + + def safe_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Any: + # Do not print template_meta.suffix[-1] and eos_token. + tokenizer = self.tokenizer + + if len(generate_ids) > 0 and generate_ids[-1] == tokenizer.eos_token_id: + generate_ids = generate_ids[:-1] + # skip suffix and eos_token + template_suffix = self.template_meta.suffix[-1] + if isinstance(template_suffix, str): + template_suffix = tokenizer.encode(template_suffix, add_special_tokens=False) + generate_ids = self._skip_stop_tokens(generate_ids, template_suffix, is_finished) + return tokenizer.decode(generate_ids, **decode_kwargs) + # if not is_finished or is_finished and response[-len_suffix:] == template_suffix: + # # To avoid response length being shorter than previous response length during streaming. + # # TODO:check + # # idx = max(len(response) - len_suffix, 0, self.print_idx) + # response = response[:-len_suffix] + + def prepare_for_generation(self, example, model) -> GenerationProperty: + return GenerationProperty( + criterias=StopWordsCriteria(self.tokenizer, model.generation_config.stop_words) + ) + def _preprocess_objects(self, inputs: StdTemplateInputs, objects: List[Dict[str, Any]]): # Load image into PIL format images = inputs.images diff --git a/swift/llm/template/template/template.py b/swift/llm/template/template/template.py index b0b7330bb..03ad1985d 100644 --- a/swift/llm/template/template/template.py +++ b/swift/llm/template/template/template.py @@ -1,16 +1,19 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import time +from datetime import datetime from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple import json import torch +from PIL import Image from transformers.dynamic_module_utils import get_class_from_dynamic_module from swift.utils import get_logger, upper_bound from ..base import Template from ..constant import TemplateType from ..register import register_template -from ..utils import Context, align_image_inputs, findall +from ..utils import Context, align_image_inputs, findall, GenerationProperty, StopWordsCriteria from ..vision_utils import (load_batch, load_image, load_video_cogvlm2, load_video_internvl, load_video_minicpmv_mplug_owl3, transform_image) from .llama import Llama3Template, Llama3TemplateMixin @@ -231,6 +234,123 @@ def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = Template(['[INST] '], ['{{SYSTEM}}\n\n', '{{QUERY}}[/INST]'], ['[INST] '], [''])) +class Emu3GenTemplate(Template): + + NULL_PROMPT_PROB = 0.1 + + COOKBOOK_SIZE = 32768 + + APPLY_LOSS_ON_ONLY_VISION = True + NEGATIVE_PROMPT = os.environ.get('NEGATIVE_PROMPT') + + def _init_template(self, *args, **kwargs): + super()._init_template(*args, **kwargs) + self.bov = self.tokenizer.encode(self.tokenizer.processor.visual_template[0].format(token_id=0))[0] + self.eov = self.tokenizer.encode(self.tokenizer.processor.visual_template[0].format(token_id=self.COOKBOOK_SIZE - 1))[0] + self.config = kwargs.get('config') + + def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: + query = example['query'] + + kwargs = dict( + mode='U' if self._is_training else 'G', + ratio="1:1", + image_area=self.config.image_area, + return_tensors="pt", + padding="longest", + ) + + # image + raw_image = example.get('images', None) + inputs = self.tokenizer.processor(query, raw_image, **kwargs) + labels = inputs["input_ids"] + if self.APPLY_LOSS_ON_ONLY_VISION: + labels = torch.where(torch.logical_and(labels >= self.bov, labels <= self.eov), labels, -100) + + inputs["labels"] = labels + for k, v in inputs.items(): + inputs[k] = v.squeeze(0) + return inputs, {} + + def prepare_for_output(self, output: str) -> str: + return output + + def prepare_for_generation(self, example, model) -> GenerationProperty: + from transformers import UnbatchedClassifierFreeGuidanceLogitsProcessor + from transformers import PrefixConstrainedLogitsProcessor + from transformers import LogitsProcessorList + + kwargs = dict( + mode='G', + ratio="1:1", + image_area=self.config.image_area, + return_tensors="pt", + padding="longest", + ) + negative_prompt = self.NEGATIVE_PROMPT + if 'negative_prompt' in example: + negative_prompt = example['negative_prompt'] + + classifier_free_guidance = 3.0 + h, w = self.tokenizer.processor.calculate_generate_size("1:1", + self.config.image_area, + self.tokenizer.processor.vision_tokenizer.spatial_scale_factor) + # h = pos_inputs.image_size[:, 0] + # w = pos_inputs.image_size[:, 1] + neg_inputs = self.tokenizer.processor(text=negative_prompt, **kwargs) + constrained_fn = self.tokenizer.processor.build_prefix_constrained_fn(h, w) + logits_processors = LogitsProcessorList([ + UnbatchedClassifierFreeGuidanceLogitsProcessor( + classifier_free_guidance, + model, + unconditional_ids=neg_inputs.input_ids.to("cuda:0"), + ), + PrefixConstrainedLogitsProcessor( + constrained_fn, + num_beams=1, + ), + ]) + + return GenerationProperty( + logits_processors=logits_processors, + criterias=StopWordsCriteria(self.tokenizer, model.generation_config.stop_words) + ) + + def safe_decode(self, generate_ids: List[int], is_finished: bool, **decode_kwargs) -> Image.Image: + mm_list = self.tokenizer.processor.decode(generate_ids) + for idx, im in enumerate(mm_list): + if not isinstance(im, Image.Image): + continue + return im + + def format_image_prompt(self, image_tokens): + h, w = image_tokens.shape + imgstr = self.tokenizer.processor.to_imgstr(image_tokens) + + image_prompt = ( + self.tokenizer.boi_token + + f"{h}*{w}" + + self.tokenizer.img_token + + imgstr + + self.tokenizer.eol_token + + self.tokenizer.eof_token + + self.tokenizer.eoi_token + ) + + return image_prompt + + def smart_resize(self, image): + w, h = image.size + current_area = h * w + target_ratio = (self.tokenizer.config.image_area / current_area) ** 0.5 + + th = int(round(h * target_ratio)) + tw = int(round(w * target_ratio)) + + image = image.resize((tw, th)) + return image + + class ReflectionTemplate(Llama3TemplateMixin, Template): system = ('You are a world-class AI system, capable of complex reasoning and reflection. ' 'Reason through the query inside tags, and then provide your final ' diff --git a/swift/llm/template/utils.py b/swift/llm/template/utils.py index 40d3445a1..7cbcca6c4 100644 --- a/swift/llm/template/utils.py +++ b/swift/llm/template/utils.py @@ -1,9 +1,10 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import re +from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch -from transformers import PreTrainedTokenizerBase, StoppingCriteria +from transformers import PreTrainedTokenizerBase, StoppingCriteria, LogitsProcessor from swift.llm import History @@ -18,6 +19,13 @@ class ContextType: OTHER = 'other' +@dataclass +class GenerationProperty: + + logits_processors: Optional[List[LogitsProcessor]] = None + criterias: Optional[List[StoppingCriteria]] = None + + class StopWordsCriteria(StoppingCriteria): """Adding extra stop words in template to prevent unstoppable generation Like suffixes and chat seps in the template.