diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 4df6c02973284..61a6d77cd8756 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -1,6 +1,7 @@ import pytest -from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template +from vllm.entrypoints.chat_utils import (apply_hf_chat_template, + load_chat_template) from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer @@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, add_generation_prompt=add_generation_prompt) # Call the function and get the result - result = apply_chat_template( + result = apply_hf_chat_template( tokenizer, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index f9f9536a7c160..a42ad81b3eef4 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -23,6 +23,7 @@ # yapf: enable # pydantic needs the TypedDict from typing_extensions from pydantic import ConfigDict +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig @@ -31,7 +32,7 @@ from vllm.multimodal.utils import (async_get_and_parse_audio, async_get_and_parse_image, get_and_parse_audio, get_and_parse_image) -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer logger = init_logger(__name__) @@ -379,6 +380,9 @@ def _parse_chat_message_content_parts( audio_url = _AudioParser(part)["audio_url"] mm_parser.parse_audio(audio_url["url"]) + elif part_type == "refusal": + text = _RefusalParser(part)["refusal"] + texts.append(text) else: raise NotImplementedError(f"Unknown part type: {part_type}") @@ -433,6 +437,21 @@ def _parse_chat_message_content( return result +def _postprocess_messages(messages: List[ConversationMessage]) -> None: + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in messages: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for item in message["tool_calls"]: + item["function"]["arguments"] = json.loads( + item["function"]["arguments"]) + + def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, @@ -446,6 +465,8 @@ def parse_chat_messages( conversation.extend(sub_messages) + _postprocess_messages(conversation) + return conversation, mm_tracker.all_mm_data() @@ -462,41 +483,41 @@ def parse_chat_messages_futures( conversation.extend(sub_messages) + _postprocess_messages(conversation) + return conversation, mm_tracker.all_mm_data() -def apply_chat_template( - tokenizer: AnyTokenizer, +def apply_hf_chat_template( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: List[ConversationMessage], chat_template: Optional[str], *, tokenize: bool = False, # Different from HF's default **kwargs: Any, -) -> Union[str, List[int]]: +) -> str: if chat_template is None and tokenizer.chat_template is None: raise ValueError( "As of transformers v4.44, default chat template is no longer " "allowed, so you must provide a chat template if the tokenizer " "does not define one.") - # per the Transformers docs & maintainers, tool call arguments in - # assistant-role messages with tool_calls need to be dicts not JSON str - - # this is how tool-use chat templates will expect them moving forwards - # so, for messages that have tool_calls, parse the string (which we get - # from openAI format) to dict - for message in conversation: - if (message["role"] == "assistant" and "tool_calls" in message - and isinstance(message["tool_calls"], list)): + return tokenizer.apply_chat_template( + conversation=conversation, # type: ignore[arg-type] + chat_template=chat_template, + tokenize=tokenize, + **kwargs, + ) - for i in range(len(message["tool_calls"])): - args: str = message["tool_calls"][i]["function"]["arguments"] - parsed_args: Dict = json.loads(args) - message["tool_calls"][i]["function"]["arguments"] = parsed_args - prompt = tokenizer.apply_chat_template( - conversation=conversation, +def apply_mistral_chat_template( + tokenizer: MistralTokenizer, + messages: List[ChatCompletionMessageParam], + chat_template: Optional[str], + **kwargs: Any, +) -> List[int]: + return tokenizer.apply_chat_template( + messages=messages, chat_template=chat_template, - tokenize=tokenize, **kwargs, ) - return prompt diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1e4432eaaa665..b1d9f386b6c3e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,7 +6,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, - apply_chat_template, + apply_hf_chat_template, + apply_mistral_chat_template, parse_chat_messages) from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt @@ -19,7 +20,7 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import (AnyTokenizer, +from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext @@ -393,12 +394,21 @@ def chat( conversation, mm_data = parse_chat_messages(messages, model_config, tokenizer) - prompt = apply_chat_template( - tokenizer, - conversation, - chat_template=chat_template, - add_generation_prompt=add_generation_prompt, - ) + prompt: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=messages, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + ) inputs: PromptInputs if is_list_of(prompt, int): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 8ed81e9c88cb2..a81d2aa989aaf 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -11,7 +11,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (ConversationMessage, - apply_chat_template, + apply_hf_chat_template, + apply_mistral_chat_template, load_chat_template, parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger @@ -35,7 +36,7 @@ from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) -from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.utils import iterate_with_cancellation, random_uuid logger = init_logger(__name__) @@ -121,15 +122,27 @@ async def create_chat_completion( tool.model_dump() for tool in request.tools ] - prompt = apply_chat_template( - tokenizer, - conversation=conversation, - chat_template=request.chat_template or self.chat_template, - add_generation_prompt=request.add_generation_prompt, - tools=tool_dicts, - documents=request.documents, - **(request.chat_template_kwargs or {}), - ) + prompt: Union[str, List[int]] + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=request.messages, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=request.chat_template or self.chat_template, + add_generation_prompt=request.add_generation_prompt, + tools=tool_dicts, + documents=request.documents, + **(request.chat_template_kwargs or {}), + ) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) @@ -307,11 +320,10 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content: Optional[str] = "" - if conversation and conversation[-1].get( - "content") and conversation[-1].get( - "role") == role: - last_msg_content = conversation[-1]["content"] + last_msg_content: str = "" + if conversation and "content" in conversation[ + -1] and conversation[-1].get("role") == role: + last_msg_content = conversation[-1]["content"] or "" if last_msg_content: for i in range(num_choices): @@ -659,8 +671,8 @@ async def chat_completion_full_generator( if request.echo: last_msg_content = "" - if conversation and conversation[-1].get( - "content") and conversation[-1].get("role") == role: + if conversation and "content" in conversation[-1] and conversation[ + -1].get("role") == role: last_msg_content = conversation[-1]["content"] or "" for choice in choices: diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 69a5ad5b62cfa..6e802b71ae2b4 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -2,7 +2,8 @@ from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient -from vllm.entrypoints.chat_utils import (apply_chat_template, +from vllm.entrypoints.chat_utils import (apply_hf_chat_template, + apply_mistral_chat_template, load_chat_template, parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger @@ -18,6 +19,7 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import MistralTokenizer from vllm.utils import random_uuid logger = init_logger(__name__) @@ -66,6 +68,7 @@ async def create_tokenize( tokenizer = await self.async_engine_client.get_tokenizer(lora_request) + prompt: Union[str, List[int]] if isinstance(request, TokenizeChatRequest): model_config = self.model_config @@ -77,12 +80,20 @@ async def create_tokenize( logger.warning( "Multi-modal inputs are ignored during tokenization") - prompt = apply_chat_template( - tokenizer, - conversation=conversation, - chat_template=self.chat_template, - add_generation_prompt=request.add_generation_prompt, - ) + if isinstance(tokenizer, MistralTokenizer): + prompt = apply_mistral_chat_template( + tokenizer, + messages=request.messages, + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + ) + else: + prompt = apply_hf_chat_template( + tokenizer, + conversation=conversation, + chat_template=self.chat_template, + add_generation_prompt=request.add_generation_prompt, + ) else: prompt = request.prompt diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 533a86b787325..17e318cb5e047 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -16,7 +16,7 @@ Tekkenizer) if TYPE_CHECKING: - from vllm.entrypoints.chat_utils import ConversationMessage + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @dataclass @@ -122,19 +122,19 @@ def get_added_vocab(self) -> List[str]: return [] def encode(self, prompt: str) -> List[int]: - # `encode ` should only be used for prompt completion + # `encode` should only be used for prompt completion # it should never be used for chat_completion. # For chat completion use `apply_chat_template` return self.tokenizer.encode(prompt, bos=True, eos=False) def apply_chat_template(self, - conversation: List["ConversationMessage"], + messages: List["ChatCompletionMessageParam"], tools: Optional[Dict[str, Any]] = None, **kwargs) -> List[int]: assert tools is None, "`tools` are not yet supported." request = ChatCompletionRequest( - messages=conversation) # type: ignore[type-var] + messages=messages) # type: ignore[type-var] encoded = self.mistral.encode_chat_completion(request) # encode-decode to get clean prompt