Skip to content

Commit

Permalink
[Frontend] Clean up type annotations for mistral tokenizer (vllm-proj…
Browse files Browse the repository at this point in the history
…ect#8314)

Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
DarkLight1337 authored and garg-amit committed Oct 28, 2024
1 parent 8474dd6 commit 024f9c5
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 59 deletions.
5 changes: 3 additions & 2 deletions tests/async_engine/test_chat_template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from vllm.entrypoints.chat_utils import apply_chat_template, load_chat_template
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
load_chat_template)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand Down Expand Up @@ -87,7 +88,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
add_generation_prompt=add_generation_prompt)

# Call the function and get the result
result = apply_chat_template(
result = apply_hf_chat_template(
tokenizer,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
Expand Down
61 changes: 41 additions & 20 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Required, TypeAlias, TypedDict

from vllm.config import ModelConfig
Expand All @@ -31,7 +32,7 @@
from vllm.multimodal.utils import (async_get_and_parse_audio,
async_get_and_parse_image,
get_and_parse_audio, get_and_parse_image)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer

logger = init_logger(__name__)

Expand Down Expand Up @@ -379,6 +380,9 @@ def _parse_chat_message_content_parts(
audio_url = _AudioParser(part)["audio_url"]

mm_parser.parse_audio(audio_url["url"])
elif part_type == "refusal":
text = _RefusalParser(part)["refusal"]
texts.append(text)
else:
raise NotImplementedError(f"Unknown part type: {part_type}")

Expand Down Expand Up @@ -433,6 +437,21 @@ def _parse_chat_message_content(
return result


def _postprocess_messages(messages: List[ConversationMessage]) -> None:
# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in messages:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):

for item in message["tool_calls"]:
item["function"]["arguments"] = json.loads(
item["function"]["arguments"])


def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
Expand All @@ -446,6 +465,8 @@ def parse_chat_messages(

conversation.extend(sub_messages)

_postprocess_messages(conversation)

return conversation, mm_tracker.all_mm_data()


Expand All @@ -462,41 +483,41 @@ def parse_chat_messages_futures(

conversation.extend(sub_messages)

_postprocess_messages(conversation)

return conversation, mm_tracker.all_mm_data()


def apply_chat_template(
tokenizer: AnyTokenizer,
def apply_hf_chat_template(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: List[ConversationMessage],
chat_template: Optional[str],
*,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> Union[str, List[int]]:
) -> str:
if chat_template is None and tokenizer.chat_template is None:
raise ValueError(
"As of transformers v4.44, default chat template is no longer "
"allowed, so you must provide a chat template if the tokenizer "
"does not define one.")

# per the Transformers docs & maintainers, tool call arguments in
# assistant-role messages with tool_calls need to be dicts not JSON str -
# this is how tool-use chat templates will expect them moving forwards
# so, for messages that have tool_calls, parse the string (which we get
# from openAI format) to dict
for message in conversation:
if (message["role"] == "assistant" and "tool_calls" in message
and isinstance(message["tool_calls"], list)):
return tokenizer.apply_chat_template(
conversation=conversation, # type: ignore[arg-type]
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)

for i in range(len(message["tool_calls"])):
args: str = message["tool_calls"][i]["function"]["arguments"]
parsed_args: Dict = json.loads(args)
message["tool_calls"][i]["function"]["arguments"] = parsed_args

prompt = tokenizer.apply_chat_template(
conversation=conversation,
def apply_mistral_chat_template(
tokenizer: MistralTokenizer,
messages: List[ChatCompletionMessageParam],
chat_template: Optional[str],
**kwargs: Any,
) -> List[int]:
return tokenizer.apply_chat_template(
messages=messages,
chat_template=chat_template,
tokenize=tokenize,
**kwargs,
)
return prompt
26 changes: 18 additions & 8 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_chat_template,
apply_hf_chat_template,
apply_mistral_chat_template,
parse_chat_messages)
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
Expand All @@ -19,7 +20,7 @@
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -393,12 +394,21 @@ def chat(
conversation, mm_data = parse_chat_messages(messages, model_config,
tokenizer)

prompt = apply_chat_template(
tokenizer,
conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=messages,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=chat_template,
add_generation_prompt=add_generation_prompt,
)

inputs: PromptInputs
if is_list_of(prompt, int):
Expand Down
48 changes: 30 additions & 18 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_chat_template,
apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
Expand All @@ -35,7 +36,7 @@
from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import iterate_with_cancellation, random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -121,15 +122,27 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = apply_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
prompt: Union[str, List[int]]
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=request.chat_template or self.chat_template,
add_generation_prompt=request.add_generation_prompt,
tools=tool_dicts,
documents=request.documents,
**(request.chat_template_kwargs or {}),
)
except Exception as e:
logger.error("Error in applying chat template from request: %s", e)
return self.create_error_response(str(e))
Expand Down Expand Up @@ -307,11 +320,10 @@ async def chat_completion_stream_generator(
# Send response to echo the input portion of the
# last message
if request.echo:
last_msg_content: Optional[str] = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get(
"role") == role:
last_msg_content = conversation[-1]["content"]
last_msg_content: str = ""
if conversation and "content" in conversation[
-1] and conversation[-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""

if last_msg_content:
for i in range(num_choices):
Expand Down Expand Up @@ -659,8 +671,8 @@ async def chat_completion_full_generator(

if request.echo:
last_msg_content = ""
if conversation and conversation[-1].get(
"content") and conversation[-1].get("role") == role:
if conversation and "content" in conversation[-1] and conversation[
-1].get("role") == role:
last_msg_content = conversation[-1]["content"] or ""

for choice in choices:
Expand Down
25 changes: 18 additions & 7 deletions vllm/entrypoints/openai/serving_tokenization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from vllm.config import ModelConfig
from vllm.engine.protocol import AsyncEngineClient
from vllm.entrypoints.chat_utils import (apply_chat_template,
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
apply_mistral_chat_template,
load_chat_template,
parse_chat_messages_futures)
from vllm.entrypoints.logger import RequestLogger
Expand All @@ -18,6 +19,7 @@
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import MistralTokenizer
from vllm.utils import random_uuid

logger = init_logger(__name__)
Expand Down Expand Up @@ -66,6 +68,7 @@ async def create_tokenize(

tokenizer = await self.async_engine_client.get_tokenizer(lora_request)

prompt: Union[str, List[int]]
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config

Expand All @@ -77,12 +80,20 @@ async def create_tokenize(
logger.warning(
"Multi-modal inputs are ignored during tokenization")

prompt = apply_chat_template(
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
if isinstance(tokenizer, MistralTokenizer):
prompt = apply_mistral_chat_template(
tokenizer,
messages=request.messages,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = apply_hf_chat_template(
tokenizer,
conversation=conversation,
chat_template=self.chat_template,
add_generation_prompt=request.add_generation_prompt,
)
else:
prompt = request.prompt

Expand Down
8 changes: 4 additions & 4 deletions vllm/transformers_utils/tokenizers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Tekkenizer)

if TYPE_CHECKING:
from vllm.entrypoints.chat_utils import ConversationMessage
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam


@dataclass
Expand Down Expand Up @@ -122,19 +122,19 @@ def get_added_vocab(self) -> List[str]:
return []

def encode(self, prompt: str) -> List[int]:
# `encode ` should only be used for prompt completion
# `encode` should only be used for prompt completion
# it should never be used for chat_completion.
# For chat completion use `apply_chat_template`
return self.tokenizer.encode(prompt, bos=True, eos=False)

def apply_chat_template(self,
conversation: List["ConversationMessage"],
messages: List["ChatCompletionMessageParam"],
tools: Optional[Dict[str, Any]] = None,
**kwargs) -> List[int]:
assert tools is None, "`tools` are not yet supported."

request = ChatCompletionRequest(
messages=conversation) # type: ignore[type-var]
messages=messages) # type: ignore[type-var]
encoded = self.mistral.encode_chat_completion(request)

# encode-decode to get clean prompt
Expand Down

0 comments on commit 024f9c5

Please sign in to comment.