diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f610495135121..b3cef7ec09ff2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,6 +16,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -378,6 +379,12 @@ def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) + async def get_tokenizer_group(self) -> BaseTokenizerGroup: + if self.engine_use_ray: + return await self.engine.get_tokenizer_group.remote() + else: + return self.engine.get_tokenizer_group() + async def get_tokenizer(self) -> "PreTrainedTokenizer": if self.engine_use_ray: return await self.engine.get_tokenizer.remote() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c639af696544..b19a481eb0d6c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -241,6 +241,9 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def get_tokenizer_group(self) -> BaseTokenizerGroup: + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(None) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0980c3d3cb614..dcc04665f0952 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,10 +63,13 @@ async def create_chat_completion( request_id = f"cmpl-{random_uuid()}" try: - token_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) - sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + token_ids = await self._validate_prompt_and_tokenize( + request, + request_id=request_id, + lora_request=lora_request, + prompt=prompt) + sampling_params = request.to_sampling_params() guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( request, await self.engine.get_tokenizer())) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 06e7a9225fefb..021e47c323d6e 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -135,23 +135,20 @@ async def create_completion(self, request: CompletionRequest, prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): - if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( - request, - prompt_ids=prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens) - else: - input_ids = self._validate_prompt_and_tokenize( - request, - prompt=prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens) + sub_request_id = f"{request_id}-{i}" + prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" + input_ids = await self._validate_prompt_and_tokenize( + request, + request_id=sub_request_id, + lora_request=lora_request, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens, + **{prompt_arg: prompt}) generators.append( self.engine.generate(prompt, sampling_params, - f"{request_id}-{i}", + sub_request_id, prompt_token_ids=input_ids, lora_request=lora_request)) except ValueError as e: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8f69388c0251e..20d8b1f4b05b4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -64,12 +64,11 @@ async def _post_init(self): engine_model_config = await self.engine.get_model_config() self.max_model_len = engine_model_config.max_model_len - # A separate tokenizer to map token IDs to strings. + # A separate tokenizer for applying the chat template. self.tokenizer = get_tokenizer( engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, - trust_remote_code=engine_model_config.trust_remote_code, - truncation_side="left") + trust_remote_code=engine_model_config.trust_remote_code) async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" @@ -163,25 +162,28 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]: # if _check_model has been called earlier, this will be unreachable raise ValueError("The model `{request.model}` does not exist.") - def _validate_prompt_and_tokenize( + async def _validate_prompt_and_tokenize( self, request: Union[ChatCompletionRequest, CompletionRequest], + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None, truncate_prompt_tokens: Optional[conint(ge=1)] = None ) -> List[int]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") - if (prompt and prompt_ids): + if prompt and prompt_ids: raise ValueError( "Only one of prompt or prompt_ids should be provided.") if prompt_ids is None: - tokenizer_kwargs = {} if truncate_prompt_tokens is None else { - "truncation": True, - "max_length": truncate_prompt_tokens, - } - input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids + tokenizer = await self.engine.get_tokenizer_group() + input_ids = await tokenizer.encode_async( + prompt, + request_id, + lora_request, + truncate_to=truncate_prompt_tokens) elif truncate_prompt_tokens is not None: input_ids = prompt_ids[-truncate_prompt_tokens:] else: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e216a99af91f9..8ef27b161b992 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -66,6 +66,9 @@ def get_tokenizer( "Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False + if "truncation_side" not in kwargs: + kwargs["truncation_side"] = "left" + try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, diff --git a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py index 3cce96e06d1a0..d875b925d82d3 100644 --- a/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/base_tokenizer_group.py @@ -25,16 +25,17 @@ def get_max_input_len(self, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + truncate_to: Optional[int] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass @abstractmethod - async def encode_async( - self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + async def encode_async(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + truncate_to: Optional[int] = None) -> List[int]: """Encode a prompt using the tokenizer group.""" pass diff --git a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py index c00b02fdbbbc0..4b1457db9e414 100644 --- a/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py @@ -81,7 +81,8 @@ def _ensure_queue_initialized(self): def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + truncate_to: Optional[int] = None) -> List[int]: """Encode a prompt using the tokenizer group. We pick an idle actor and use it to encode the prompt. @@ -97,7 +98,8 @@ def encode(self, ret = ray.get( actor.encode.remote(request_id=request_id, prompt=prompt, - lora_request=lora_request)) + lora_request=lora_request, + truncate_to=truncate_to)) finally: # Put the actor back in the queue. # This is done in a finally block to ensure that the actor is @@ -106,11 +108,11 @@ def encode(self, self._idle_actors.put_nowait(actor) return ret - async def encode_async( - self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + async def encode_async(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + truncate_to: Optional[int] = None) -> List[int]: """Encode a prompt using the tokenizer group. We pick an idle actor and use it to encode the prompt. @@ -125,7 +127,8 @@ async def encode_async( try: ret = await actor.encode.remote(request_id=request_id, prompt=prompt, - lora_request=lora_request) + lora_request=lora_request, + truncate_to=truncate_to) finally: # Put the actor back in the queue. # This is done in a finally block to ensure that the actor is diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 927cbeed073bf..5a1394c83c6fa 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -37,17 +37,27 @@ def get_max_input_len(self, def encode(self, prompt: str, request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + lora_request: Optional[LoRARequest] = None, + truncate_to: Optional[int] = None) -> List[int]: tokenizer = self.get_lora_tokenizer(lora_request) - return tokenizer.encode(prompt) + return self._encode(tokenizer, prompt, truncate_to) - async def encode_async( - self, - prompt: str, - request_id: Optional[str] = None, - lora_request: Optional[LoRARequest] = None) -> List[int]: + async def encode_async(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + truncate_to: Optional[int] = None) -> List[int]: tokenizer = await self.get_lora_tokenizer_async(lora_request) - return tokenizer.encode(prompt) + return self._encode(tokenizer, prompt, truncate_to) + + @staticmethod + def _encode(tokenizer: PreTrainedTokenizer, prompt: str, + truncate_to: Optional[int]) -> List[int]: + tokenizer_kwargs = {} if truncate_to is None else { + "truncation": True, + "max_length": truncate_to, + } + return tokenizer.encode(prompt, **tokenizer_kwargs) def get_lora_tokenizer( self, @@ -60,8 +70,7 @@ def get_lora_tokenizer( lora_request, **self.tokenizer_config) or self.tokenizer) self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) + return self.lora_tokenizers.get(lora_request.lora_int_id) async def get_lora_tokenizer_async( self, @@ -74,5 +83,4 @@ async def get_lora_tokenizer_async( lora_request, **self.tokenizer_config) or self.tokenizer) self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) + return self.lora_tokenizers.get(lora_request.lora_int_id)