Skip to content

Commit

Permalink
[BugFix][Frontend] Use correct, shared tokenizer in OpenAI server
Browse files Browse the repository at this point in the history
The front-end server code currently doesn't use lora-specific tokenizers.

It also won't make use of the recently introduced parallel async tokenization if enabled.
  • Loading branch information
njhill committed Apr 10, 2024
1 parent 0258b7a commit 1db1b92
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 52 deletions.
7 changes: 7 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.usage.usage_lib import UsageContext

logger = init_logger(__name__)
Expand Down Expand Up @@ -378,6 +379,12 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

async def get_tokenizer_group(self) -> BaseTokenizerGroup:
if self.engine_use_ray:
return await self.engine.get_tokenizer_group.remote()
else:
return self.engine.get_tokenizer_group()

async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def __reduce__(self):
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")

def get_tokenizer_group(self) -> BaseTokenizerGroup:
return self.tokenizer

def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)

Expand Down
9 changes: 6 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ async def create_chat_completion(

request_id = f"cmpl-{random_uuid()}"
try:
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
token_ids = await self._validate_prompt_and_tokenize(
request,
request_id=request_id,
lora_request=lora_request,
prompt=prompt)
sampling_params = request.to_sampling_params()
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, await self.engine.get_tokenizer()))
Expand Down
23 changes: 10 additions & 13 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,20 @@ async def create_completion(self, request: CompletionRequest,
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)

for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request,
prompt_ids=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
else:
input_ids = self._validate_prompt_and_tokenize(
request,
prompt=prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens)
sub_request_id = f"{request_id}-{i}"
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
input_ids = await self._validate_prompt_and_tokenize(
request,
request_id=sub_request_id,
lora_request=lora_request,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens,
**{prompt_arg: prompt})

generators.append(
self.engine.generate(prompt,
sampling_params,
f"{request_id}-{i}",
sub_request_id,
prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e:
Expand Down
22 changes: 12 additions & 10 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ async def _post_init(self):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len

# A separate tokenizer to map token IDs to strings.
# A separate tokenizer for applying the chat template.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
trust_remote_code=engine_model_config.trust_remote_code,
truncation_side="left")
trust_remote_code=engine_model_config.trust_remote_code)

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
Expand Down Expand Up @@ -163,25 +162,28 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")

def _validate_prompt_and_tokenize(
async def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[conint(ge=1)] = None
) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")

if prompt_ids is None:
tokenizer_kwargs = {} if truncate_prompt_tokens is None else {
"truncation": True,
"max_length": truncate_prompt_tokens,
}
input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids
tokenizer = await self.engine.get_tokenizer_group()
input_ids = await tokenizer.encode_async(
prompt,
request_id,
lora_request,
truncate_to=truncate_prompt_tokens)
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
else:
Expand Down
3 changes: 3 additions & 0 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def get_tokenizer(
"Cannot use the fast tokenizer in slow tokenizer mode.")
kwargs["use_fast"] = False

if "truncation_side" not in kwargs:
kwargs["truncation_side"] = "left"

try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
Expand Down
13 changes: 7 additions & 6 deletions vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,17 @@ def get_max_input_len(self,
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
truncate_to: Optional[int] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass

@abstractmethod
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
async def encode_async(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
truncate_to: Optional[int] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass

Expand Down
19 changes: 11 additions & 8 deletions vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def _ensure_queue_initialized(self):
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
truncate_to: Optional[int] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
Expand All @@ -97,7 +98,8 @@ def encode(self,
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
lora_request=lora_request,
truncate_to=truncate_to))
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
Expand All @@ -106,11 +108,11 @@ def encode(self,
self._idle_actors.put_nowait(actor)
return ret

async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
async def encode_async(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
truncate_to: Optional[int] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.
We pick an idle actor and use it to encode the prompt.
Expand All @@ -125,7 +127,8 @@ async def encode_async(
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
lora_request=lora_request,
truncate_to=truncate_to)
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
Expand Down
32 changes: 20 additions & 12 deletions vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,27 @@ def get_max_input_len(self,
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
lora_request: Optional[LoRARequest] = None,
truncate_to: Optional[int] = None) -> List[int]:
tokenizer = self.get_lora_tokenizer(lora_request)
return tokenizer.encode(prompt)
return self._encode(tokenizer, prompt, truncate_to)

async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
async def encode_async(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
truncate_to: Optional[int] = None) -> List[int]:
tokenizer = await self.get_lora_tokenizer_async(lora_request)
return tokenizer.encode(prompt)
return self._encode(tokenizer, prompt, truncate_to)

@staticmethod
def _encode(tokenizer: PreTrainedTokenizer, prompt: str,
truncate_to: Optional[int]) -> List[int]:
tokenizer_kwargs = {} if truncate_to is None else {
"truncation": True,
"max_length": truncate_to,
}
return tokenizer.encode(prompt, **tokenizer_kwargs)

def get_lora_tokenizer(
self,
Expand All @@ -60,8 +70,7 @@ def get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers.get(lora_request.lora_int_id)

async def get_lora_tokenizer_async(
self,
Expand All @@ -74,5 +83,4 @@ async def get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers.get(lora_request.lora_int_id)

0 comments on commit 1db1b92

Please sign in to comment.