Skip to content

Commit

Permalink
[BugFix][Frontend] Use LoRA tokenizer in OpenAI APIs
Browse files Browse the repository at this point in the history
Currently the LoRA tokenizers aren't used in the OpenAI APIs, meaning the behaviour won't be correct if adapters are used that have custom added tokens. This PR includes changes to address that. It mostly replaces vllm-project#3512.

More work is needed to address remaining inconsistencies in tokenization behaviour between the OpenAI front-end and standalone LLMEngine/AsyncLLMEngine use, including:
- Standalone cases don't honor truncation and add_special_tokens request parameters
- OpenAI API cases don't make use of TokenizerGroups for possible parallelization of tokenization

As well as some other inefficiencies.

But these are to be addressed in follow-on PRs.
  • Loading branch information
njhill committed Jul 8, 2024
1 parent f7a8fa3 commit 8ca4505
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 116 deletions.
13 changes: 9 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,16 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

async def get_tokenizer(self) -> "PreTrainedTokenizer":
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)

return await (self.engine.get_tokenizer_group().
get_lora_tokenizer_async(lora_request))

def start_background_loop(self) -> None:
"""Start the background loop."""
Expand Down
7 changes: 5 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,8 +441,11 @@ def get_tokenizer_group(

return self.tokenizer

def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(None)
def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)

def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
Expand Down
138 changes: 76 additions & 62 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import codecs
import time
from dataclasses import dataclass, field
from functools import cached_property
from functools import lru_cache
from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable,
List, Optional)
from typing import Sequence as GenericSequence
Expand All @@ -10,6 +10,7 @@
from fastapi import Request
from openai.types.chat import (ChatCompletionContentPartImageParam,
ChatCompletionContentPartTextParam)
from transformers import PreTrainedTokenizer

from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
Expand Down Expand Up @@ -66,39 +67,36 @@ def __init__(self,
lora_modules=lora_modules)

self.response_role = response_role
self._load_chat_template(chat_template)

def _load_chat_template(self, chat_template: Optional[str]):
tokenizer = self.tokenizer
# If this is None we use the tokenizer's default chat template
self.chat_template = self._load_chat_template(chat_template)

if chat_template is not None:
try:
with open(chat_template, "r") as f:
tokenizer.chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
tokenizer.chat_template = codecs.decode(
chat_template, "unicode_escape")

logger.info("Using supplied chat template:\n%s",
tokenizer.chat_template)
elif tokenizer.chat_template is not None:
logger.info("Using default chat template:\n%s",
tokenizer.chat_template)
else:
logger.warning(
"No chat template provided. Chat API will not work.")

@cached_property
def image_token_str(self) -> Optional[str]:
@staticmethod
def _load_chat_template(chat_template: Optional[str]) -> Optional[str]:
if chat_template is None:
return None
try:
with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template,
"unicode_escape")

logger.info("Using supplied chat template:\n%s",
resolved_chat_template)
return resolved_chat_template

@lru_cache(maxsize=32) # noqa: B019
def image_token_str(self, tokenizer: PreTrainedTokenizer) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self.model_config.hf_config.model_type
Expand All @@ -110,7 +108,7 @@ def image_token_str(self) -> Optional[str]:
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return self.tokenizer.decode(
return tokenizer.decode(
self.model_config.hf_config.image_token_index)

else:
Expand All @@ -130,6 +128,7 @@ def _parse_chat_message_content_parts(
self,
role: str,
parts: Iterable[ChatCompletionContentPartParam],
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
Expand Down Expand Up @@ -161,7 +160,7 @@ def _parse_chat_message_content_parts(
text_prompt = "\n".join(texts)

if mm_futures:
image_token_str = self.image_token_str
image_token_str = self.image_token_str(tokenizer)
if image_token_str is not None:
if image_token_str in text_prompt:
logger.warning(
Expand All @@ -180,6 +179,7 @@ def _parse_chat_message_content_parts(
def _parse_chat_message_content(
self,
message: ChatCompletionMessageParam,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
role = message["role"]
content = message.get("content")
Expand All @@ -190,7 +190,7 @@ def _parse_chat_message_content(
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])

return self._parse_chat_message_content_parts(role, content)
return self._parse_chat_message_content_parts(role, content, tokenizer)

async def create_chat_completion(
self,
Expand All @@ -212,11 +212,15 @@ async def create_chat_completion(
return error_check_ret

try:
lora_request = self._maybe_get_lora(request)
tokenizer = await self.engine.get_tokenizer(lora_request)

conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []

for msg in request.messages:
chat_parsed_result = self._parse_chat_message_content(msg)
chat_parsed_result = self._parse_chat_message_content(
msg, tokenizer)

conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
Expand All @@ -225,7 +229,9 @@ async def create_chat_completion(
tool.model_dump() for tool in request.tools
]

prompt = self.tokenizer.apply_chat_template(
if self.chat_template is not None:
tokenizer.chat_template = self.chat_template
prompt = tokenizer.apply_chat_template(
conversation=conversation,
tokenize=False,
add_generation_prompt=request.add_generation_prompt,
Expand Down Expand Up @@ -253,19 +259,19 @@ async def create_chat_completion(
request_id = f"cmpl-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = self._validate_prompt_and_tokenize(
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
guided_decoding_backend, request, await
self.engine.get_tokenizer()))
await
get_guided_decoding_logits_processor(guided_decoding_backend,
request, tokenizer))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
Expand Down Expand Up @@ -299,12 +305,12 @@ async def create_chat_completion(
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
request, result_generator, request_id, conversation)
request, result_generator, request_id, conversation, tokenizer)
else:
try:
return await self.chat_completion_full_generator(
request, raw_request, result_generator, request_id,
conversation)
conversation, tokenizer)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
Expand All @@ -316,9 +322,12 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
return request.messages[-1]["role"]

async def chat_completion_stream_generator(
self, request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
model_name = self.served_model_names[0]
created_time = int(time.time())
Expand Down Expand Up @@ -405,6 +414,7 @@ async def chat_completion_stream_generator(
logprobs = self._create_chat_logprobs(
token_ids=delta_token_ids,
top_logprobs=out_logprobs,
tokenizer=tokenizer,
num_output_top_logprobs=request.top_logprobs,
)
else:
Expand Down Expand Up @@ -493,9 +503,13 @@ async def chat_completion_stream_generator(
yield "data: [DONE]\n\n"

async def chat_completion_full_generator(
self, request: ChatCompletionRequest, raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput], request_id: str,
conversation: List[ConversationMessage]
self,
request: ChatCompletionRequest,
raw_request: Optional[Request],
result_generator: AsyncIterator[RequestOutput],
request_id: str,
conversation: List[ConversationMessage],
tokenizer: PreTrainedTokenizer,
) -> Union[ErrorResponse, ChatCompletionResponse]:

model_name = self.served_model_names[0]
Expand Down Expand Up @@ -523,6 +537,7 @@ async def chat_completion_full_generator(
token_ids=token_ids,
top_logprobs=out_logprobs,
num_output_top_logprobs=request.top_logprobs,
tokenizer=tokenizer,
)
else:
logprobs = None
Expand Down Expand Up @@ -577,16 +592,14 @@ async def chat_completion_full_generator(
return response

def _get_top_logprobs(
self, logprobs: Dict[int, Logprob],
top_logprobs: Optional[int]) -> List[ChatCompletionLogProb]:
self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int],
tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]:
return [
ChatCompletionLogProb(
token=self._get_decoded_token(p[1], p[0]),
token=(token := self._get_decoded_token(p[1], p[0],
tokenizer)),
logprob=max(p[1].logprob, -9999.0),
bytes=list(
self._get_decoded_token(p[1],
p[0]).encode("utf-8",
errors="replace")))
bytes=list(token.encode("utf-8", errors="replace")))
for i, p in enumerate(logprobs.items())
if top_logprobs and i < top_logprobs
]
Expand All @@ -595,6 +608,7 @@ def _create_chat_logprobs(
self,
token_ids: GenericSequence[int],
top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]],
tokenizer: PreTrainedTokenizer,
num_output_top_logprobs: Optional[int] = None,
) -> ChatCompletionLogProbs:
"""Create OpenAI-style logprobs."""
Expand All @@ -604,12 +618,11 @@ def _create_chat_logprobs(
for i, token_id in enumerate(token_ids):
step_top_logprobs = top_logprobs[i]
if step_top_logprobs is None:
token = tokenizer.decode(token_id)
logprobs_content.append(
ChatCompletionLogProbsContent(
token=self.tokenizer.decode(token_id),
bytes=list(
self.tokenizer.decode(token_id).encode(
"utf-8", errors="replace"))))
token=token,
bytes=list(token.encode("utf-8", errors="replace"))))
else:
logprobs_content.append(
ChatCompletionLogProbsContent(
Expand All @@ -620,6 +633,7 @@ def _create_chat_logprobs(
step_top_logprobs[token_id].decoded_token.encode(
"utf-8", errors="replace")),
top_logprobs=self._get_top_logprobs(
step_top_logprobs, num_output_top_logprobs)))
step_top_logprobs, num_output_top_logprobs,
tokenizer)))

return ChatCompletionLogProbs(content=logprobs_content)
Loading

0 comments on commit 8ca4505

Please sign in to comment.