From ce770f4ef2e99c0b6256ea8e87d17f664e9ea500 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 06:00:00 +0000 Subject: [PATCH 01/94] Use discriminated union in prompt parsing --- vllm/entrypoints/openai/serving_completion.py | 76 ++++++++++++------- 1 file changed, 48 insertions(+), 28 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e24aa2489a80f..8db79123084e3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,6 +1,6 @@ import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, - Optional, Tuple) + Literal, Optional, Tuple, TypedDict, Union) from fastapi import Request @@ -26,27 +26,45 @@ [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] -def parse_prompt_format(prompt) -> Tuple[bool, list]: +class PromptStrings(TypedDict): + prompt: str + is_tokens: Literal[False] + + +class PromptTokens(TypedDict): + prompt: List[int] + is_tokens: Literal[True] + + +def _parse_prompt_element_format( + elem: Union[str, int, + List[int]]) -> Union[PromptStrings, PromptTokens]: + if isinstance(elem, str): + # case 2: array of strings + return PromptStrings(prompt=elem, is_tokens=False) + if isinstance(elem, int): + # case 3: array of tokens + return PromptTokens(prompt=[elem], is_tokens=True) + if isinstance(elem, list): + # case 4: array of token arrays + return PromptTokens(prompt=elem, is_tokens=True) + + +def parse_prompt_format( + prompt: Union[str, List[str], List[int], List[List[int]]] +) -> List[Union[PromptStrings, PromptTokens]]: # get the prompt, openai supports the following # "a string, array of strings, array of tokens, or array of token arrays." - prompt_is_tokens = False - prompts = [prompt] # case 1: a string + + if isinstance(prompt, str): + # case 1: a string + return [_parse_prompt_element_format(prompt)] + if isinstance(prompt, list): - if len(prompt) == 0: - raise ValueError("please provide at least one prompt") - elif isinstance(prompt[0], str): - prompt_is_tokens = False - prompts = prompt # case 2: array of strings - elif isinstance(prompt[0], int): - prompt_is_tokens = True - prompts = [prompt] # case 3: array of tokens - elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): - prompt_is_tokens = True - prompts = prompt # case 4: array of token arrays - else: - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") - return prompt_is_tokens, prompts + return [_parse_prompt_element_format(elem) for elem in prompt] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") class OpenAIServingCompletion(OpenAIServing): @@ -84,7 +102,7 @@ async def create_completion(self, request: CompletionRequest, created_time = int(time.time()) # Schedule the request and get the result generator. - generators = [] + generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) @@ -96,21 +114,23 @@ async def create_completion(self, request: CompletionRequest, sampling_params.logits_processors = [] sampling_params.logits_processors.append( guided_decode_logit_processor) - prompt_is_tokens, prompts = parse_prompt_format(request.prompt) + + prompts = parse_prompt_format(request.prompt) + truncate_prompt_tokens = sampling_params.truncate_prompt_tokens for i, prompt in enumerate(prompts): - if prompt_is_tokens: + if prompt["is_tokens"]: prompt_formats = self._validate_prompt_and_tokenize( request, - prompt_ids=prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens) + prompt_ids=prompt["prompt"], + truncate_prompt_tokens=truncate_prompt_tokens, + ) else: prompt_formats = self._validate_prompt_and_tokenize( request, - prompt=prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens) + prompt=prompt["prompt"], + truncate_prompt_tokens=truncate_prompt_tokens, + ) prompt_ids, prompt_text = prompt_formats generators.append( From 6b016bc537e5622995a161ee25e6fc1c91fce396 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 03:26:43 +0000 Subject: [PATCH 02/94] Fix some type errors along the way --- vllm/entrypoints/openai/protocol.py | 20 +++++++++--------- vllm/entrypoints/openai/serving_chat.py | 21 ++++++++++--------- vllm/entrypoints/openai/serving_completion.py | 14 +++++++------ vllm/entrypoints/openai/serving_engine.py | 4 ++-- 4 files changed, 31 insertions(+), 28 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f94d22d279cc4..c06fc027d3c8c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -30,7 +30,7 @@ class ModelPermission(BaseModel): allow_fine_tuning: bool = False organization: str = "*" group: Optional[str] = None - is_blocking: str = False + is_blocking: bool = False class ModelCard(BaseModel): @@ -56,7 +56,7 @@ class UsageInfo(BaseModel): class ResponseFormat(BaseModel): # type must be "json_object" or "text" - type: str = Literal["text", "json_object"] + type: Literal["text", "json_object"] class ChatCompletionRequest(BaseModel): @@ -339,8 +339,8 @@ class CompletionResponseChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -362,8 +362,8 @@ class CompletionResponseStreamChoice(BaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -390,8 +390,8 @@ class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None class ChatCompletionResponse(BaseModel): @@ -412,8 +412,8 @@ class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None class ChatCompletionStreamResponse(BaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a03c5dc88108f..1b0758175416b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -102,18 +102,19 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str - ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: - + result_generator: AsyncIterator[RequestOutput], + request_id: str) -> AsyncGenerator[str, None]: model_name = request.model created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" first_iteration = True # Send response for each token for each request.n (index) - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - finish_reason_sent = [False] * request.n + num_choices = 1 if request.n is None else request.n + previous_texts = [""] * num_choices + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + try: async for res in result_generator: res: RequestOutput @@ -124,7 +125,7 @@ async def chat_completion_stream_generator( # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) - for i in range(request.n): + for i in range(num_choices): choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(role=role), @@ -151,19 +152,19 @@ async def chat_completion_stream_generator( last_msg_content = request.messages[-1]["content"] if last_msg_content: - for i in range(request.n): + for i in range(num_choices): choice_data = ( ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( content=last_msg_content), + logprobs=None, finish_reason=None)) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - logprobs=None, model=model_name) data = chunk.model_dump_json( exclude_unset=True) @@ -249,7 +250,7 @@ async def chat_completion_full_generator( model_name = request.model created_time = int(time.time()) - final_res: RequestOutput = None + final_res: Optional[RequestOutput] = None async for res in result_generator: if await raw_request.is_disconnected(): diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 8db79123084e3..7fb47ffdc8555 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -164,7 +164,7 @@ async def create_completion(self, request: CompletionRequest, num_prompts=len(prompts)) # Non-streaming response - final_res_batch: RequestOutput = [None] * len(prompts) + final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: if await raw_request.is_disconnected(): @@ -201,9 +201,10 @@ async def completion_stream_generator( model_name: str, num_prompts: int, ) -> AsyncGenerator[str, None]: - previous_texts = [""] * request.n * num_prompts - previous_num_tokens = [0] * request.n * num_prompts - has_echoed = [False] * request.n * num_prompts + num_choices = 1 if request.n is None else request.n + previous_texts = [""] * num_choices * num_prompts + previous_num_tokens = [0] * num_choices * num_prompts + has_echoed = [False] * num_choices * num_prompts try: async for prompt_idx, res in result_generator: @@ -214,7 +215,7 @@ async def completion_stream_generator( raise StopAsyncIteration() for output in res.outputs: - i = output.index + prompt_idx * request.n + i = output.index + prompt_idx * num_choices # TODO(simon): optimize the performance by avoiding full # text O(n^2) sending. @@ -295,9 +296,10 @@ def request_output_to_completion_response( created_time: int, model_name: str, ) -> CompletionResponse: - choices = [] + choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 + for final_res in final_res_batch: assert final_res is not None prompt_token_ids = final_res.prompt_token_ids diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 77a568b564039..f785fb524388f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -151,9 +151,9 @@ def create_streaming_error_response( async def _check_model(self, request) -> Optional[ErrorResponse]: if request.model == self.served_model: - return + return None if request.model in [lora.lora_name for lora in self.lora_requests]: - return + return None return self.create_error_response( message=f"The model `{request.model}` does not exist.", err_type="NotFoundError", From 7620354628a6a63dcf790dd5d7f0521bebcd2f5e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 06:18:56 +0000 Subject: [PATCH 03/94] Some more fixes --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1b0758175416b..f189fa27d5826 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -260,7 +260,7 @@ async def chat_completion_full_generator( final_res = res assert final_res is not None - choices = [] + choices: List[ChatCompletionResponseChoice] = [] role = self.get_chat_request_role(request) for output in final_res.outputs: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index f785fb524388f..a215e498ae63c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -30,7 +30,7 @@ class OpenAIServing: def __init__(self, engine: AsyncLLMEngine, served_model: str, - lora_modules=Optional[List[LoRA]]): + lora_modules: Optional[List[LoRA]]): self.engine = engine self.served_model = served_model if lora_modules is None: From 7c3e6d91b332227eb5892de2ba10cc97b6167499 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 06:24:34 +0000 Subject: [PATCH 04/94] Apply formatter --- vllm/entrypoints/openai/serving_engine.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a215e498ae63c..b2d055bea352d 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -27,9 +27,7 @@ class LoRA: class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, - served_model: str, + def __init__(self, engine: AsyncLLMEngine, served_model: str, lora_modules: Optional[List[LoRA]]): self.engine = engine self.served_model = served_model From 7bdc84eb4dbf482d7540d56289712b516ebd4451 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 09:55:44 +0000 Subject: [PATCH 05/94] Refactor prompt parsing so that it can be shared between Chat Completions API and legacy Completions API --- vllm/entrypoints/openai/serving_chat.py | 23 ++-- vllm/entrypoints/openai/serving_completion.py | 89 +++----------- vllm/entrypoints/openai/serving_engine.py | 116 +++++++++++++++++- 3 files changed, 145 insertions(+), 83 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f189fa27d5826..58856bd96f9f9 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -10,7 +10,8 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -26,7 +27,7 @@ def __init__(self, engine: AsyncLLMEngine, served_model: str, response_role: str, - lora_modules: Optional[List[LoRA]] = None, + lora_modules: Optional[List[LoRAModulePath]] = None, chat_template=None): super().__init__(engine=engine, served_model=served_model, @@ -63,9 +64,6 @@ async def create_chat_completion( request_id = f"cmpl-{random_uuid()}" try: - # Tokenize/detokenize depending on prompt format (string/token list) - prompt_ids, prompt_text = self._validate_prompt_and_tokenize( - request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) guided_decode_logits_processor = ( @@ -76,12 +74,21 @@ async def create_chat_completion( sampling_params.logits_processors = [] sampling_params.logits_processors.append( guided_decode_logits_processor) + + prompt_ids, prompt_text = self._tokenize_input_text( + request, + prompt, + truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, + ) + + result_generator = self.engine.generate(prompt_text, + sampling_params, + request_id, prompt_ids, + lora_request) except ValueError as e: + # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt_text, sampling_params, - request_id, prompt_ids, - lora_request) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7fb47ffdc8555..8b0bec7fbda69 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,6 +1,6 @@ import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, - Literal, Optional, Tuple, TypedDict, Union) + Optional, Tuple) from fastapi import Request @@ -11,7 +11,8 @@ CompletionResponseStreamChoice, CompletionStreamResponse, LogProbs, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -26,53 +27,12 @@ [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] -class PromptStrings(TypedDict): - prompt: str - is_tokens: Literal[False] - - -class PromptTokens(TypedDict): - prompt: List[int] - is_tokens: Literal[True] - - -def _parse_prompt_element_format( - elem: Union[str, int, - List[int]]) -> Union[PromptStrings, PromptTokens]: - if isinstance(elem, str): - # case 2: array of strings - return PromptStrings(prompt=elem, is_tokens=False) - if isinstance(elem, int): - # case 3: array of tokens - return PromptTokens(prompt=[elem], is_tokens=True) - if isinstance(elem, list): - # case 4: array of token arrays - return PromptTokens(prompt=elem, is_tokens=True) - - -def parse_prompt_format( - prompt: Union[str, List[str], List[int], List[List[int]]] -) -> List[Union[PromptStrings, PromptTokens]]: - # get the prompt, openai supports the following - # "a string, array of strings, array of tokens, or array of token arrays." - - if isinstance(prompt, str): - # case 1: a string - return [_parse_prompt_element_format(prompt)] - - if isinstance(prompt, list): - return [_parse_prompt_element_format(elem) for elem in prompt] - - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") - - class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model: str, - lora_modules: Optional[List[LoRA]] = None): + lora_modules: Optional[List[LoRAModulePath]] = None): super().__init__(engine=engine, served_model=served_model, lora_modules=lora_modules) @@ -115,24 +75,13 @@ async def create_completion(self, request: CompletionRequest, sampling_params.logits_processors.append( guided_decode_logit_processor) - prompts = parse_prompt_format(request.prompt) - truncate_prompt_tokens = sampling_params.truncate_prompt_tokens - - for i, prompt in enumerate(prompts): - if prompt["is_tokens"]: - prompt_formats = self._validate_prompt_and_tokenize( - request, - prompt_ids=prompt["prompt"], - truncate_prompt_tokens=truncate_prompt_tokens, - ) - else: - prompt_formats = self._validate_prompt_and_tokenize( + for i, (prompt_ids, prompt_text) in enumerate( + self._tokenize_input_text_or_texts( request, - prompt=prompt["prompt"], - truncate_prompt_tokens=truncate_prompt_tokens, - ) - prompt_ids, prompt_text = prompt_formats - + request.prompt, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens, + )): generators.append( self.engine.generate(prompt_text, sampling_params, @@ -155,16 +104,18 @@ async def create_completion(self, request: CompletionRequest, # Streaming response if stream: - return self.completion_stream_generator(request, - raw_request, - result_generator, - request_id, - created_time, - model_name, - num_prompts=len(prompts)) + return self.completion_stream_generator( + request, + raw_request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(generators)) # Non-streaming response - final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) + final_res_batch: List[Optional[RequestOutput]] = [None + ] * len(generators) try: async for i, res in result_generator: if await raw_request.is_disconnected(): diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index b2d055bea352d..0e896c455c276 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,8 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, List, Optional, Tuple, Union +from typing import (Dict, Iterable, Iterator, List, Literal, Optional, Tuple, + TypedDict, Union) from pydantic import conint @@ -19,8 +20,18 @@ logger = init_logger(__name__) +class InputStrings(TypedDict): + input_text: str + is_tokens: Literal[False] + + +class InputTokens(TypedDict): + input_text: List[int] + is_tokens: Literal[True] + + @dataclass -class LoRA: +class LoRAModulePath: name: str local_path: str @@ -28,7 +39,7 @@ class LoRA: class OpenAIServing: def __init__(self, engine: AsyncLLMEngine, served_model: str, - lora_modules: Optional[List[LoRA]]): + lora_modules: Optional[List[LoRAModulePath]]): self.engine = engine self.served_model = served_model if lora_modules is None: @@ -147,7 +158,9 @@ def create_streaming_error_response( }) return json_str - async def _check_model(self, request) -> Optional[ErrorResponse]: + async def _check_model( + self, request: Union[CompletionRequest, ChatCompletionRequest] + ) -> Optional[ErrorResponse]: if request.model == self.served_model: return None if request.model in [lora.lora_name for lora in self.lora_requests]: @@ -157,9 +170,11 @@ async def _check_model(self, request) -> Optional[ErrorResponse]: err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_lora(self, request) -> Optional[LoRARequest]: + def _maybe_get_lora( + self, request: Union[CompletionRequest, ChatCompletionRequest] + ) -> Optional[LoRARequest]: if request.model == self.served_model: - return + return None for lora in self.lora_requests: if request.model == lora.lora_name: return lora @@ -207,3 +222,92 @@ def _validate_prompt_and_tokenize( f"Please reduce the length of the messages or completion.", ) else: return input_ids, input_text + + # https://platform.openai.com/docs/api-reference/embeddings/create + def _tokenize_input_text( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + input_text: Union[str, List[int]], + truncate_prompt_tokens: Optional[conint(ge=1)] = None, + ) -> Tuple[List[int], str]: + return next( + self._tokenize_input_texts( + request, + [input_text], + truncate_prompt_tokens=truncate_prompt_tokens, + )) + + def _tokenize_input_texts( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + input_texts: Iterable[Union[str, List[int]]], + truncate_prompt_tokens: Optional[conint(ge=1)] = None, + ) -> Iterator[Tuple[List[int], str]]: + for input_text in input_texts: + if isinstance(input_text, str): + yield self._validate_prompt_and_tokenize( + request, + prompt=input_text, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + else: + yield self._validate_prompt_and_tokenize( + request, + prompt_ids=input_text, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _parse_input_element( + self, + elem: Union[str, int, List[int]], + ) -> Union[InputStrings, InputTokens]: + if isinstance(elem, str): + # case 2: array of strings + return InputStrings(prompt=elem, is_tokens=False) + if isinstance(elem, int): + # case 3: array of tokens + return InputTokens(prompt=[elem], is_tokens=True) + if isinstance(elem, list): + # case 4: array of token arrays + return InputTokens(prompt=elem, is_tokens=True) + + def _parse_input_text_or_texts( + self, + input_text_or_texts: Union[str, List[str], List[int], List[List[int]]], + ) -> List[Union[InputStrings, InputTokens]]: + # get the prompt, openai supports the following: + # a string, array of strings, array of tokens, or array of token arrays + + if isinstance(input_text_or_texts, str): + # case 1: a string + return [self._parse_input_element(input_text_or_texts)] + + if isinstance(input_text_or_texts, list): + return [self._parse_input_element(e) for e in input_text_or_texts] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + def _tokenize_input_text_or_texts( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + input_text_or_texts: Union[str, List[str], List[int], List[List[int]]], + truncate_prompt_tokens: Optional[conint(ge=1)] = None, + ) -> Iterator[Tuple[List[int], str]]: + for input_ in self._parse_input_text_or_texts(input_text_or_texts): + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is True" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + if input_["is_tokens"] is True: + yield self._validate_prompt_and_tokenize( + request, + prompt_ids=input_["input_text"], + truncate_prompt_tokens=truncate_prompt_tokens, + ) + else: + yield self._validate_prompt_and_tokenize( + request, + prompt=input_["input_text"], + truncate_prompt_tokens=truncate_prompt_tokens, + ) From a7d109853a93dd8114c984ae18bf9122e14c4128 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 10:02:39 +0000 Subject: [PATCH 06/94] Make code more readable --- vllm/entrypoints/openai/serving_chat.py | 18 +++++----- vllm/entrypoints/openai/serving_completion.py | 34 +++++++++---------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 58856bd96f9f9..664899e262c0f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -124,7 +124,6 @@ async def chat_completion_stream_generator( try: async for res in result_generator: - res: RequestOutput # We need to do it here, because if there are exceptions in # the result_generator, it needs to be sent as the FIRST # response (by the try...catch). @@ -322,24 +321,25 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template): + def _load_chat_template(self, chat_template: str): + tokenizer = self.tokenizer + assert tokenizer is not None + if chat_template is not None: try: with open(chat_template, "r") as f: - self.tokenizer.chat_template = f.read() + tokenizer.chat_template = f.read() except OSError: # If opening a file fails, set chat template to be args to # ensure we decode so our escape are interpreted correctly - self.tokenizer.chat_template = codecs.decode( + tokenizer.chat_template = codecs.decode( chat_template, "unicode_escape") logger.info( - f"Using supplied chat template:\n{self.tokenizer.chat_template}" - ) - elif self.tokenizer.chat_template is not None: + f"Using supplied chat template:\n{tokenizer.chat_template}") + elif tokenizer.chat_template is not None: logger.info( - f"Using default chat template:\n{self.tokenizer.chat_template}" - ) + f"Using default chat template:\n{tokenizer.chat_template}") else: logger.warning( "No chat template provided. Chat API will not work.") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 8b0bec7fbda69..7f1ebd53e85eb 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -75,13 +75,15 @@ async def create_completion(self, request: CompletionRequest, sampling_params.logits_processors.append( guided_decode_logit_processor) - for i, (prompt_ids, prompt_text) in enumerate( - self._tokenize_input_text_or_texts( - request, - request.prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens, - )): + prompts = list( + self._tokenize_input_text_or_texts( + request, + request.prompt, + truncate_prompt_tokens=sampling_params. + truncate_prompt_tokens, + )) + + for i, (prompt_ids, prompt_text) in enumerate(prompts): generators.append( self.engine.generate(prompt_text, sampling_params, @@ -104,18 +106,16 @@ async def create_completion(self, request: CompletionRequest, # Streaming response if stream: - return self.completion_stream_generator( - request, - raw_request, - result_generator, - request_id, - created_time, - model_name, - num_prompts=len(generators)) + return self.completion_stream_generator(request, + raw_request, + result_generator, + request_id, + created_time, + model_name, + num_prompts=len(prompts)) # Non-streaming response - final_res_batch: List[Optional[RequestOutput]] = [None - ] * len(generators) + final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: if await raw_request.is_disconnected(): From 8b9d6368846e3c12cb591e4f71a37d746a759d49 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 10:08:50 +0000 Subject: [PATCH 07/94] Move assertion to a more appropriate place --- vllm/entrypoints/openai/serving_completion.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7f1ebd53e85eb..10e368ecfd7b9 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -123,8 +123,19 @@ async def create_completion(self, request: CompletionRequest, await self.engine.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res + + final_res_batch_checked: List[RequestOutput] = [] + for final_res in final_res_batch: + assert final_res is not None + final_res_batch_checked.append(final_res) + response = self.request_output_to_completion_response( - final_res_batch, request, request_id, created_time, model_name) + final_res_batch_checked, + request, + request_id, + created_time, + model_name, + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -252,7 +263,6 @@ def request_output_to_completion_response( num_generated_tokens = 0 for final_res in final_res_batch: - assert final_res is not None prompt_token_ids = final_res.prompt_token_ids prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt From c48c13a4bc1e5f49a6b729fb177b81347e457466 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 10:18:47 +0000 Subject: [PATCH 08/94] Add code documentation --- vllm/entrypoints/openai/serving_engine.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0e896c455c276..d04a2af9ba744 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -223,13 +223,15 @@ def _validate_prompt_and_tokenize( else: return input_ids, input_text - # https://platform.openai.com/docs/api-reference/embeddings/create def _tokenize_input_text( self, request: Union[ChatCompletionRequest, CompletionRequest], input_text: Union[str, List[int]], truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> Tuple[List[int], str]: + """A simpler implementation of + :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_input_text_or_texts` + that assumes single input.""" return next( self._tokenize_input_texts( request, @@ -243,6 +245,9 @@ def _tokenize_input_texts( input_texts: Iterable[Union[str, List[int]]], truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> Iterator[Tuple[List[int], str]]: + """A simpler implementation of + :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_input_text_or_texts` + that assumes multiple input.""" for input_text in input_texts: if isinstance(input_text, str): yield self._validate_prompt_and_tokenize( @@ -275,9 +280,6 @@ def _parse_input_text_or_texts( self, input_text_or_texts: Union[str, List[str], List[int], List[List[int]]], ) -> List[Union[InputStrings, InputTokens]]: - # get the prompt, openai supports the following: - # a string, array of strings, array of tokens, or array of token arrays - if isinstance(input_text_or_texts, str): # case 1: a string return [self._parse_input_element(input_text_or_texts)] @@ -294,6 +296,12 @@ def _tokenize_input_text_or_texts( input_text_or_texts: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> Iterator[Tuple[List[int], str]]: + """Tokenize/detokenize depending on the input format. + + According to `OpenAI API `_ + , each input can be a string or array of tokens. Note that each request + can pass one or more inputs. + """ for input_ in self._parse_input_text_or_texts(input_text_or_texts): # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly From 35303626d45f7f7257ac3f5d385758cd52e94270 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 10:47:36 +0000 Subject: [PATCH 09/94] Decompose `_validate_prompt_and_tokenize` --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/entrypoints/openai/serving_engine.py | 146 +++++++++++------- 3 files changed, 89 insertions(+), 61 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 664899e262c0f..0b30366d716dd 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -75,7 +75,7 @@ async def create_chat_completion( sampling_params.logits_processors.append( guided_decode_logits_processor) - prompt_ids, prompt_text = self._tokenize_input_text( + prompt_ids, prompt_text = self._tokenize_prompt_input( request, prompt, truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 10e368ecfd7b9..f80d1bd5c4484 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -76,7 +76,7 @@ async def create_completion(self, request: CompletionRequest, guided_decode_logit_processor) prompts = list( - self._tokenize_input_text_or_texts( + self._tokenize_prompt_input_or_inputs( request, request.prompt, truncate_prompt_tokens=sampling_params. diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index d04a2af9ba744..7b3bbcc4c7b98 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -6,6 +6,7 @@ TypedDict, Union) from pydantic import conint +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -20,13 +21,13 @@ logger = init_logger(__name__) -class InputStrings(TypedDict): - input_text: str +class InputString(TypedDict): + text: str is_tokens: Literal[False] class InputTokens(TypedDict): - input_text: List[int] + text: List[int] is_tokens: Literal[True] @@ -181,32 +182,48 @@ def _maybe_get_lora( # if _check_model has been called earlier, this will be unreachable raise ValueError("The model `{request.model}` does not exist.") - def _validate_prompt_and_tokenize( + def _normalize_prompt_text_to_input( self, request: Union[ChatCompletionRequest, CompletionRequest], - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, + prompt: str, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], truncate_prompt_tokens: Optional[conint(ge=1)] = None ) -> Tuple[List[int], str]: - if not (prompt or prompt_ids): - raise ValueError("Either prompt or prompt_ids should be provided.") - if (prompt and prompt_ids): - raise ValueError( - "Only one of prompt or prompt_ids should be provided.") - - if prompt_ids is None: - tokenizer_kwargs = {} if truncate_prompt_tokens is None else { - "truncation": True, - "max_length": truncate_prompt_tokens, - } - input_ids = self.tokenizer(prompt, **tokenizer_kwargs).input_ids - elif truncate_prompt_tokens is not None: - input_ids = prompt_ids[-truncate_prompt_tokens:] + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt) else: + encoded = tokenizer(prompt, + truncation=True, + max_length=truncate_prompt_tokens) + + input_ids = encoded.input_ids + + input_text = prompt + + return self._validate_input(request, input_ids, input_text) + + def _normalize_prompt_tokens_to_input( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt_ids: List[int], + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + truncate_prompt_tokens: Optional[conint(ge=1)] = None + ) -> Tuple[List[int], str]: + if truncate_prompt_tokens is None: input_ids = prompt_ids + else: + input_ids = prompt_ids[-truncate_prompt_tokens:] + + input_text = tokenizer.decode(prompt_ids) - input_text = prompt if prompt is not None else self.tokenizer.decode( - prompt_ids) + return self._validate_input(request, input_ids, input_text) + + def _validate_input( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + input_ids: List[int], + input_text: str, + ) -> Tuple[List[int], str]: token_num = len(input_ids) if request.max_tokens is None: @@ -220,80 +237,85 @@ def _validate_prompt_and_tokenize( f"({token_num} in the messages, " f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.", ) - else: - return input_ids, input_text - def _tokenize_input_text( + return input_ids, input_text + + def _tokenize_prompt_input( self, request: Union[ChatCompletionRequest, CompletionRequest], - input_text: Union[str, List[int]], + prompt_input: Union[str, List[int]], truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> Tuple[List[int], str]: """A simpler implementation of - :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_input_text_or_texts` + :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` that assumes single input.""" return next( - self._tokenize_input_texts( + self._tokenize_prompt_inputs( request, - [input_text], + [prompt_input], truncate_prompt_tokens=truncate_prompt_tokens, )) - def _tokenize_input_texts( + def _tokenize_prompt_inputs( self, request: Union[ChatCompletionRequest, CompletionRequest], - input_texts: Iterable[Union[str, List[int]]], + prompt_inputs: Iterable[Union[str, List[int]]], truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> Iterator[Tuple[List[int], str]]: """A simpler implementation of - :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_input_text_or_texts` - that assumes multiple input.""" - for input_text in input_texts: - if isinstance(input_text, str): - yield self._validate_prompt_and_tokenize( + :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` + that assumes multiple inputs.""" + tokenizer = self.tokenizer + assert tokenizer is not None + + for text in prompt_inputs: + if isinstance(text, str): + yield self._normalize_prompt_text_to_input( request, - prompt=input_text, + prompt=text, + tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, ) else: - yield self._validate_prompt_and_tokenize( + yield self._normalize_prompt_tokens_to_input( request, - prompt_ids=input_text, + prompt_ids=text, + tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, ) - def _parse_input_element( + def _parse_prompt_element( self, elem: Union[str, int, List[int]], - ) -> Union[InputStrings, InputTokens]: + ) -> Union[InputString, InputTokens]: if isinstance(elem, str): # case 2: array of strings - return InputStrings(prompt=elem, is_tokens=False) + return InputString(text=elem, is_tokens=False) if isinstance(elem, int): # case 3: array of tokens - return InputTokens(prompt=[elem], is_tokens=True) + return InputTokens(text=[elem], is_tokens=True) if isinstance(elem, list): # case 4: array of token arrays - return InputTokens(prompt=elem, is_tokens=True) + return InputTokens(text=elem, is_tokens=True) - def _parse_input_text_or_texts( + def _parse_prompt_input_or_inputs( self, - input_text_or_texts: Union[str, List[str], List[int], List[List[int]]], - ) -> List[Union[InputStrings, InputTokens]]: - if isinstance(input_text_or_texts, str): + input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + ) -> List[Union[InputString, InputTokens]]: + if isinstance(input_or_inputs, str): # case 1: a string - return [self._parse_input_element(input_text_or_texts)] + return [self._parse_prompt_element(input_or_inputs)] - if isinstance(input_text_or_texts, list): - return [self._parse_input_element(e) for e in input_text_or_texts] + if isinstance(input_or_inputs, list): + return [self._parse_prompt_element(e) for e in input_or_inputs] raise ValueError("prompt must be a string, array of strings, " "array of tokens, or array of token arrays") - def _tokenize_input_text_or_texts( + def _tokenize_prompt_input_or_inputs( self, request: Union[ChatCompletionRequest, CompletionRequest], - input_text_or_texts: Union[str, List[str], List[int], List[List[int]]], + input_or_inputs: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[conint(ge=1)] = None, ) -> Iterator[Tuple[List[int], str]]: """Tokenize/detokenize depending on the input format. @@ -302,20 +324,26 @@ def _tokenize_input_text_or_texts( , each input can be a string or array of tokens. Note that each request can pass one or more inputs. """ - for input_ in self._parse_input_text_or_texts(input_text_or_texts): + tokenizer = self.tokenizer + assert tokenizer is not None + + for prompt_input in self._parse_prompt_input_or_inputs( + input_or_inputs): # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly # "is True" is required for Pyright to perform type narrowing # See: https://github.com/microsoft/pyright/issues/7672 - if input_["is_tokens"] is True: - yield self._validate_prompt_and_tokenize( + if prompt_input["is_tokens"] is False: + yield self._normalize_prompt_text_to_input( request, - prompt_ids=input_["input_text"], + prompt=prompt_input["text"], + tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, ) else: - yield self._validate_prompt_and_tokenize( + yield self._normalize_prompt_tokens_to_input( request, - prompt=input_["input_text"], + prompt_ids=prompt_input["text"], + tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, ) From b8feec974209f87c78a93a18972e9c75fd68945e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 12 Apr 2024 19:08:00 +0800 Subject: [PATCH 10/94] Fix missing import due to renaming --- vllm/entrypoints/openai/cli_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index cc71931b97955..0bd15b667c651 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,7 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.serving_engine import LoRA +from vllm.entrypoints.openai.serving_engine import LoRAModulePath class LoRAParserAction(argparse.Action): @@ -18,7 +18,7 @@ def __call__(self, parser, namespace, values, option_string=None): lora_list = [] for item in values: name, path = item.split('=') - lora_list.append(LoRA(name, path)) + lora_list.append(LoRAModulePath(name, path)) setattr(namespace, self.dest, lora_list) From cc1a5b3eeccaa50419082fc9cd80a8ac662b1fdc Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 13 Apr 2024 04:38:57 +0000 Subject: [PATCH 11/94] Fix bug when parsing array of tokens --- vllm/entrypoints/openai/serving_engine.py | 39 +++++++++++++---------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 7b3bbcc4c7b98..4f5258b048c43 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from http import HTTPStatus from typing import (Dict, Iterable, Iterator, List, Literal, Optional, Tuple, - TypedDict, Union) + TypedDict, Union, cast) from pydantic import conint from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -284,30 +284,35 @@ def _tokenize_prompt_inputs( truncate_prompt_tokens=truncate_prompt_tokens, ) - def _parse_prompt_element( - self, - elem: Union[str, int, List[int]], - ) -> Union[InputString, InputTokens]: - if isinstance(elem, str): - # case 2: array of strings - return InputString(text=elem, is_tokens=False) - if isinstance(elem, int): - # case 3: array of tokens - return InputTokens(text=[elem], is_tokens=True) - if isinstance(elem, list): - # case 4: array of token arrays - return InputTokens(text=elem, is_tokens=True) - def _parse_prompt_input_or_inputs( self, input_or_inputs: Union[str, List[str], List[int], List[List[int]]], ) -> List[Union[InputString, InputTokens]]: if isinstance(input_or_inputs, str): # case 1: a string - return [self._parse_prompt_element(input_or_inputs)] + elem = input_or_inputs + return [InputString(text=elem, is_tokens=False)] if isinstance(input_or_inputs, list): - return [self._parse_prompt_element(e) for e in input_or_inputs] + if len(input_or_inputs) == 0: + raise ValueError("please provide at least one prompt") + if isinstance(input_or_inputs[0], str): + # case 2: array of strings + return [ + InputString(text=elem, is_tokens=False) + for elem in cast(List[str], input_or_inputs) + ] + if isinstance(input_or_inputs[0], int): + # case 3: array of tokens + elem = cast(List[int], input_or_inputs) + return [InputTokens(text=elem, is_tokens=True)] + if isinstance(input_or_inputs[0], list) and isinstance( + input_or_inputs[0][0], int): + # case 4: array of token arrays + return [ + InputTokens(text=elem, is_tokens=True) + for elem in cast(List[List[int]], input_or_inputs) + ] raise ValueError("prompt must be a string, array of strings, " "array of tokens, or array of token arrays") From f9c1135e353921d5e00d22f177a5516c9317d87a Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 13 Apr 2024 05:48:37 +0000 Subject: [PATCH 12/94] Add token array to batch completions testing --- tests/entrypoints/test_openai_server.py | 92 +++++++++++++------------ 1 file changed, 47 insertions(+), 45 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 7940430b8b654..d83692d5cbdbb 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -413,50 +413,52 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, ) async def test_batch_completions(server, client: openai.AsyncOpenAI, model_name: str): - # test simple list - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - max_tokens=5, - temperature=0.0, - ) - assert len(batch.choices) == 2 - assert batch.choices[0].text == batch.choices[1].text - - # test n = 2 - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - n=2, - max_tokens=5, - temperature=0.0, - extra_body=dict( - # NOTE: this has to be true for n > 1 in vLLM, but not necessary - # for official client. - use_beam_search=True), - ) - assert len(batch.choices) == 4 - assert batch.choices[0].text != batch.choices[ - 1].text, "beam search should be different" - assert batch.choices[0].text == batch.choices[ - 2].text, "two copies of the same prompt should be the same" - assert batch.choices[1].text == batch.choices[ - 3].text, "two copies of the same prompt should be the same" - - # test streaming - batch = await client.completions.create( - model=model_name, - prompt=["Hello, my name is", "Hello, my name is"], - max_tokens=5, - temperature=0.0, - stream=True, - ) - texts = [""] * 2 - async for chunk in batch: - assert len(chunk.choices) == 1 - choice = chunk.choices[0] - texts[choice.index] += choice.text - assert texts[0] == texts[1] + # test using text and token IDs + for prompts in (["Hello, my name is"] * 2, [[0, 0, 0, 0, 0]] * 2): + # test simple list + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + ) + assert len(batch.choices) == 2 + assert batch.choices[0].text == batch.choices[1].text + + # test n = 2 + batch = await client.completions.create( + model=model_name, + prompt=prompts, + n=2, + max_tokens=5, + temperature=0.0, + extra_body=dict( + # NOTE: this has to be true for n > 1 in vLLM, but not necessary + # for official client. + use_beam_search=True), + ) + assert len(batch.choices) == 4 + assert batch.choices[0].text != batch.choices[ + 1].text, "beam search should be different" + assert batch.choices[0].text == batch.choices[ + 2].text, "two copies of the same prompt should be the same" + assert batch.choices[1].text == batch.choices[ + 3].text, "two copies of the same prompt should be the same" + + # test streaming + batch = await client.completions.create( + model=model_name, + prompt=prompts, + max_tokens=5, + temperature=0.0, + stream=True, + ) + texts = [""] * 2 + async for chunk in batch: + assert len(chunk.choices) == 1 + choice = chunk.choices[0] + texts[choice.index] += choice.text + assert texts[0] == texts[1] async def test_logits_bias(server, client: openai.AsyncOpenAI): @@ -762,7 +764,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, prompt_text = tokenizer.decode(prompt) if isinstance(prompt, list) else prompt assert (completion.choices[0].text is not None - and re.search(r"^" + prompt_text, completion.choices[0].text)) + and completion.choices[0].text.startswith(prompt_text)) logprobs = completion.choices[0].logprobs assert logprobs is not None assert len(logprobs.text_offset) > 5 From f2e818055e31170a7d726498464791777f0bf828 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 14 Apr 2024 04:48:14 +0000 Subject: [PATCH 13/94] Replace legacy `conint` with `Annotated` field --- vllm/entrypoints/openai/protocol.py | 5 +++-- vllm/entrypoints/openai/serving_engine.py | 13 +++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index c06fc027d3c8c..4358b6000d335 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,7 +4,8 @@ from typing import Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, Field, conint, model_validator +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Annotated from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid @@ -229,7 +230,7 @@ class CompletionRequest(BaseModel): min_tokens: Optional[int] = 0 skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True - truncate_prompt_tokens: Optional[conint(ge=1)] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 4f5258b048c43..3ebaa73157118 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -5,8 +5,9 @@ from typing import (Dict, Iterable, Iterator, List, Literal, Optional, Tuple, TypedDict, Union, cast) -from pydantic import conint +from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from typing_extensions import Annotated from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -187,7 +188,7 @@ def _normalize_prompt_text_to_input( request: Union[ChatCompletionRequest, CompletionRequest], prompt: str, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - truncate_prompt_tokens: Optional[conint(ge=1)] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None ) -> Tuple[List[int], str]: if truncate_prompt_tokens is None: encoded = tokenizer(prompt) @@ -207,7 +208,7 @@ def _normalize_prompt_tokens_to_input( request: Union[ChatCompletionRequest, CompletionRequest], prompt_ids: List[int], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - truncate_prompt_tokens: Optional[conint(ge=1)] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None ) -> Tuple[List[int], str]: if truncate_prompt_tokens is None: input_ids = prompt_ids @@ -244,7 +245,7 @@ def _tokenize_prompt_input( self, request: Union[ChatCompletionRequest, CompletionRequest], prompt_input: Union[str, List[int]], - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> Tuple[List[int], str]: """A simpler implementation of :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` @@ -260,7 +261,7 @@ def _tokenize_prompt_inputs( self, request: Union[ChatCompletionRequest, CompletionRequest], prompt_inputs: Iterable[Union[str, List[int]]], - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> Iterator[Tuple[List[int], str]]: """A simpler implementation of :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` @@ -321,7 +322,7 @@ def _tokenize_prompt_input_or_inputs( self, request: Union[ChatCompletionRequest, CompletionRequest], input_or_inputs: Union[str, List[str], List[int], List[List[int]]], - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> Iterator[Tuple[List[int], str]]: """Tokenize/detokenize depending on the input format. From a1db4e0c1dc3631875a07ef1b1c41d4e08e2b1f7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 23 Apr 2024 11:03:44 +0000 Subject: [PATCH 14/94] Fix `mypy` error --- vllm/entrypoints/openai/serving_engine.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 308f9ae7339de..f4eab07119fa3 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -297,8 +297,7 @@ def _parse_prompt_input_or_inputs( ) -> List[Union[InputString, InputTokens]]: if isinstance(input_or_inputs, str): # case 1: a string - elem = input_or_inputs - return [InputString(text=elem, is_tokens=False)] + return [InputString(text=input_or_inputs, is_tokens=False)] if isinstance(input_or_inputs, list): if len(input_or_inputs) == 0: From 5d42800ddd1882564a1edc39ffb3142fd525dbbb Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 24 Apr 2024 08:59:54 +0000 Subject: [PATCH 15/94] Combine prompt inputs --- benchmarks/benchmark_latency.py | 4 +- examples/llava_example.py | 10 +- tests/conftest.py | 18 +- tests/engine/test_skip_tokenizer_init.py | 2 +- tests/test_sequence.py | 10 +- tests/tokenization/test_detokenize.py | 7 +- vllm/engine/async_llm_engine.py | 108 ++++++------ vllm/engine/llm_engine.py | 161 +++++++++++------- vllm/entrypoints/llm.py | 66 ++----- vllm/entrypoints/openai/serving_chat.py | 12 +- vllm/entrypoints/openai/serving_completion.py | 17 +- vllm/inputs.py | 48 ++++++ vllm/outputs.py | 2 +- vllm/sequence.py | 38 +++-- 14 files changed, 295 insertions(+), 208 deletions(-) create mode 100644 vllm/inputs.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 44da3bad8d840..8b376a379c450 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -55,13 +55,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(prompt_token_ids=dummy_prompt_token_ids, + llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/examples/llava_example.py b/examples/llava_example.py index 3d22b492654bf..31853d7e6ae79 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -25,9 +25,13 @@ def run_llava_pixel_values(): # This should be provided by another online or offline component. images = torch.load("images/stop_sign_pixel_values.pt") - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=images), + }) + for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/tests/conftest.py b/tests/conftest.py index 5c50fc2d1bab6..653bc69b174c5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,7 @@ from vllm import LLM, SamplingParams from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.distributed import destroy_model_parallel +from vllm.inputs import PromptInputs from vllm.sequence import MultiModalData from vllm.transformers_utils.tokenizer import get_tokenizer @@ -320,12 +321,17 @@ def generate( ) -> List[Tuple[List[int], str]]: if images is not None: assert len(prompts) == images.shape[0] - req_outputs = self.model.generate( - prompts, - sampling_params=sampling_params, - multi_modal_data=MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) - if images is not None else None) + + prompt_inputs: List[PromptInputs] = [{ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=images) + if images is not None else None + } for prompt in prompts] + + req_outputs = self.model.generate(prompt_inputs, + sampling_params=sampling_params) outputs = [] for req_output in req_outputs: prompt_str = req_output.prompt diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py index baa463a316902..338b208723ba9 100644 --- a/tests/engine/test_skip_tokenizer_init.py +++ b/tests/engine/test_skip_tokenizer_init.py @@ -14,7 +14,7 @@ def test_skip_tokenizer_initialization(model: str): with pytest.raises(ValueError) as err: llm.generate("abc", sampling_params) assert "prompts must be None if" in str(err.value) - outputs = llm.generate(prompt_token_ids=[[1, 2, 3]], + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, sampling_params=sampling_params) assert len(outputs) > 0 completions = outputs[0].outputs diff --git a/tests/test_sequence.py b/tests/test_sequence.py index b16bdc141e57c..655fb388d95c2 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -24,7 +24,15 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + prompt = Sequence( + int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size, + ) seq_group = SequenceGroup( request_id, [prompt], SamplingParams(use_beam_search=use_beam_search, best_of=best_of), diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 9bc9becb2a6f1..5b43578aad1f2 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -123,8 +123,11 @@ def create_sequence(prompt_token_ids=None): prompt_token_ids = prompt_token_ids or [1] return Sequence( seq_id=0, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": None, + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3a2f7db679358..dd31a2c4e26fb 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -11,11 +11,11 @@ from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.engine.ray_utils import initialize_ray_cluster, ray +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -230,46 +230,51 @@ async def step_async(self) -> List[RequestOutput]: async def encode_request_async( self, request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = await self.tokenizer.encode_async( + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self._require_tokenizer("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = await tokenizer.encode_async( request_id=request_id, - prompt=prompt, + prompt=inputs["prompt"], lora_request=lora_request) - return prompt_token_ids + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) async def add_request_async( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = await self.encode_request_async( + + processed_inputs = await self.encode_request_async( + request_id=request_id, inputs=inputs, lora_request=lora_request) + + return self._add_request( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - - return self.add_request(request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - sampling_params=sampling_params, - arrival_time=arrival_time, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + processed_inputs=processed_inputs, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) async def check_health_async(self) -> None: self.model_executor.check_health() @@ -505,22 +510,26 @@ async def run_engine_loop(self): async def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: - shortened_prompt = prompt - shortened_token_ids = prompt_token_ids - if self.max_log_len is not None: + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + else: + shortened_prompt = inputs.get("prompt") + shortened_token_ids = inputs.get("prompt_token_ids") + + max_log_len = self.max_log_len + if max_log_len is not None: if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:self.max_log_len] + shortened_prompt = shortened_prompt[:max_log_len] if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:self. - max_log_len] + shortened_token_ids = shortened_token_ids[:max_log_len] + logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"sampling_params: {sampling_params}, " @@ -541,39 +550,32 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await ( - self.engine.encode_request_async.remote( # type: ignore - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request)) + processed_inputs = await self.engine.encode_request_async.remote( # type: ignore + request_id=request_id, + inputs=inputs, + lora_request=lora_request) else: - prompt_token_ids = await self.engine.encode_request_async( + processed_inputs = await self.engine.encode_request_async( request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, + inputs=inputs, lora_request=lora_request) stream = self._request_tracker.add_request( request_id, - prompt=prompt, + inputs=processed_inputs, sampling_params=sampling_params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) return stream async def generate( self, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -582,14 +584,10 @@ async def generate( from the LLMEngine to the caller. Args: - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data per request. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -644,12 +642,10 @@ async def generate( try: stream = await self.add_request( request_id, - prompt, + inputs, sampling_params, - prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - multi_modal_data=multi_modal_data, ) async for request_output in stream: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 19e58fb1722cf..b981dc84cf164 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -17,12 +17,12 @@ from vllm.engine.output_processor.util import create_output_by_sequence_group from vllm.engine.ray_utils import initialize_ray_cluster from vllm.executor.executor_base import ExecutorBase +from vllm.inputs import LLMInputs, PromptInputs from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceStage) +from vllm.sequence import SamplerOutput, Sequence, SequenceGroup, SequenceStage from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -78,6 +78,7 @@ class LLMEngine: log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection """ + tokenizer: Optional[BaseTokenizerGroup] def __init__( self, @@ -134,9 +135,8 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - self.tokenizer: BaseTokenizerGroup - self._init_tokenizer() - self.detokenizer = Detokenizer(self.tokenizer) + tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(tokenizer) else: self.detokenizer = None self.tokenizer = None @@ -287,12 +287,23 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") + def _require_tokenizer(self, fail_msg: Optional[str] = None): + if self.tokenizer is None: + if fail_msg is None: + fail_msg = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(None) + return self._require_tokenizer().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + return self._require_tokenizer().get_lora_tokenizer( + sequence.lora_request) def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( @@ -304,9 +315,12 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) + self.tokenizer = get_tokenizer_group( self.parallel_config.tokenizer_pool_config, **init_kwargs) + return self.tokenizer + def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -315,29 +329,81 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) + def _add_request( + self, + request_id: str, + processed_inputs: LLMInputs, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + ) -> None: + max_logprobs = self.get_model_config().max_logprobs + if (sampling_params.logprobs + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): + raise ValueError(f"Cannot request more than " + f"{max_logprobs} logprobs.") + + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + eos_token_id = None + if self.tokenizer: + eos_token_id = self.tokenizer.get_lora_tokenizer( + lora_request).eos_token_id + else: + logger.warning("Use None for EOS token id because tokenizer is " + "not initialized") + seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, + lora_request) + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + # inject the eos token id into the sampling_params to support min_tokens + # processing + sampling_params.eos_token_id = seq.eos_token_id + sampling_params.update_from_generation_config( + self.generation_config_fields) + + # Create the sequence group. + seq_group = SequenceGroup(request_id, [seq], sampling_params, + arrival_time, lora_request) + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + def encode_request( self, - request_id: str, # pylint: disable=unused-argument - prompt: Optional[str], - prompt_token_ids: Optional[List[int]] = None, + request_id: str, + inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, - ): - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(request_id=request_id, - prompt=prompt, - lora_request=lora_request) - return prompt_token_ids + ) -> LLMInputs: + if isinstance(inputs, str): + inputs = {"prompt": inputs} + + if "prompt_token_ids" not in inputs: + tokenizer = self._require_tokenizer("prompts must be None if " + "skip_tokenizer_init is True") + + prompt_token_ids = tokenizer.encode(request_id=request_id, + prompt=inputs["prompt"], + lora_request=lora_request) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + return LLMInputs(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt"), + multi_modal_data=inputs.get("multi_modal_data")) def add_request( self, request_id: str, - prompt: Optional[str], + inputs: PromptInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: """Add a request to the engine's request pool. @@ -347,14 +413,10 @@ def add_request( Args: request_id: The unique ID of the request. - prompt: The prompt string. Can be None if prompt_token_ids is - provided. + inputs: The inputs to the LLM. sampling_params: The sampling parameters for text generation. - prompt_token_ids: The token IDs of the prompt. If None, we - use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use the current monotonic time. - multi_modal_data: Multi modal data per request. Details: - Set arrival_time to the current time if it is None. @@ -383,49 +445,20 @@ def add_request( if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " "not enabled!") - max_logprobs = self.get_model_config().max_logprobs - if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): - raise ValueError(f"Cannot request more than " - f"{max_logprobs} logprobs.") if arrival_time is None: arrival_time = time.time() - prompt_token_ids = self.encode_request( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) - # Create the sequences. - block_size = self.cache_config.block_size - seq_id = next(self.seq_counter) - eos_token_id = None - if self.tokenizer: - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - else: - logger.warning("Use None for EOS token id because tokenizer is " - "not initialized") - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - eos_token_id, lora_request) - - # Defensive copy of SamplingParams, which are used by the sampler, - # this doesn't deep-copy LogitsProcessor objects - sampling_params = sampling_params.clone() - # inject the eos token id into the sampling_params to support min_tokens - # processing - sampling_params.eos_token_id = seq.eos_token_id - sampling_params.update_from_generation_config( - self.generation_config_fields) - - # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, multi_modal_data) + processed_inputs = self.encode_request(request_id=request_id, + inputs=inputs, + lora_request=lora_request) - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + return self._add_request( + request_id=request_id, + processed_inputs=processed_inputs, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: """Aborts a request(s) with the given ID. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b022707794a78..b31b28a15fa4a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,15 +1,14 @@ -from typing import List, Optional, Union +from typing import List, Optional, Sequence, Union -import torch from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine +from vllm.inputs import PromptStrictInputs from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter @@ -131,13 +130,11 @@ def set_tokenizer( def generate( self, - prompts: Optional[Union[str, List[str]]] = None, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], sampling_params: Optional[Union[SamplingParams, - List[SamplingParams]]] = None, - prompt_token_ids: Optional[List[List[int]]] = None, + Sequence[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -146,42 +143,24 @@ def generate( into a single list and pass it to this method. Args: - prompts: A list of prompts to generate completions for. + inputs: A list of inputs to generate completions for. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it is a list, the list must have the same length as the prompts and it is paired one by one with the prompt. - prompt_token_ids: A list of token IDs for the prompts. If None, we - use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. lora_request: LoRA request to use for generation, if any. - multi_modal_data: Multi modal data. Returns: A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") - if self.llm_engine.model_config.skip_tokenizer_init \ - and prompts is not None: - raise ValueError("prompts must be None if skip_tokenizer_init " - "is True") - if isinstance(prompts, str): + if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. - prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + inputs = [inputs] - if prompts is not None: - num_requests = len(prompts) - else: - assert prompt_token_ids is not None - num_requests = len(prompt_token_ids) + num_requests = len(inputs) if sampling_params is None: # Use default sampling params. @@ -191,43 +170,28 @@ def generate( list) and len(sampling_params) != num_requests: raise ValueError("The lengths of prompts and sampling_params " "must be the same.") - if multi_modal_data: - multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - for i in range(num_requests): - prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] + for i, request_inputs in enumerate(inputs): self._add_request( - prompt, + request_inputs, sampling_params[i] - if isinstance(sampling_params, list) else sampling_params, - token_ids, + if isinstance(sampling_params, Sequence) else sampling_params, lora_request=lora_request, - # Get ith image while maintaining the batch dim. - multi_modal_data=MultiModalData( - type=multi_modal_data.type, - data=multi_modal_data.data[i].unsqueeze(0)) - if multi_modal_data else None, ) return self._run_engine(use_tqdm) def _add_request( self, - prompt: Optional[str], + inputs: PromptStrictInputs, sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: request_id = str(next(self.request_counter)) self.llm_engine.add_request(request_id, - prompt, + inputs, sampling_params, - prompt_token_ids, - lora_request=lora_request, - multi_modal_data=multi_modal_data) + lora_request=lora_request) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. @@ -251,4 +215,4 @@ def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # This is necessary because some requests may be finished earlier than # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs \ No newline at end of file + return outputs diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2ff335eb71073..5eb7ad51b64a3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -83,9 +83,15 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) - result_generator = self.engine.generate(prompt_text, sampling_params, - request_id, prompt_ids, - lora_request) + result_generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + request_id, + lora_request, + ) # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 211b2e0424c3e..5786170e2f2a5 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -117,12 +117,17 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens) prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - sampling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids, - lora_request=lora_request)) + generator = self.engine.generate( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + sampling_params, + f"{request_id}-{i}", + lora_request=lora_request, + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/inputs.py b/vllm/inputs.py new file mode 100644 index 0000000000000..bd61f959eeb6e --- /dev/null +++ b/vllm/inputs.py @@ -0,0 +1,48 @@ +from typing import TYPE_CHECKING, List, Optional, TypedDict, Union + +if TYPE_CHECKING: + from vllm.sequence import MultiModalData + + +class MultiModalPrompt(TypedDict, total=False): + multi_modal_data: Optional["MultiModalData"] + """Multi modal data.""" + + +class StringPrompt(MultiModalPrompt, TypedDict): + prompt: str + """The prompt string.""" + + +class TokensPrompt(MultiModalPrompt, TypedDict): + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + +class StringTokensPrompt(MultiModalPrompt, TypedDict): + """It is assumed that :attr:`prompt` is consistent with + :attr:`prompt_token_ids`. This is currently used in + :class:`AsyncLLMEngine` for logging both the text and token IDs.""" + + prompt: str + """The prompt string.""" + + prompt_token_ids: List[int] + """The token IDs of the prompt. If None, we use the + tokenizer to convert the prompts to token IDs.""" + + +PromptStrictInputs = Union[str, StringPrompt, TokensPrompt] +"""The prompt string. More complex inputs should be represented by +:class:`StringPrompt` or :class:`TokensPrompt`.""" + +PromptInputs = Union[str, StringPrompt, TokensPrompt, StringTokensPrompt] +"""As :const:`PromptStrictInputs` but additionally accepts +:class:`StringTokensPrompt`.""" + + +class LLMInputs(TypedDict): + prompt_token_ids: List[int] + prompt: Optional[str] + multi_modal_data: Optional["MultiModalData"] diff --git a/vllm/outputs.py b/vllm/outputs.py index d01be0eb0efd2..78b70dfe107e3 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -74,7 +74,7 @@ class RequestOutput: def __init__( self, request_id: str, - prompt: str, + prompt: Optional[str], prompt_token_ids: List[int], prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], diff --git a/vllm/sequence.py b/vllm/sequence.py index b296b37a84f15..3ea3af0f7cba7 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Union from vllm.block import LogicalTokenBlock +from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams @@ -193,8 +194,7 @@ class Sequence: Args: seq_id: The ID of the sequence. - prompt: The prompt of the sequence. - prompt_token_ids: The token IDs of the prompt. + inputs: The inputs of the sequence. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. lora_request: LoRA request. @@ -203,25 +203,24 @@ class Sequence: def __init__( self, seq_id: int, - prompt: str, - prompt_token_ids: List[int], + inputs: LLMInputs, block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id - self.prompt = prompt + self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request - self.data = SequenceData(prompt_token_ids) + self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] self.output_text = "" self.logical_token_blocks: List[LogicalTokenBlock] = [] # Initialize the logical token blocks with the prompt token ids. - self._append_tokens_to_blocks(prompt_token_ids) + self._append_tokens_to_blocks(self.prompt_token_ids) self.status = SequenceStatus.WAITING self.stop_reason: Union[int, str, None] = None @@ -231,6 +230,18 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def prompt(self) -> Optional[str]: + return self.inputs["prompt"] + + @property + def prompt_token_ids(self) -> List[int]: + return self.inputs["prompt_token_ids"] + + @property + def multi_modal_data(self) -> Optional["MultiModalData"]: + return self.inputs["multi_modal_data"] + @property def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @@ -398,7 +409,6 @@ class SequenceGroup: sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. lora_request: LoRA request. - multi_modal_data: Multi modal data associated with the request. """ def __init__( @@ -408,7 +418,6 @@ def __init__( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest] = None, - multi_modal_data: Optional[MultiModalData] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -421,10 +430,9 @@ def __init__( self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() - self.multi_modal_data = multi_modal_data @property - def prompt(self) -> str: + def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).prompt @@ -433,7 +441,13 @@ def prompt(self) -> str: def prompt_token_ids(self) -> List[int]: # All sequences in the group should have the same prompt. # We use the prompt of an arbitrary sequence. - return next(iter(self.seqs_dict.values())).data.prompt_token_ids + return next(iter(self.seqs_dict.values())).prompt_token_ids + + @property + def multi_modal_data(self) -> Optional[MultiModalData]: + # All sequences in the group should have the same multi-modal data. + # We use the multi-modal data of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).multi_modal_data @property def lora_int_id(self) -> int: From 5db2c5e03ed0b84cf55c26fd9c60f11e4f1bd4b0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 25 Apr 2024 01:51:08 +0000 Subject: [PATCH 16/94] Fix a bunch of tests --- benchmarks/benchmark_throughput.py | 6 +----- tests/conftest.py | 13 ++++++++----- tests/core/test_block_manager.py | 15 ++++++++++++--- tests/core/utils.py | 15 ++++++++++++--- tests/samplers/test_logits_processor.py | 9 +++------ tests/samplers/test_seeded_generate.py | 6 +----- tests/test_cache_block_hashing.py | 11 +++++++++-- tests/tokenization/test_detokenize.py | 2 +- 8 files changed, 47 insertions(+), 30 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 6bb889d1eceba..ae05c3bf742f4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -113,11 +113,7 @@ def run_vllm( max_tokens=output_len, ) # FIXME(woosuk): Do not use internal method. - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=sampling_params, - ) + llm._add_request(prompt, sampling_params=sampling_params) start = time.perf_counter() # FIXME(woosuk): Do not use internal method. diff --git a/tests/conftest.py b/tests/conftest.py index 653bc69b174c5..e187a178d10c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -322,12 +322,15 @@ def generate( if images is not None: assert len(prompts) == images.shape[0] + if images is None: + mm_data = None + else: + mm_data = MultiModalData(type=MultiModalData.Type.IMAGE, + data=images) + prompt_inputs: List[PromptInputs] = [{ - "prompt": - prompt, - "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=images) - if images is not None else None + "prompt": prompt, + "multi_modal_data": mm_data } for prompt in prompts] req_outputs = self.model.generate(prompt_inputs, diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index 62984ef4caabb..62da6c4850f65 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -132,8 +132,11 @@ def test_append_slot_cow(): # Allocate prompt to gpu block. There is one slot left in the block. prompt = Sequence(seq_id=1, - prompt="one two three", - prompt_token_ids=[1, 2, 3], + inputs={ + "prompt": "one two three", + "prompt_token_ids": [1, 2, 3], + "multi_modal_data": None + }, block_size=block_size) # Fork the sequence, such that a COW will be required when we append a new @@ -298,7 +301,13 @@ def test_sliding_window_multi_seq(): assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks - parent = Sequence(1, "one two three", [0, 1, 2], block_size) + parent = Sequence(seq_id=1, + inputs={ + "prompt": "one two three", + "prompt_token_ids": [0, 1, 2], + "multi_modal_data": None + }, + block_size=block_size) seq_group = SequenceGroup("1", [parent], SamplingParams(), time.time(), None) block_manager.allocate(seq_group) diff --git a/tests/core/utils.py b/tests/core/utils.py index 22c1d3826dff4..fd0edce6524f5 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -21,7 +21,13 @@ def create_dummy_prompt( # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) - prompt = Sequence(int(request_id), prompt_str, prompt_tokens, block_size) + prompt = Sequence(int(request_id), + inputs={ + "prompt": prompt_str, + "prompt_token_ids": prompt_tokens, + "multi_modal_data": None, + }, + block_size=block_size) seq_group = SequenceGroup( request_id, [prompt], SamplingParams(use_beam_search=use_beam_search, best_of=best_of), @@ -48,8 +54,11 @@ def create_seq_group( for seq_id_offset, output_len in enumerate(seq_output_lens): seq = Sequence( seq_id=seq_id_start + seq_id_offset, - prompt="", - prompt_token_ids=prompt_token_ids, + inputs={ + "prompt": "", + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, block_size=16, ) diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 3788e9e9752ff..8c877265e71a0 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -35,26 +35,23 @@ def pick_vllm(token_ids, logits): # test logits_processors when prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[0], + example_prompts[0], sampling_params=params_with_logprobs, - prompt_token_ids=None, ) # test prompt_logprobs is not None vllm_model.model._add_request( - prompt=example_prompts[1], + example_prompts[1], sampling_params=SamplingParams( prompt_logprobs=3, max_tokens=max_tokens, ), - prompt_token_ids=None, ) # test grouped requests vllm_model.model._add_request( - prompt=example_prompts[2], + example_prompts[2], sampling_params=SamplingParams(max_tokens=max_tokens), - prompt_token_ids=None, ) outputs = vllm_model.model._run_engine(False) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 3cd659cef58da..ba8070cd16dbc 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -57,11 +57,7 @@ def test_random_sample_with_seed( sampling_params_seed_1, sampling_params_seed_2, ): - llm._add_request( - prompt=prompt, - prompt_token_ids=None, - sampling_params=params, - ) + llm._add_request(prompt, sampling_params=params) results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] diff --git a/tests/test_cache_block_hashing.py b/tests/test_cache_block_hashing.py index 3b257ac062f56..97864af88e40a 100644 --- a/tests/test_cache_block_hashing.py +++ b/tests/test_cache_block_hashing.py @@ -70,8 +70,15 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int, for prompt in prompts: hashes[-1].append([]) prompt_token_ids = tokenizer.encode(prompt) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, - tokenizer.tokenizer.eos_token_id, lora_request) + seq = Sequence(seq_id, + inputs={ + "prompt": prompt, + "prompt_token_ids": prompt_token_ids, + "multi_modal_data": None, + }, + block_size=block_size, + eos_token_id=tokenizer.tokenizer.eos_token_id, + lora_request=lora_request) num_blocks = len(prompt_token_ids) // block_size for idx in range(num_blocks): diff --git a/tests/tokenization/test_detokenize.py b/tests/tokenization/test_detokenize.py index 5b43578aad1f2..1d4c74d6bd8da 100644 --- a/tests/tokenization/test_detokenize.py +++ b/tests/tokenization/test_detokenize.py @@ -124,7 +124,7 @@ def create_sequence(prompt_token_ids=None): return Sequence( seq_id=0, inputs={ - "prompt": None, + "prompt": "", "prompt_token_ids": prompt_token_ids, "multi_modal_data": None, }, From 74c5905d96707ef32a7846e32ceb0bfa4f27c1fa Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 25 Apr 2024 05:13:45 +0000 Subject: [PATCH 17/94] Fix LLaVA test --- examples/llava_example.py | 15 +++++++++------ tests/conftest.py | 20 +++++++++++--------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/examples/llava_example.py b/examples/llava_example.py index 31853d7e6ae79..60250c4303fbf 100644 --- a/examples/llava_example.py +++ b/examples/llava_example.py @@ -23,13 +23,13 @@ def run_llava_pixel_values(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_pixel_values.pt") + image = torch.load("images/stop_sign_pixel_values.pt") outputs = llm.generate({ "prompt": prompt, "multi_modal_data": - MultiModalData(type=MultiModalData.Type.IMAGE, data=images), + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), }) for o in outputs: @@ -50,11 +50,14 @@ def run_llava_image_features(): "\nUSER: What is the content of this image?\nASSISTANT:") # This should be provided by another online or offline component. - images = torch.load("images/stop_sign_image_features.pt") + image = torch.load("images/stop_sign_image_features.pt") - outputs = llm.generate(prompt, - multi_modal_data=MultiModalData( - type=MultiModalData.Type.IMAGE, data=images)) + outputs = llm.generate({ + "prompt": + prompt, + "multi_modal_data": + MultiModalData(type=MultiModalData.Type.IMAGE, data=image), + }) for o in outputs: generated_text = o.outputs[0].text print(generated_text) diff --git a/tests/conftest.py b/tests/conftest.py index e187a178d10c1..b86e67c6c4da7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -322,16 +322,18 @@ def generate( if images is not None: assert len(prompts) == images.shape[0] - if images is None: - mm_data = None - else: - mm_data = MultiModalData(type=MultiModalData.Type.IMAGE, - data=images) + prompt_inputs: List[PromptInputs] = [] + for i, prompt in enumerate(prompts): + image = None if images is None else images[i:i + 1] + mm_data = None if image is None else MultiModalData( + type=MultiModalData.Type.IMAGE, + data=image, + ) - prompt_inputs: List[PromptInputs] = [{ - "prompt": prompt, - "multi_modal_data": mm_data - } for prompt in prompts] + prompt_inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) req_outputs = self.model.generate(prompt_inputs, sampling_params=sampling_params) From b49aba766d0c349e2bf2a2c186200ecfad751ce7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 25 Apr 2024 05:19:57 +0000 Subject: [PATCH 18/94] Fix `benchmark_latency` test --- benchmarks/benchmark_latency.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8b376a379c450..8932788cac119 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -2,13 +2,14 @@ import argparse import time from pathlib import Path -from typing import Optional +from typing import List, Optional import numpy as np import torch from tqdm import tqdm from vllm import LLM, SamplingParams +from vllm.inputs import PromptStrictInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -44,7 +45,9 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_prompt_token_ids = dummy_prompt_token_ids.tolist() + dummy_inputs: List[PromptStrictInputs] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: @@ -55,13 +58,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate({"prompt_token_ids": dummy_prompt_token_ids}, + llm.generate(dummy_inputs, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() From c4f35401a58ab4e400131f2165fcb8474e63b459 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 03:42:47 +0000 Subject: [PATCH 19/94] Clarify tokenizer usage --- vllm/engine/async_llm_engine.py | 4 ++-- vllm/engine/llm_engine.py | 29 +++++++++++++++-------------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 1d52e9265ff48..332dfd64ba37d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -239,8 +239,8 @@ async def encode_request_async( inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: - tokenizer = self._require_tokenizer("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") prompt_token_ids = await tokenizer.encode_async( request_id=request_id, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8c9ae99c4a78c..622b6d2819695 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -145,6 +145,8 @@ def __init__( self.decoding_config = decoding_config or DecodingConfig() self.log_stats = log_stats + self.tokenizer: Optional[BaseTokenizerGroup] + if not self.model_config.skip_tokenizer_init: tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(tokenizer) @@ -301,28 +303,27 @@ def __reduce__(self): # the closure used to initialize Ray worker actors raise RuntimeError("LLMEngine should not be pickled!") - def _require_tokenizer(self, fail_msg: Optional[str] = None): - if self.tokenizer is None: - if fail_msg is None: - fail_msg = ("Unable to get tokenizer because " - "skip_tokenizer_init is True") - - raise ValueError(fail_msg) - - return self.tokenizer - def __del__(self): # Shutdown model executor when engine is garbage collected # Use getattr since __init__ can fail before the field is set if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() + MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " + "skip_tokenizer_init is True") + + def get_tokenizer_group(self, fail_msg: str = MISSING_TOKENIZER_GROUP_MSG): + if self.tokenizer is None: + raise ValueError(fail_msg) + + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": - return self._require_tokenizer().get_lora_tokenizer(None) + return self.get_tokenizer_group().get_lora_tokenizer(None) def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": - return self._require_tokenizer().get_lora_tokenizer( + return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) def _init_tokenizer(self, **tokenizer_init_kwargs): @@ -405,8 +406,8 @@ def encode_request( inputs = {"prompt": inputs} if "prompt_token_ids" not in inputs: - tokenizer = self._require_tokenizer("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group("prompts must be None if " + "skip_tokenizer_init is True") prompt_token_ids = tokenizer.encode(request_id=request_id, prompt=inputs["prompt"], From ab8182ce45c0897b22da7fa23244eb00bdddfc92 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 03:55:21 +0000 Subject: [PATCH 20/94] Rename `encode_request -> process_model_inputs` --- tests/async_engine/test_async_llm_engine.py | 2 +- vllm/engine/async_llm_engine.py | 17 +++++++++-------- vllm/engine/llm_engine.py | 12 ++++++------ 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index b69cdc0a21409..10a46422887e3 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,7 +25,7 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] - async def encode_request_async(self, *args, **kwargs): + async def process_model_inputs_async(self, *args, **kwargs): pass def generate(self, request_id): diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 332dfd64ba37d..7d0ee745813d9 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -229,7 +229,7 @@ async def step_async(self) -> List[RequestOutput]: return request_outputs - async def encode_request_async( + async def process_model_inputs_async( self, request_id: str, # pylint: disable=unused-argument inputs: PromptInputs, @@ -267,10 +267,10 @@ async def add_request_async( if arrival_time is None: arrival_time = time.time() - processed_inputs = await self.encode_request_async( + processed_inputs = await self.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request) - return self._add_request( + return self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, sampling_params=sampling_params, @@ -552,12 +552,13 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - processed_inputs = await self.engine.encode_request_async.remote( # type: ignore - request_id=request_id, - inputs=inputs, - lora_request=lora_request) + processed_inputs = await self.engine.process_model_inputs_async \ + .remote( # type: ignore + request_id=request_id, + inputs=inputs, + lora_request=lora_request) else: - processed_inputs = await self.engine.encode_request_async( + processed_inputs = await self.engine.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 622b6d2819695..6a937347e0bec 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -350,7 +350,7 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) - def _add_request( + def _add_processed_request( self, request_id: str, processed_inputs: LLMInputs, @@ -396,7 +396,7 @@ def _add_request( # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) - def encode_request( + def process_model_inputs( self, request_id: str, inputs: PromptInputs, @@ -470,11 +470,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.encode_request(request_id=request_id, - inputs=inputs, - lora_request=lora_request) + processed_inputs = self.process_model_inputs(request_id=request_id, + inputs=inputs, + lora_request=lora_request) - return self._add_request( + return self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, sampling_params=sampling_params, From eac33e1f7d1747dd4147423d45ef3c9d65f14b95 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 06:35:09 +0000 Subject: [PATCH 21/94] Support old API in `LLM.generate` --- .buildkite/test-pipeline.yaml | 1 + tests/test_inputs.py | 53 ++++++++++ vllm/entrypoints/llm.py | 183 +++++++++++++++++++++++++++++++++- vllm/inputs.py | 62 +++++++++++- vllm/utils.py | 28 +++++- 5 files changed, 321 insertions(+), 6 deletions(-) create mode 100644 tests/test_inputs.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e49a5650c44ea..60b57ea7a5f78 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -53,6 +53,7 @@ steps: - label: Entrypoints Test commands: + - pytest -v -s test_inputs.py # these tests have to be separated, because each one will allocate all posible GPU memory - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - pytest -v -s entrypoints/test_server_oot_registration.py diff --git a/tests/test_inputs.py b/tests/test_inputs.py new file mode 100644 index 0000000000000..887c7101decda --- /dev/null +++ b/tests/test_inputs.py @@ -0,0 +1,53 @@ +from typing import List + +import pytest + +from vllm.inputs import parse_and_batch_prompt + +STRING_INPUTS = [ + '', + 'foo', + 'foo bar', + 'foo baz bar', + 'foo bar qux baz', +] + +TOKEN_INPUTS = [ + [-1], + [1], + [1, 2], + [1, 3, 4], + [1, 2, 4, 3], +] + +INPUTS_SLICES = [ + slice(None, None, -1), + slice(None, None, 2), + slice(None, None, -2), +] + + +def test_parse_single_batch_empty(): + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([]) + + with pytest.raises(ValueError, match="at least one prompt"): + parse_and_batch_prompt([[]]) + + +@pytest.mark.parametrize('string_input', STRING_INPUTS) +def test_parse_single_batch_string_consistent(string_input: str): + assert parse_and_batch_prompt(string_input) \ + == parse_and_batch_prompt([string_input]) + + +@pytest.mark.parametrize('token_input', TOKEN_INPUTS) +def test_parse_single_batch_token_consistent(token_input: List[int]): + assert parse_and_batch_prompt(token_input) \ + == parse_and_batch_prompt([token_input]) + + +@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES) +def test_parse_single_batch_string_slice(inputs_slice: slice): + assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \ + == parse_and_batch_prompt(STRING_INPUTS[inputs_slice]) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b31b28a15fa4a..90b0423698ac6 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,16 +1,18 @@ -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, overload from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import PromptStrictInputs +from vllm.inputs import (PromptInputs, PromptStrictInputs, + parse_and_batch_prompt) from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import MultiModalData from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter +from vllm.utils import Counter, deprecate_kwargs class LLM: @@ -128,13 +130,96 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer + @overload # DEPRECATED: single (prompt + optional token ids) + def generate( + self, + prompts: str, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[int]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: multi (prompt + optional token ids) + def generate( + self, + prompts: List[str], + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: single (token ids + optional prompt) + def generate( + self, + prompts: Optional[str] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[int], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: multi (token ids + optional prompt) + def generate( + self, + prompts: Optional[List[str]] = None, + sampling_params: Optional[Union[SamplingParams, + List[SamplingParams]]] = None, + *, + prompt_token_ids: List[List[int]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload # DEPRECATED: single or multi token ids [pos-only] + def generate( + self, + prompts: None, + sampling_params: None, + prompt_token_ids: Union[List[int], List[List[int]]], + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, + ) -> List[RequestOutput]: + ... + + @overload def generate( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + /, # We may enable `inputs` keyword after removing the old API + *, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, + ) -> List[RequestOutput]: + ... + + @deprecate_kwargs('prompts', 'prompt_token_ids', 'multi_modal_data') + def generate( + self, + prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + Optional[Union[str, List[str]]]] = None, + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]] = None, + prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -156,6 +241,96 @@ def generate( A list of `RequestOutput` objects containing the generated completions in the same order as the input prompts. """ + if prompt_token_ids is not None or multi_modal_data is not None: + return self._generate_v1( + prompts=prompts, # type: ignore + sampling_params=sampling_params, + prompt_token_ids=prompt_token_ids, + use_tqdm=use_tqdm, + lora_request=lora_request, + multi_modal_data=multi_modal_data, + ) + + return self._generate_v2( + inputs=prompts, # type: ignore + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + # DEPRECATED + def _generate_v1( + self, + prompts: Optional[Union[str, List[str]]], + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]], + prompt_token_ids: Optional[Union[List[int], List[List[int]]]], + use_tqdm: bool, + lora_request: Optional[LoRARequest], + multi_modal_data: Optional[MultiModalData], + ) -> List[RequestOutput]: + # skip_tokenizer_init is now checked in engine + + if prompts is not None: + prompts = [p["text"] for p in parse_and_batch_prompt(prompts)] + if prompt_token_ids is not None: + prompt_token_ids = [ + p["text"] for p in parse_and_batch_prompt(prompt_token_ids) + ] + + num_requests = None + if prompts is not None: + num_requests = len(prompts) + if prompt_token_ids is not None: + if (num_requests is not None + and num_requests != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + + num_requests = len(prompt_token_ids) + if num_requests is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + + inputs: List[PromptInputs] = [] + for i in range(num_requests): + if prompts is not None: + if prompt_token_ids is not None: + inputs.append({ + "prompt": prompts[i], + "prompt_token_ids": prompt_token_ids[i], + "multi_modal_data": multi_modal_data, + }) + else: + inputs.append({ + "prompt": prompts[i], + "multi_modal_data": multi_modal_data, + }) + else: + if prompt_token_ids is not None: + inputs.append({ + "prompt_token_ids": prompt_token_ids[i], + "multi_modal_data": multi_modal_data, + }) + else: + raise AssertionError + + # sampling_params is now checked in _generate_v2 + return self._generate_v2( + inputs, + sampling_params=sampling_params, + use_tqdm=use_tqdm, + lora_request=lora_request, + ) + + def _generate_v2( + self, + inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + sampling_params: Optional[Union[SamplingParams, + Sequence[SamplingParams]]], + use_tqdm: bool, + lora_request: Optional[LoRARequest], + ) -> List[RequestOutput]: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. inputs = [inputs] @@ -183,7 +358,7 @@ def generate( def _add_request( self, - inputs: PromptStrictInputs, + inputs: PromptInputs, sampling_params: SamplingParams, lora_request: Optional[LoRARequest] = None, ) -> None: diff --git a/vllm/inputs.py b/vllm/inputs.py index bd61f959eeb6e..2b5ea1c0f3828 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -1,9 +1,69 @@ -from typing import TYPE_CHECKING, List, Optional, TypedDict, Union +from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, + TypedDict, Union, cast, overload) if TYPE_CHECKING: from vllm.sequence import MultiModalData +class ParsedString(TypedDict): + text: str + is_tokens: Literal[False] + + +class ParsedTokens(TypedDict): + text: List[int] + is_tokens: Literal[True] + + +# https://github.com/vllm-project/vllm/pull/4028 +@overload +def parse_and_batch_prompt( + prompt: Union[str, List[str]]) -> Sequence[ParsedString]: + ... + + +@overload +def parse_and_batch_prompt( + prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]: + ... + + +def parse_and_batch_prompt( + prompt: Union[str, List[str], List[int], List[List[int]]], +) -> Union[Sequence[ParsedString], Sequence[ParsedTokens]]: + if isinstance(prompt, str): + # case 1: a string + return [ParsedString(text=prompt, is_tokens=False)] + + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0], str): + # case 2: array of strings + return [ + ParsedString(text=elem, is_tokens=False) + for elem in cast(List[str], prompt) + ] + if isinstance(prompt[0], int): + # case 3: array of tokens + elem = cast(List[int], prompt) + return [ParsedTokens(text=elem, is_tokens=True)] + if isinstance(prompt[0], list): + if len(prompt[0]) == 0: + raise ValueError("please provide at least one prompt") + + if isinstance(prompt[0][0], int): + # case 4: array of token arrays + return [ + ParsedTokens(text=elem, is_tokens=True) + for elem in cast(List[List[int]], prompt) + ] + + raise ValueError("prompt must be a string, array of strings, " + "array of tokens, or array of token arrays") + + class MultiModalPrompt(TypedDict, total=False): multi_modal_data: Optional["MultiModalData"] """Multi modal data.""" diff --git a/vllm/utils.py b/vllm/utils.py index ce55253ce2199..784b9bb29db8e 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -11,7 +11,7 @@ import uuid import warnings from collections import defaultdict -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Tuple, TypeVar, @@ -638,3 +638,29 @@ def enable_trace_function_call_for_thread() -> None: filename) os.makedirs(os.path.dirname(log_path), exist_ok=True) enable_trace_function_call(log_path) + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs(*kws: str) -> Callable[[F], F]: + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + deprecated_kws = {k for k in kwargs if k in kws} + if deprecated_kws: + warnings.warn( + DeprecationWarning( + f"The keyword arguments {deprecated_kws}" + " are deprecated and will be removed in " + "a future update."), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper From 703d318dea3479cd5f1edc8ecbe8beec2860cfa0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 15:16:32 +0000 Subject: [PATCH 22/94] Add tests to ensure old API still works - To facilitate equality tests, `CompletionOutput` is now a dataclass --- tests/entrypoints/__init__.py | 0 tests/entrypoints/test_llm_generate.py | 124 +++++++++++++++++++++---- vllm/outputs.py | 29 ++---- vllm/utils.py | 10 +- 4 files changed, 123 insertions(+), 40 deletions(-) create mode 100644 tests/entrypoints/__init__.py diff --git a/tests/entrypoints/__init__.py b/tests/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 5e8b7ca4d9977..42bd6ef39c440 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -1,21 +1,113 @@ +from typing import List + import pytest -from vllm import LLM, SamplingParams +from vllm import LLM, RequestOutput, SamplingParams + +from ..conftest import cleanup + +MODEL_NAME = "facebook/opt-125m" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +TOKEN_IDS = [ + [0], + [0, 1], + [0, 2, 1], + [0, 3, 1, 2], +] -def test_multiple_sampling_params(): - llm = LLM(model="facebook/opt-125m", +@pytest.fixture(scope="module") +def llm(): + yield LLM(model="facebook/opt-125m", max_num_batched_tokens=4096, + enforce_eager=True, tensor_parallel_size=1) - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] + cleanup() + + +def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=prompt, + sampling_params=sampling_params) + + v2_output = llm.generate(prompt, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate({"prompt": prompt}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params) + + v2_output = llm.generate({"prompt_token_ids": prompt_token_ids}, + sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.generate(prompts=PROMPTS, + sampling_params=sampling_params) + + v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.generate( + [{ + "prompt": p + } for p in PROMPTS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.generate(prompt_token_ids=TOKEN_IDS, + sampling_params=sampling_params) + + v2_output = llm.generate( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + sampling_params=sampling_params, + ) + assert_outputs_equal(v1_output, v2_output) + +@pytest.mark.skip_global_cleanup +def test_multiple_sampling_params(llm: LLM): sampling_params = [ SamplingParams(temperature=0.01, top_p=0.95), SamplingParams(temperature=0.3, top_p=0.95), @@ -24,18 +116,18 @@ def test_multiple_sampling_params(): ] # Multiple SamplingParams should be matched with each prompt - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params) + assert len(PROMPTS) == len(outputs) # Exception raised, if the size of params does not match the size of prompts with pytest.raises(ValueError): - outputs = llm.generate(prompts, sampling_params=sampling_params[:3]) + outputs = llm.generate(PROMPTS, sampling_params=sampling_params[:3]) # Single SamplingParams should be applied to every prompt single_sampling_params = SamplingParams(temperature=0.3, top_p=0.95) - outputs = llm.generate(prompts, sampling_params=single_sampling_params) - assert len(prompts) == len(outputs) + outputs = llm.generate(PROMPTS, sampling_params=single_sampling_params) + assert len(PROMPTS) == len(outputs) # sampling_params is None, default params should be applied - outputs = llm.generate(prompts, sampling_params=None) - assert len(prompts) == len(outputs) \ No newline at end of file + outputs = llm.generate(PROMPTS, sampling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/vllm/outputs.py b/vllm/outputs.py index 78b70dfe107e3..f137c9e89a673 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from typing import List, Optional, Union from vllm.lora.request import LoRARequest @@ -6,6 +7,7 @@ SequenceGroup, SequenceStatus) +@dataclass class CompletionOutput: """The output data of one completion output of a request. @@ -24,25 +26,14 @@ class CompletionOutput: lora_request: The LoRA request that was used to generate the output. """ - def __init__( - self, - index: int, - text: str, - token_ids: List[int], - cumulative_logprob: float, - logprobs: Optional[SampleLogprobs], - finish_reason: Optional[str] = None, - stop_reason: Union[int, str, None] = None, - lora_request: Optional[LoRARequest] = None, - ) -> None: - self.index = index - self.text = text - self.token_ids = token_ids - self.cumulative_logprob = cumulative_logprob - self.logprobs = logprobs - self.finish_reason = finish_reason - self.stop_reason = stop_reason - self.lora_request = lora_request + index: int + text: str + token_ids: List[int] + cumulative_logprob: float + logprobs: Optional[SampleLogprobs] + finish_reason: Optional[str] = None + stop_reason: Union[int, str, None] = None + lora_request: Optional[LoRARequest] = None def finished(self) -> bool: return self.finish_reason is not None diff --git a/vllm/utils.py b/vllm/utils.py index 784b9bb29db8e..a3fb89612e54b 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -644,18 +644,18 @@ def enable_trace_function_call_for_thread() -> None: def deprecate_kwargs(*kws: str) -> Callable[[F], F]: + deprecated_kws = set(kws) def wrapper(fn: F) -> F: @wraps(fn) def inner(*args, **kwargs): - deprecated_kws = {k for k in kwargs if k in kws} - if deprecated_kws: + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: warnings.warn( DeprecationWarning( - f"The keyword arguments {deprecated_kws}" - " are deprecated and will be removed in " - "a future update."), + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update."), stacklevel=3, # The inner function takes up one level ) From 19d85f990bd0878cc6ad151d76bad3d282d0c674 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 3 May 2024 17:21:02 +0000 Subject: [PATCH 23/94] Let all entrypoints tests be run at the same time --- .buildkite/test-pipeline.yaml | 4 +--- pyproject.toml | 5 +++++ tests/async_engine/test_openapi_server_ray.py | 4 ++-- tests/entrypoints/test_llm_generate.py | 5 +++-- tests/entrypoints/test_openai_server.py | 10 ++++------ tests/entrypoints/test_server_oot_registration.py | 7 ++++--- 6 files changed, 19 insertions(+), 16 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 60b57ea7a5f78..5f569693e0af6 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -54,9 +54,7 @@ steps: - label: Entrypoints Test commands: - pytest -v -s test_inputs.py - # these tests have to be separated, because each one will allocate all posible GPU memory - - pytest -v -s entrypoints --ignore=entrypoints/test_server_oot_registration.py - - pytest -v -s entrypoints/test_server_oot_registration.py + - pytest -v -s entrypoints - label: Examples Test working_dir: "/vllm-workspace/examples" diff --git a/pyproject.toml b/pyproject.toml index 6a448defc16e1..ead64b7436121 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,3 +65,8 @@ skip = "./tests/prompts,./benchmarks/sonnet.txt" [tool.isort] use_parentheses = true skip_gitignore = true + +[tool.pytest.ini_options] +markers = [ + "skip_global_cleanup" +] diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index 4b97af88012b9..2a754f5c4ccab 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -55,7 +55,7 @@ def __del__(self): self.proc.terminate() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(): ray.init() server_runner = ServerRunner.remote([ @@ -74,7 +74,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 42bd6ef39c440..fe3d3fdf9a93d 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -27,8 +27,9 @@ def llm(): yield LLM(model="facebook/opt-125m", max_num_batched_tokens=4096, - enforce_eager=True, - tensor_parallel_size=1) + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) cleanup() diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 1323dba469117..60411c9e767d1 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -121,7 +121,7 @@ def zephyr_lora_files(): return snapshot_download(repo_id=LORA_NAME) -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(zephyr_lora_files): ray.init() server_runner = ServerRunner.remote([ @@ -133,6 +133,8 @@ def server(zephyr_lora_files): "--max-model-len", "8192", "--enforce-eager", + "--gpu-memory-utilization", + "0.75", # lora config below "--enable-lora", "--lora-modules", @@ -150,7 +152,7 @@ def server(zephyr_lora_files): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", @@ -888,7 +890,3 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): assert ("greater_than_equal" in exc_info.value.message or "less_than_equal" in exc_info.value.message) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 22e65bf7e7da1..dd43a5cf0a248 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -26,15 +26,16 @@ def server_function(port): # register our dummy model ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM) sys.argv = ["placeholder.py"] + \ - ("--model facebook/opt-125m --dtype" - f" float32 --api-key token-abc123 --port {port}").split() + ("--model facebook/opt-125m --gpu-memory-utilization 0.10 " + f"--dtype float32 --api-key token-abc123 --port {port}").split() import runpy runpy.run_module('vllm.entrypoints.openai.api_server', run_name='__main__') def test_oot_registration_for_api_server(): port = get_open_port() - server = multiprocessing.Process(target=server_function, args=(port, )) + ctx = multiprocessing.get_context("spawn") + server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( base_url=f"http://localhost:{port}/v1", From 5759dfa619f0310b5ccb2957a2d259a90254f87e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:01:39 +0000 Subject: [PATCH 24/94] Add tests for LLM.encode and fix corresponding bugs --- tests/entrypoints/test_llm_encode.py | 137 +++++++++++++++++++++++++ tests/entrypoints/test_llm_generate.py | 2 +- vllm/outputs.py | 9 +- 3 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 tests/entrypoints/test_llm_encode.py diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py new file mode 100644 index 0000000000000..b6ae4bc498dc6 --- /dev/null +++ b/tests/entrypoints/test_llm_encode.py @@ -0,0 +1,137 @@ +from typing import List + +import pytest + +from vllm import LLM, EmbeddingRequestOutput, PoolingParams + +from ..conftest import cleanup + +MODEL_NAME = "intfloat/e5-mistral-7b-instruct" + +PROMPTS = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +TOKEN_IDS = [ + # Using ID={0, 1, 2, 3} results in NaN values, + # so we add this offset of 1000 + [1000], + [1000, 1001], + [1000, 1002, 1001], + [1000, 1003, 1001, 1002], +] + + +@pytest.fixture(scope="module") +def llm(): + yield LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.90, + enforce_eager=True) + + cleanup() + + +def assert_outputs_equal(o1: List[EmbeddingRequestOutput], + o2: List[EmbeddingRequestOutput]): + assert [o.outputs for o in o1] == [o.outputs for o in o2] + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt', PROMPTS) +def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=prompt, + pooling_params=pooling_params) + + v2_output = llm.encode(prompt, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode({"prompt": prompt}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +@pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) +def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, + prompt_token_ids): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=prompt_token_ids, + pooling_params=pooling_params) + + v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, + pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompts'"): + v1_output = llm.encode(prompts=PROMPTS, + pooling_params=pooling_params) + + v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) + assert_outputs_equal(v1_output, v2_output) + + v2_output = llm.encode( + [{ + "prompt": p + } for p in PROMPTS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): + pooling_params = PoolingParams() + + with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): + v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, + pooling_params=pooling_params) + + v2_output = llm.encode( + [{ + "prompt_token_ids": p + } for p in TOKEN_IDS], + pooling_params=pooling_params, + ) + assert_outputs_equal(v1_output, v2_output) + + +@pytest.mark.skip_global_cleanup +def test_multiple_pooling_params(llm: LLM): + pooling_params = [ + PoolingParams(), + PoolingParams(), + PoolingParams(), + PoolingParams(), + ] + + # Multiple PoolingParams should be matched with each prompt + outputs = llm.encode(PROMPTS, pooling_params=pooling_params) + assert len(PROMPTS) == len(outputs) + + # Exception raised, if the size of params does not match the size of prompts + with pytest.raises(ValueError): + outputs = llm.encode(PROMPTS, pooling_params=pooling_params[:3]) + + # Single PoolingParams should be applied to every prompt + single_pooling_params = PoolingParams() + outputs = llm.encode(PROMPTS, pooling_params=single_pooling_params) + assert len(PROMPTS) == len(outputs) + + # pooling_params is None, default params should be applied + outputs = llm.encode(PROMPTS, pooling_params=None) + assert len(PROMPTS) == len(outputs) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index fe3d3fdf9a93d..8ee08f8e83961 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -25,7 +25,7 @@ @pytest.fixture(scope="module") def llm(): - yield LLM(model="facebook/opt-125m", + yield LLM(model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, diff --git a/vllm/outputs.py b/vllm/outputs.py index 8bf3e236d532d..49f526b5f9300 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -48,6 +48,7 @@ def __repr__(self) -> str: f"stop_reason={self.stop_reason})") +@dataclass class EmbeddingOutput: """The output data of one completion output of a request. @@ -56,15 +57,11 @@ class EmbeddingOutput: length of vector depends on the model as listed in the embedding guide. """ - def __init__( - self, - embedding: List[float], - ) -> None: - self.embedding = embedding + embedding: List[float] def __repr__(self) -> str: return (f"EmbeddingOutput(" - f"embedding={len(self.embedding)}") + f"embedding={len(self.embedding)})") class RequestOutput: From cc4bfb5416957336833ecd439d4da51d95b084e5 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:03:41 +0000 Subject: [PATCH 25/94] Apply formatter --- tests/entrypoints/test_llm_encode.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index b6ae4bc498dc6..4cf8b7fbafa8e 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -47,14 +47,12 @@ def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): pooling_params = PoolingParams() with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=prompt, - pooling_params=pooling_params) + v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) v2_output = llm.encode(prompt, pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) - v2_output = llm.encode({"prompt": prompt}, - pooling_params=pooling_params) + v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) @@ -66,10 +64,10 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): v1_output = llm.encode(prompt_token_ids=prompt_token_ids, - pooling_params=pooling_params) + pooling_params=pooling_params) v2_output = llm.encode({"prompt_token_ids": prompt_token_ids}, - pooling_params=pooling_params) + pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) @@ -78,8 +76,7 @@ def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): pooling_params = PoolingParams() with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=PROMPTS, - pooling_params=pooling_params) + v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) assert_outputs_equal(v1_output, v2_output) @@ -99,7 +96,7 @@ def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): with pytest.warns(DeprecationWarning, match="'prompt_token_ids'"): v1_output = llm.encode(prompt_token_ids=TOKEN_IDS, - pooling_params=pooling_params) + pooling_params=pooling_params) v2_output = llm.encode( [{ From d5c9731f0d64b3bdd0c03df621a1f45b954a9eae Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:10:22 +0000 Subject: [PATCH 26/94] Rename `_add_requests` to `_validate_and_add_requests` to be more similar to the original `_validate_and_prepare_requests` --- vllm/entrypoints/llm.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a6175ff4485d1..59709e050325b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -265,7 +265,7 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() - return self._add_requests( + return self._validate_and_add_requests( inputs=inputs, params=sampling_params, use_tqdm=use_tqdm, @@ -395,7 +395,7 @@ def encode( # Use default pooling params. pooling_params = PoolingParams() - return self._add_requests( + return self._validate_and_add_requests( inputs=inputs, params=pooling_params, use_tqdm=use_tqdm, @@ -458,7 +458,7 @@ def _convert_v1_inputs( return inputs @overload - def _add_requests( + def _validate_and_add_requests( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[SamplingParams, Sequence[SamplingParams]], @@ -468,7 +468,7 @@ def _add_requests( ... @overload - def _add_requests( # type: ignore[misc] + def _validate_and_add_requests( # type: ignore[misc] self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[PoolingParams, Sequence[PoolingParams]], @@ -477,7 +477,7 @@ def _add_requests( # type: ignore[misc] ) -> List[EmbeddingRequestOutput]: ... - def _add_requests( + def _validate_and_add_requests( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, From 4f218a52ec830b6d39436ba84d7e2404851b1c18 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 03:15:52 +0000 Subject: [PATCH 27/94] Separate `entrypoints` tests into two groups --- .buildkite/test-pipeline.yaml | 3 ++- tests/entrypoints/openai/test_serving_chat.py | 4 +++ tests/entrypoints/test_guided_processors.py | 2 ++ tests/entrypoints/test_llm_encode.py | 2 ++ tests/entrypoints/test_llm_generate.py | 2 ++ tests/entrypoints/test_openai_server.py | 27 ++++++++++++++++++- .../test_server_oot_registration.py | 3 +++ 7 files changed, 41 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 176fa1d39db46..f0bab7c87ad0f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -52,7 +52,8 @@ steps: - label: Entrypoints Test commands: - pytest -v -s test_inputs.py - - pytest -v -s entrypoints + - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m openai - label: Examples Test working_dir: "/vllm-workspace/examples" diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 74b49726734b5..c45f02fe564a3 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,11 +1,15 @@ import asyncio from dataclasses import dataclass +import pytest + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" +pytestmark = pytest.mark.openai + @dataclass class MockModelConfig: diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 41c871ca40bc8..5d4163e96fd87 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -52,6 +52,8 @@ TEST_REGEX = (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") +pytestmark = pytest.mark.openai + def test_guided_logits_processors(): """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 4cf8b7fbafa8e..c9833b0c315bf 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -24,6 +24,8 @@ [1000, 1003, 1001, 1002], ] +pytestmark = pytest.mark.llm + @pytest.fixture(scope="module") def llm(): diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 8ee08f8e83961..e21e4b8136746 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -22,6 +22,8 @@ [0, 3, 1, 2], ] +pytestmark = pytest.mark.llm + @pytest.fixture(scope="module") def llm(): diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 75f0fb0be4f08..7050cc8ebe3d1 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -71,7 +71,7 @@ "Swift", "Kotlin" ] -pytestmark = pytest.mark.asyncio +pytestmark = pytest.mark.openai @pytest.fixture(scope="session") @@ -138,6 +138,7 @@ def client(): yield client +@pytest.mark.asyncio async def test_check_models(server, client: openai.AsyncOpenAI): models = await client.models.list() models = models.data @@ -149,6 +150,7 @@ async def test_check_models(server, client: openai.AsyncOpenAI): assert lora_models[1].id == "zephyr-lora2" +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -180,6 +182,7 @@ async def test_single_completion(server, client: openai.AsyncOpenAI, completion.choices[0].text) >= 5 +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -201,6 +204,7 @@ async def test_zero_logprobs(server, client: openai.AsyncOpenAI, assert choice.logprobs.top_logprobs is None +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -245,6 +249,7 @@ async def test_single_chat_session(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, model_name: str): @@ -300,6 +305,7 @@ async def test_too_many_logprobs(server, client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -337,6 +343,7 @@ async def test_completion_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == single_output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -387,6 +394,7 @@ async def test_chat_streaming(server, client: openai.AsyncOpenAI, assert "".join(chunks) == output +@pytest.mark.asyncio @pytest.mark.parametrize( # just test 1 lora hereafter "model_name", @@ -440,6 +448,7 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI, assert texts[0] == texts[1] +@pytest.mark.asyncio async def test_logits_bias(server, client: openai.AsyncOpenAI): prompt = "Hello, my name is" max_tokens = 5 @@ -487,6 +496,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): assert first_response != completion.choices[0].text +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_completion(server, client: openai.AsyncOpenAI, @@ -509,6 +519,7 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI, jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_json_chat(server, client: openai.AsyncOpenAI, @@ -555,6 +566,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI, assert json1["age"] != json2["age"] +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, @@ -575,6 +587,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI, assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, @@ -612,6 +625,7 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI, assert ip1 != ip2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, @@ -631,6 +645,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI, assert completion.choices[i].text in TEST_CHOICE +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, @@ -669,6 +684,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI, assert choice1 != choice2 +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, @@ -704,6 +720,7 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI, extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) +@pytest.mark.asyncio @pytest.mark.parametrize("guided_decoding_backend", ["outlines", "lm-format-enforcer"]) async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, @@ -734,6 +751,7 @@ async def test_guided_choice_chat_logprobs(server, client: openai.AsyncOpenAI, for token, logprob in token_dict.items()) +@pytest.mark.asyncio async def test_response_format_json_object(server, client: openai.AsyncOpenAI): for _ in range(2): resp = await client.chat.completions.create( @@ -751,6 +769,7 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +@pytest.mark.asyncio async def test_extra_fields(server, client: openai.AsyncOpenAI): with pytest.raises(BadRequestError) as exc_info: await client.chat.completions.create( @@ -766,6 +785,7 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI): assert "extra_forbidden" in exc_info.value.message +@pytest.mark.asyncio async def test_complex_message_content(server, client: openai.AsyncOpenAI): resp = await client.chat.completions.create( model=MODEL_NAME, @@ -785,6 +805,7 @@ async def test_complex_message_content(server, client: openai.AsyncOpenAI): assert content == "2" +@pytest.mark.asyncio async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement @@ -819,6 +840,7 @@ async def test_guided_grammar(server, client: openai.AsyncOpenAI): assert content.strip() == ground_truth +@pytest.mark.asyncio @pytest.mark.parametrize( # first test base model, then test loras "model_name", @@ -850,6 +872,7 @@ async def test_echo_logprob_completion(server, client: openai.AsyncOpenAI, assert len(logprobs.tokens) > 5 +@pytest.mark.asyncio async def test_long_seed(server, client: openai.AsyncOpenAI): for seed in [ torch.iinfo(torch.long).min - 1, @@ -869,6 +892,7 @@ async def test_long_seed(server, client: openai.AsyncOpenAI): or "less_than_equal" in exc_info.value.message) +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], @@ -907,6 +931,7 @@ async def test_single_embedding(embedding_server, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 5 +@pytest.mark.asyncio @pytest.mark.parametrize( "model_name", [EMBEDDING_MODEL_NAME], diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index dd43a5cf0a248..52dc1a0b898de 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -2,6 +2,7 @@ import sys import time +import pytest import torch from openai import OpenAI, OpenAIError @@ -10,6 +11,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.utils import get_open_port +pytestmark = pytest.mark.openai + class MyOPTForCausalLM(OPTForCausalLM): From a9201d0251790b0ebdb6981d4f7b39ab149ce2f0 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 14 May 2024 07:17:59 +0000 Subject: [PATCH 28/94] Fix memory profiling error --- .buildkite/test-pipeline.yaml | 3 ++- pyproject.toml | 5 ++++- tests/entrypoints/test_llm_encode.py | 9 +++------ tests/entrypoints/test_llm_generate.py | 7 ++----- 4 files changed, 11 insertions(+), 13 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f0bab7c87ad0f..cb4e1ba935880 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -52,7 +52,8 @@ steps: - label: Entrypoints Test commands: - pytest -v -s test_inputs.py - - pytest -v -s entrypoints -m llm + - pytest -v -s entrypoints -m llm_generate + - pytest -v -s entrypoints -m llm_encode - pytest -v -s entrypoints -m openai - label: Examples Test diff --git a/pyproject.toml b/pyproject.toml index ead64b7436121..97ff04f854ad9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,5 +68,8 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ - "skip_global_cleanup" + "skip_global_cleanup", + "llm_encode: run tests for vLLM embedding API only", + "llm_generate: run tests for vLLM generate API only", + "openai: run tests for OpenAI API only", ] diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index c9833b0c315bf..fd1995a71f7dd 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -4,8 +4,6 @@ from vllm import LLM, EmbeddingRequestOutput, PoolingParams -from ..conftest import cleanup - MODEL_NAME = "intfloat/e5-mistral-7b-instruct" PROMPTS = [ @@ -24,19 +22,18 @@ [1000, 1003, 1001, 1002], ] -pytestmark = pytest.mark.llm +pytestmark = pytest.mark.llm_encode @pytest.fixture(scope="module") def llm(): + # pytest caches the fixture so we cannot GC it yield LLM(model=MODEL_NAME, max_num_batched_tokens=32768, tensor_parallel_size=1, - gpu_memory_utilization=0.90, + gpu_memory_utilization=0.75, enforce_eager=True) - cleanup() - def assert_outputs_equal(o1: List[EmbeddingRequestOutput], o2: List[EmbeddingRequestOutput]): diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index e21e4b8136746..b973cebea4e71 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -4,8 +4,6 @@ from vllm import LLM, RequestOutput, SamplingParams -from ..conftest import cleanup - MODEL_NAME = "facebook/opt-125m" PROMPTS = [ @@ -22,19 +20,18 @@ [0, 3, 1, 2], ] -pytestmark = pytest.mark.llm +pytestmark = pytest.mark.llm_generate @pytest.fixture(scope="module") def llm(): + # pytest caches the fixture so we cannot GC it yield LLM(model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, enforce_eager=True) - cleanup() - def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] From ceebfa684c9bbb1401a5d1d042471b23635e5798 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 02:04:01 +0000 Subject: [PATCH 29/94] Fix memory usage for embedding server --- tests/entrypoints/test_openai_server.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 7050cc8ebe3d1..2944d887d8896 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -120,9 +120,11 @@ def embedding_server(zephyr_lora_files): # use half precision for speed and memory savings in CI environment "--dtype", "bfloat16", + "--enforce-eager", + "--gpu-memory-utilization", + "0.75", "--max-model-len", "8192", - "--enforce-eager", ]) ray.get(server_runner.ready.remote()) yield server_runner From 7d991cde83a872641cf58862a57a0f3697866e42 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 02:43:16 +0000 Subject: [PATCH 30/94] Update embeddings API to use new imputs --- vllm/engine/async_llm_engine.py | 8 ++-- vllm/entrypoints/openai/serving_embedding.py | 42 ++++++++++++-------- vllm/entrypoints/openai/serving_engine.py | 6 ++- 3 files changed, 33 insertions(+), 23 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d4991df122325..f4ae3fe64e85b 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -584,7 +584,7 @@ async def add_request( async def generate( self, inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], + sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, ) -> AsyncIterator[RequestOutput]: @@ -596,9 +596,7 @@ async def generate( Args: inputs: The inputs to the LLM. - params: Parameters for sampling or pooling. - :class:`~vllm.SamplingParams` for text generation. - :class:`~vllm.PoolingParams` for pooling. + sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -652,7 +650,7 @@ async def generate( async for output in self.process_request( request_id, inputs, - params, + sampling_params, lora_request=lora_request, ): yield output diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7a57be0c88915..5a3448de3d7a4 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,5 +1,5 @@ import time -from typing import AsyncIterator, List, Tuple +from typing import AsyncIterator, List, Optional, Tuple from fastapi import Request @@ -100,11 +100,16 @@ async def create_embedding(self, request: EmbeddingRequest, prompt_ids, prompt_text = prompt_formats - generators.append( - self.engine.generate(prompt_text, - pooling_params, - f"{request_id}-{i}", - prompt_token_ids=prompt_ids)) + generator = self.engine.encode( + { + "prompt": prompt_text, + "prompt_token_ids": prompt_ids + }, + pooling_params, + f"{request_id}-{i}", + ) + + generators.append(generator) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -113,16 +118,21 @@ async def create_embedding(self, request: EmbeddingRequest, int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) # Non-streaming response - final_res_batch: EmbeddingRequestOutput = [None] * len(prompts) - async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.engine.abort(f"{request_id}-{i}") - # TODO: Use a vllm-specific Validation Error - return self.create_error_response("Client disconnected") - final_res_batch[i] = res - response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name) + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * len(prompts) + try: + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + # TODO: Use a vllm-specific Validation Error + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_embedding_response( + final_res_batch, request_id, created_time, model_name) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) return response diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 58a1c2f7e73fe..a50d91e8d4fd4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -142,7 +142,8 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -154,7 +155,8 @@ async def _check_model( status_code=HTTPStatus.NOT_FOUND) def _maybe_get_lora( - self, request: Union[CompletionRequest, ChatCompletionRequest] + self, request: Union[CompletionRequest, ChatCompletionRequest, + EmbeddingRequest] ) -> Optional[LoRARequest]: if request.model in self.served_model_names: return None From 07c0a2e77f8d4048a06089d152982c4ae8f2fdea Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 15 May 2024 06:11:32 +0000 Subject: [PATCH 31/94] Remove unnecessary commas --- vllm/entrypoints/openai/serving_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 660f04e117632..e2f74b51028b5 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -230,7 +230,7 @@ def _validate_input( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the input for embedding " - f"generation. Please reduce the length of the input.", ) + f"generation. Please reduce the length of the input.") return input_ids, input_text if request.max_tokens is None: @@ -239,7 +239,7 @@ def _validate_input( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " - f"Please reduce the length of the messages.", ) + f"Please reduce the length of the messages.") request.max_tokens = self.max_model_len - token_num if token_num + request.max_tokens > self.max_model_len: @@ -249,7 +249,7 @@ def _validate_input( f"{request.max_tokens + token_num} tokens " f"({token_num} in the messages, " f"{request.max_tokens} in the completion). " - f"Please reduce the length of the messages or completion.", ) + f"Please reduce the length of the messages or completion.") return input_ids, input_text From 30975825f3d1259277aade493d12e3e081d625ca Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 20 May 2024 03:46:41 +0000 Subject: [PATCH 32/94] Merge `llm` groups back into one by enabling gc --- .buildkite/test-pipeline.yaml | 3 +-- pyproject.toml | 3 +-- tests/entrypoints/test_llm_encode.py | 15 ++++++++++++--- tests/entrypoints/test_llm_generate.py | 15 ++++++++++++--- 4 files changed, 26 insertions(+), 10 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f1deb7fdcf698..206fb814abf5a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -63,8 +63,7 @@ steps: #mirror_hardwares: [amd] commands: - pytest -v -s test_inputs.py - - pytest -v -s entrypoints -m llm_generate - - pytest -v -s entrypoints -m llm_encode + - pytest -v -s entrypoints -m llm - pytest -v -s entrypoints -m openai - label: Examples Test diff --git a/pyproject.toml b/pyproject.toml index c529f51ab93de..ab3fbfc92642c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,6 @@ skip_gitignore = true [tool.pytest.ini_options] markers = [ "skip_global_cleanup", - "llm_encode: run tests for vLLM embedding API only", - "llm_generate: run tests for vLLM generate API only", + "llm: run tests for vLLM API only", "openai: run tests for OpenAI API only", ] diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index fd1995a71f7dd..24da218b5adb2 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -1,9 +1,12 @@ +import weakref from typing import List import pytest from vllm import LLM, EmbeddingRequestOutput, PoolingParams +from ..conftest import cleanup + MODEL_NAME = "intfloat/e5-mistral-7b-instruct" PROMPTS = [ @@ -22,18 +25,24 @@ [1000, 1003, 1001, 1002], ] -pytestmark = pytest.mark.llm_encode +pytestmark = pytest.mark.llm @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we cannot GC it - yield LLM(model=MODEL_NAME, + # pytest caches the fixture so we use weakref for garbage collection to work + llm = LLM(model=MODEL_NAME, max_num_batched_tokens=32768, tensor_parallel_size=1, gpu_memory_utilization=0.75, enforce_eager=True) + yield weakref.proxy(llm) + + del llm + + cleanup() + def assert_outputs_equal(o1: List[EmbeddingRequestOutput], o2: List[EmbeddingRequestOutput]): diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index b973cebea4e71..4c2e52e64d54c 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -1,9 +1,12 @@ +import weakref from typing import List import pytest from vllm import LLM, RequestOutput, SamplingParams +from ..conftest import cleanup + MODEL_NAME = "facebook/opt-125m" PROMPTS = [ @@ -20,18 +23,24 @@ [0, 3, 1, 2], ] -pytestmark = pytest.mark.llm_generate +pytestmark = pytest.mark.llm @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we cannot GC it - yield LLM(model=MODEL_NAME, + # pytest caches the fixture so we use weakref for garbage collection to work + llm = LLM(model=MODEL_NAME, max_num_batched_tokens=4096, tensor_parallel_size=1, gpu_memory_utilization=0.10, enforce_eager=True) + yield weakref.proxy(llm) + + del llm + + cleanup() + def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] From 7bbd123dd23fd6c0266a24331f404244462ed7af Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 20 May 2024 10:11:50 +0000 Subject: [PATCH 33/94] Improve documentation for LLM/engine --- vllm/engine/async_llm_engine.py | 18 +++++++++--------- vllm/engine/llm_engine.py | 10 +++++----- vllm/entrypoints/llm.py | 6 ++++-- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 3e684c883e1eb..b2f2fa02642d1 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -291,15 +291,15 @@ async def check_health_async(self) -> None: class AsyncLLMEngine: - """An asynchronous wrapper for LLMEngine. + """An asynchronous wrapper for :class:`LLMEngine`. - This class is used to wrap the LLMEngine class to make it asynchronous. It - uses asyncio to create a background loop that keeps processing incoming - requests. The LLMEngine is kicked by the generate method when there - are requests in the waiting queue. The generate method yields the outputs - from the LLMEngine to the caller. + This class is used to wrap the :class:`LLMEngine` class to make it + asynchronous. It uses asyncio to create a background loop that keeps + processing incoming requests. The :class:`LLMEngine` is kicked by the + generate method when there are requests in the waiting queue. The generate + method yields the outputs from the :class:`LLMEngine` to the caller. - NOTE: For the comprehensive list of arguments, see `LLMEngine`. + NOTE: For the comprehensive list of arguments, see :class:`LLMEngine`. Args: worker_use_ray: Whether to use Ray for model workers. Required for @@ -313,8 +313,8 @@ class AsyncLLMEngine: being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. - *args: Arguments for LLMEngine. - *kwargs: Arguments for LLMEngine. + *args: Arguments for :class:`LLMEngine`. + **kwargs: Arguments for :class:`LLMEngine`. """ _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5e64117ae6c5a..850c8096f58d3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -63,11 +63,11 @@ class LLMEngine: iteration-level scheduling and efficient memory management to maximize the serving throughput. - The `LLM` class wraps this class for offline batched inference and the - `AsyncLLMEngine` class wraps this class for online serving. + The :class:`~vllm.LLM` class wraps this class for offline batched inference + and the :class:`AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the `EngineArgs` class. For the - comprehensive list of arguments, see `EngineArgs`. + NOTE: The config arguments are derived from the :class:`~vllm.EngineArgs` + class. For the comprehensive list of arguments, see :ref:`engine_args`. Args: model_config: The configuration related to the LLM model. @@ -84,7 +84,7 @@ class LLMEngine: executor_class: The model executor class for managing distributed execution. log_stats: Whether to log statistics. - usage_context: Specified entry point, used for usage info collection + usage_context: Specified entry point, used for usage info collection. """ tokenizer: Optional[BaseTokenizerGroup] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 59709e050325b..89d49b4741ccc 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -29,8 +29,10 @@ class LLM: mechanism and efficient memory management. NOTE: This class is intended to be used for offline inference. For online - serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `EngineArgs`. + serving, use the :class:`~vllm.AsyncLLMEngine` class instead. + + NOTE: For the comprehensive list of arguments, see + :class:`~vllm.EngineArgs`. Args: model: The name or path of a HuggingFace Transformers model. From 056eb6168989fc61fe625cde6cf2de1cd65765c1 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 07:59:34 +0000 Subject: [PATCH 34/94] Direct readers to the `PromptInputs` class --- docs/source/index.rst | 1 + docs/source/offline_inference/llm.rst | 2 +- docs/source/offline_inference/llm_inputs.rst | 14 +++++ vllm/__init__.py | 4 ++ vllm/engine/async_llm_engine.py | 8 ++- vllm/engine/llm_engine.py | 6 +- vllm/entrypoints/llm.py | 4 +- vllm/inputs.py | 64 +++++++++++++------- 8 files changed, 74 insertions(+), 29 deletions(-) create mode 100644 docs/source/offline_inference/llm_inputs.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index bab00e28e4018..6383680f2b512 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -73,6 +73,7 @@ Documentation :caption: Offline Inference offline_inference/llm + offline_inference/llm_inputs offline_inference/sampling_params .. toctree:: diff --git a/docs/source/offline_inference/llm.rst b/docs/source/offline_inference/llm.rst index 1a443ea406994..83ba1b6987c6d 100644 --- a/docs/source/offline_inference/llm.rst +++ b/docs/source/offline_inference/llm.rst @@ -1,5 +1,5 @@ LLM Class -========== +========= .. autoclass:: vllm.LLM :members: diff --git a/docs/source/offline_inference/llm_inputs.rst b/docs/source/offline_inference/llm_inputs.rst new file mode 100644 index 0000000000000..31c3d16a3c8eb --- /dev/null +++ b/docs/source/offline_inference/llm_inputs.rst @@ -0,0 +1,14 @@ +LLM Inputs +========== + +.. autodata:: vllm.inputs.PromptStrictInputs + +.. autoclass:: vllm.inputs.TextPrompt + :show-inheritance: + :members: + :member-order: bysource + +.. autoclass:: vllm.inputs.TokensPrompt + :show-inheritance: + :members: + :member-order: bysource diff --git a/vllm/__init__.py b/vllm/__init__.py index 74674ca0d12af..a0e154d24087c 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,6 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster +from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -16,6 +17,9 @@ __all__ = [ "LLM", "ModelRegistry", + "PromptStrictInputs", + "TextPrompt", + "TokensPrompt", "SamplingParams", "RequestOutput", "CompletionOutput", diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index b2f2fa02642d1..8212b9c6e2027 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -599,7 +599,9 @@ async def generate( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. sampling_params: The sampling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. @@ -673,7 +675,9 @@ async def encode( from the LLMEngine to the caller. Args: - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. pooling_params: The pooling parameters of the request. request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 850c8096f58d3..89f99aa8f7098 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -152,8 +152,6 @@ def __init__( self.decoding_config = decoding_config or DecodingConfig() self.log_stats = log_stats - self.tokenizer: Optional[BaseTokenizerGroup] - if not self.model_config.skip_tokenizer_init: tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(tokenizer) @@ -446,7 +444,9 @@ def add_request( Args: request_id: The unique ID of the request. - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. See + :class:`~vllm.inputs.PromptInputs` + for more details about the format of each input. params: Parameters for sampling or pooling. :class:`~vllm.SamplingParams` for text generation. :class:`~vllm.PoolingParams` for pooling. diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 89d49b4741ccc..f84de86e4617b 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -372,7 +372,9 @@ def encode( into a single list and pass it to this method. Args: - inputs: The inputs to the LLM. + inputs: The inputs to the LLM. You may pass a sequence of inputs for + batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. use_tqdm: Whether to use tqdm to display the progress bar. diff --git a/vllm/inputs.py b/vllm/inputs.py index 2b5ea1c0f3828..e4bdb18c2f49a 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -5,7 +5,7 @@ from vllm.sequence import MultiModalData -class ParsedString(TypedDict): +class ParsedText(TypedDict): text: str is_tokens: Literal[False] @@ -18,7 +18,7 @@ class ParsedTokens(TypedDict): # https://github.com/vllm-project/vllm/pull/4028 @overload def parse_and_batch_prompt( - prompt: Union[str, List[str]]) -> Sequence[ParsedString]: + prompt: Union[str, List[str]]) -> Sequence[ParsedText]: ... @@ -30,10 +30,10 @@ def parse_and_batch_prompt( def parse_and_batch_prompt( prompt: Union[str, List[str], List[int], List[List[int]]], -) -> Union[Sequence[ParsedString], Sequence[ParsedTokens]]: +) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: if isinstance(prompt, str): # case 1: a string - return [ParsedString(text=prompt, is_tokens=False)] + return [ParsedText(text=prompt, is_tokens=False)] if isinstance(prompt, list): if len(prompt) == 0: @@ -42,7 +42,7 @@ def parse_and_batch_prompt( if isinstance(prompt[0], str): # case 2: array of strings return [ - ParsedString(text=elem, is_tokens=False) + ParsedText(text=elem, is_tokens=False) for elem in cast(List[str], prompt) ] if isinstance(prompt[0], int): @@ -64,42 +64,62 @@ def parse_and_batch_prompt( "array of tokens, or array of token arrays") -class MultiModalPrompt(TypedDict, total=False): - multi_modal_data: Optional["MultiModalData"] - """Multi modal data.""" - +class TextPrompt(TypedDict): + """Schema for a text prompt.""" -class StringPrompt(MultiModalPrompt, TypedDict): prompt: str - """The prompt string.""" + """The input text to be tokenized before passing to the model.""" + + multi_modal_data: Optional["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + +class TokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" -class TokensPrompt(MultiModalPrompt, TypedDict): prompt_token_ids: List[int] - """The token IDs of the prompt. If None, we use the - tokenizer to convert the prompts to token IDs.""" + """A list of token IDs to pass to the model.""" + + multi_modal_data: Optional["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ -class StringTokensPrompt(MultiModalPrompt, TypedDict): +class TextTokensPrompt(TypedDict): """It is assumed that :attr:`prompt` is consistent with :attr:`prompt_token_ids`. This is currently used in :class:`AsyncLLMEngine` for logging both the text and token IDs.""" prompt: str - """The prompt string.""" + """The prompt text.""" prompt_token_ids: List[int] """The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs.""" + multi_modal_data: Optional["MultiModalData"] + """ + Optional multi-modal data to pass to the model, + if the model supports it. + """ + + +PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +""" +The inputs to the LLM, which can take one of the following forms: -PromptStrictInputs = Union[str, StringPrompt, TokensPrompt] -"""The prompt string. More complex inputs should be represented by -:class:`StringPrompt` or :class:`TokensPrompt`.""" +- A text prompt (:class:`str` or :class:`TextPrompt`) +- A tokenized prompt (:class:`TokensPrompt`) +""" -PromptInputs = Union[str, StringPrompt, TokensPrompt, StringTokensPrompt] -"""As :const:`PromptStrictInputs` but additionally accepts -:class:`StringTokensPrompt`.""" +PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] +"""Same as :const:`PromptStrictInputs` but additionally accepts +:class:`TextTokensPrompt`.""" class LLMInputs(TypedDict): From b3b990a7ac93c08e7734b02a7bee7cae14711c81 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 08:18:55 +0000 Subject: [PATCH 35/94] Separate `_run_engine` from `_validate_and_add_requests` --- tests/lora/test_long_context.py | 4 +- tests/samplers/test_logits_processor.py | 4 +- tests/samplers/test_seeded_generate.py | 4 +- vllm/entrypoints/llm.py | 51 +++++++++---------------- 4 files changed, 23 insertions(+), 40 deletions(-) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 15189f421a539..3dd9b98ed911b 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -5,7 +5,7 @@ import pytest import vllm -from vllm import SamplingParams +from vllm import RequestOutput, SamplingParams from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora from vllm.lora.request import LoRARequest from vllm.model_executor.layers.rotary_embedding import ( @@ -100,7 +100,7 @@ def batched_generate( # Add requests to the engine and run the engine for request_data in requests_data: llm._add_request(**request_data) - outputs = llm._run_engine(use_tqdm=True) + outputs = llm._run_engine(RequestOutput, use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 0724622d5f3c7..1b63c1dab98d2 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm import SamplingParams +from vllm import RequestOutput, SamplingParams MODELS = ["facebook/opt-125m"] @@ -54,6 +54,6 @@ def pick_vllm(token_ids, logits): params=SamplingParams(max_tokens=max_tokens), ) - outputs = vllm_model.model._run_engine(False) + outputs = vllm_model.model._run_engine(RequestOutput, use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index fef5ff3fb9e8e..fca2b0e05c335 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -8,7 +8,7 @@ import pytest -from vllm import SamplingParams +from vllm import RequestOutput, SamplingParams from vllm.model_executor.utils import set_random_seed MODEL = "facebook/opt-125m" @@ -59,7 +59,7 @@ def test_random_sample_with_seed( ): llm._add_request(prompt, params=params) - results = llm._run_engine(use_tqdm=False) + results = llm._run_engine(RequestOutput, use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] for output in results] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index f84de86e4617b..d21d54783b139 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,5 @@ -from typing import List, Optional, Sequence, Union, cast, overload +from typing import (List, Optional, Sequence, Type, TypeVar, Union, cast, + overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -18,6 +19,8 @@ logger = init_logger(__name__) +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -267,13 +270,15 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() - return self._validate_and_add_requests( + self._validate_and_add_requests( inputs=inputs, params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) + return self._run_engine(RequestOutput, use_tqdm=use_tqdm) + @overload # DEPRECATED: single (prompt + optional token ids) def encode( self, @@ -399,13 +404,15 @@ def encode( # Use default pooling params. pooling_params = PoolingParams() - return self._validate_and_add_requests( + self._validate_and_add_requests( inputs=inputs, params=pooling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) + return self._run_engine(EmbeddingRequestOutput, use_tqdm=use_tqdm) + # DEPRECATED def _convert_v1_inputs( self, @@ -461,26 +468,6 @@ def _convert_v1_inputs( return inputs - @overload - def _validate_and_add_requests( - self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - params: Union[SamplingParams, Sequence[SamplingParams]], - use_tqdm: bool, - lora_request: Optional[LoRARequest], - ) -> List[RequestOutput]: - ... - - @overload - def _validate_and_add_requests( # type: ignore[misc] - self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - params: Union[PoolingParams, Sequence[PoolingParams]], - use_tqdm: bool, - lora_request: Optional[LoRARequest], - ) -> List[EmbeddingRequestOutput]: - ... - def _validate_and_add_requests( self, inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -488,7 +475,7 @@ def _validate_and_add_requests( Sequence[PoolingParams]], use_tqdm: bool, lora_request: Optional[LoRARequest], - ) -> Union[List[RequestOutput], List[EmbeddingRequestOutput]]: + ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. inputs = [inputs] @@ -507,8 +494,6 @@ def _validate_and_add_requests( lora_request=lora_request, ) - return self._run_engine(use_tqdm) - def _add_request( self, inputs: PromptInputs, @@ -521,9 +506,8 @@ def _add_request( params, lora_request=lora_request) - def _run_engine( - self, use_tqdm: bool - ) -> Union[List[RequestOutput], List[EmbeddingRequestOutput]]: + def _run_engine(self, output_type: Type[_O], *, + use_tqdm: bool) -> List[_O]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -556,9 +540,8 @@ def _run_engine( # its previous requests. outputs = sorted(outputs, key=lambda x: int(x.request_id)) - if len(outputs) > 0: - first, *rest = outputs - assert all(isinstance(r, type(first)) for r in rest), ( - f"Expected all outputs to be of the same type {type(first)}") + if len(outputs) > 0 and not isinstance(outputs[0], output_type): + raise TypeError(f"Expected output type to be {output_type}, " + f"but found type {type(outputs[0])}") - return outputs # type: ignore + return cast(List[_O], outputs) From 2169defbff298b09a405a1bd124976adf7ac574f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 08:41:31 +0000 Subject: [PATCH 36/94] Add flag for deprecating legacy API --- tests/entrypoints/test_llm_encode.py | 22 ++++++------ tests/entrypoints/test_llm_generate.py | 22 ++++++------ vllm/entrypoints/llm.py | 49 ++++++++++++++++++-------- vllm/utils.py | 30 +++++++++++----- 4 files changed, 79 insertions(+), 44 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 24da218b5adb2..872707b54e3f7 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -30,16 +30,18 @@ @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we use weakref for garbage collection to work - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True) - - yield weakref.proxy(llm) - - del llm + with LLM.deprecate_legacy_ctx(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + yield weakref.proxy(llm) + + del llm cleanup() diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 4c2e52e64d54c..37d1ea7e8745b 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -28,16 +28,18 @@ @pytest.fixture(scope="module") def llm(): - # pytest caches the fixture so we use weakref for garbage collection to work - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) - - yield weakref.proxy(llm) - - del llm + with LLM.deprecate_legacy_ctx(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + yield weakref.proxy(llm) + + del llm cleanup() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d21d54783b139..2f76979ce5ebe 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,5 +1,6 @@ -from typing import (List, Optional, Sequence, Type, TypeVar, Union, cast, - overload) +from contextlib import contextmanager +from typing import (ClassVar, List, Optional, Sequence, Type, TypeVar, Union, + cast, overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -87,6 +88,18 @@ class LLM: disable_custom_all_reduce: See ParallelConfig """ + DEPRECATE_LEGACY: ClassVar[bool] = False + """A flag to toggle whether to deprecate the legacy generate/encode API.""" + + @staticmethod + @contextmanager + def deprecate_legacy_ctx(): + LLM.DEPRECATE_LEGACY = True + + yield + + LLM.DEPRECATE_LEGACY = False + def __init__( self, model: str, @@ -144,7 +157,7 @@ def set_tokenizer( ) -> None: self.llm_engine.tokenizer.tokenizer = tokenizer - @overload # DEPRECATED: single (prompt + optional token ids) + @overload # LEGACY: single (prompt + optional token ids) def generate( self, prompts: str, @@ -157,7 +170,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: multi (prompt + optional token ids) + @overload # LEGACY: multi (prompt + optional token ids) def generate( self, prompts: List[str], @@ -170,7 +183,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: single (token ids + optional prompt) + @overload # LEGACY: single (token ids + optional prompt) def generate( self, prompts: Optional[str] = None, @@ -184,7 +197,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: multi (token ids + optional prompt) + @overload # LEGACY: multi (token ids + optional prompt) def generate( self, prompts: Optional[List[str]] = None, @@ -198,7 +211,7 @@ def generate( ) -> List[RequestOutput]: ... - @overload # DEPRECATED: single or multi token ids [pos-only] + @overload # LEGACY: single or multi token ids [pos-only] def generate( self, prompts: None, @@ -223,7 +236,10 @@ def generate( ) -> List[RequestOutput]: ... - @deprecate_kwargs('prompts', 'prompt_token_ids', 'multi_modal_data') + @deprecate_kwargs('prompts', + 'prompt_token_ids', + 'multi_modal_data', + is_deprecated=lambda: LLM.DEPRECATE_LEGACY) def generate( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -279,7 +295,7 @@ def generate( return self._run_engine(RequestOutput, use_tqdm=use_tqdm) - @overload # DEPRECATED: single (prompt + optional token ids) + @overload # LEGACY: single (prompt + optional token ids) def encode( self, prompts: str, @@ -292,7 +308,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: multi (prompt + optional token ids) + @overload # LEGACY: multi (prompt + optional token ids) def encode( self, prompts: List[str], @@ -305,7 +321,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: single (token ids + optional prompt) + @overload # LEGACY: single (token ids + optional prompt) def encode( self, prompts: Optional[str] = None, @@ -319,7 +335,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: multi (token ids + optional prompt) + @overload # LEGACY: multi (token ids + optional prompt) def encode( self, prompts: Optional[List[str]] = None, @@ -333,7 +349,7 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @overload # DEPRECATED: single or multi token ids [pos-only] + @overload # LEGACY: single or multi token ids [pos-only] def encode( self, prompts: None, @@ -358,7 +374,10 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @deprecate_kwargs('prompts', 'prompt_token_ids', 'multi_modal_data') + @deprecate_kwargs('prompts', + 'prompt_token_ids', + 'multi_modal_data', + is_deprecated=lambda: LLM.DEPRECATE_LEGACY) def encode( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -413,7 +432,7 @@ def encode( return self._run_engine(EmbeddingRequestOutput, use_tqdm=use_tqdm) - # DEPRECATED + # LEGACY def _convert_v1_inputs( self, prompts: Optional[Union[str, List[str]]], diff --git a/vllm/utils.py b/vllm/utils.py index 506f4f3ae2be6..4480c6c6960de 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -655,24 +655,36 @@ def enable_trace_function_call_for_thread() -> None: enable_trace_function_call(log_path) +def identity(value: T) -> T: + return value + + F = TypeVar('F', bound=Callable[..., Any]) -def deprecate_kwargs(*kws: str) -> Callable[[F], F]: +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], + bool]] = True) -> Callable[[F], F]: deprecated_kws = set(kws) + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + def wrapper(fn: F) -> F: @wraps(fn) def inner(*args, **kwargs): - deprecated_kwargs = kwargs.keys() & deprecated_kws - if deprecated_kwargs: - warnings.warn( - DeprecationWarning( - f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update."), - stacklevel=3, # The inner function takes up one level - ) + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + warnings.warn( + DeprecationWarning( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update." + ), + stacklevel=3, # The inner function takes up one level + ) return fn(*args, **kwargs) From 3dbded140243682140d6231fb273ca86495ad104 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 09:01:08 +0000 Subject: [PATCH 37/94] Add tests for `deprecate_kwargs` --- .buildkite/test-pipeline.yaml | 3 +++ tests/test_utils.py | 51 +++++++++++++++++++++++++++++++++++ tests/utils.py | 14 ++++++++++ 3 files changed, 68 insertions(+) create mode 100644 tests/test_utils.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 206fb814abf5a..af22e404361aa 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -109,6 +109,9 @@ steps: mirror_hardwares: [amd] command: pytest -v -s test_logits_processor.py +- label: Utils Test + command: pytest -v -s test_utils.py + - label: Worker Test mirror_hardwares: [amd] command: pytest -v -s worker diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000000..7f84fc7f6a454 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,51 @@ +import pytest + +from vllm.utils import deprecate_kwargs + +from .utils import error_on_warning + + +def test_deprecate_kwargs_always(): + @deprecate_kwargs("old_arg", is_deprecated=True) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_never(): + @deprecate_kwargs("old_arg", is_deprecated=False) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + +def test_deprecate_kwargs_func(): + is_deprecated = True + + @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="'old_arg'"): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) + + is_deprecated = False + + with error_on_warning(): + dummy(old_arg=1) + + with error_on_warning(): + dummy(new_arg=1) diff --git a/tests/utils.py b/tests/utils.py index 689d8c8c5ba8a..329842911e159 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,8 @@ import subprocess import sys import time +import warnings +from contextlib import contextmanager import ray import requests @@ -87,3 +89,15 @@ def multi_process_tensor_parallel( ray.get(refs) ray.shutdown() + + +@contextmanager +def error_on_warning(): + """ + Within the scope of this context manager, tests will fail if any warning + is emitted. + """ + with warnings.catch_warnings(): + warnings.simplefilter("error") + + yield From 8e20317bbfe9a25dc1b062cf058a68382e7d5f17 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 09:04:34 +0000 Subject: [PATCH 38/94] Apply formatter --- tests/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7f84fc7f6a454..988dc5ba2bf29 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ def test_deprecate_kwargs_always(): + @deprecate_kwargs("old_arg", is_deprecated=True) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -18,6 +19,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): def test_deprecate_kwargs_never(): + @deprecate_kwargs("old_arg", is_deprecated=False) def dummy(*, old_arg: object = None, new_arg: object = None): pass @@ -41,7 +43,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with error_on_warning(): dummy(new_arg=1) - + is_deprecated = False with error_on_warning(): From fdccaa21066010e7c8d265528e34f21979421109 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Wed, 22 May 2024 09:04:47 +0000 Subject: [PATCH 39/94] Rename attribute to be less misleading --- vllm/entrypoints/llm.py | 4 ++-- vllm/inputs.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2f76979ce5ebe..8943929371ae9 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -442,10 +442,10 @@ def _convert_v1_inputs( # skip_tokenizer_init is now checked in engine if prompts is not None: - prompts = [p["text"] for p in parse_and_batch_prompt(prompts)] + prompts = [p["content"] for p in parse_and_batch_prompt(prompts)] if prompt_token_ids is not None: prompt_token_ids = [ - p["text"] for p in parse_and_batch_prompt(prompt_token_ids) + p["content"] for p in parse_and_batch_prompt(prompt_token_ids) ] num_requests = None diff --git a/vllm/inputs.py b/vllm/inputs.py index e4bdb18c2f49a..80011b6dd1d61 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -6,12 +6,12 @@ class ParsedText(TypedDict): - text: str + content: str is_tokens: Literal[False] class ParsedTokens(TypedDict): - text: List[int] + content: List[int] is_tokens: Literal[True] @@ -33,7 +33,7 @@ def parse_and_batch_prompt( ) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]: if isinstance(prompt, str): # case 1: a string - return [ParsedText(text=prompt, is_tokens=False)] + return [ParsedText(content=prompt, is_tokens=False)] if isinstance(prompt, list): if len(prompt) == 0: @@ -42,13 +42,13 @@ def parse_and_batch_prompt( if isinstance(prompt[0], str): # case 2: array of strings return [ - ParsedText(text=elem, is_tokens=False) + ParsedText(content=elem, is_tokens=False) for elem in cast(List[str], prompt) ] if isinstance(prompt[0], int): # case 3: array of tokens elem = cast(List[int], prompt) - return [ParsedTokens(text=elem, is_tokens=True)] + return [ParsedTokens(content=elem, is_tokens=True)] if isinstance(prompt[0], list): if len(prompt[0]) == 0: raise ValueError("please provide at least one prompt") @@ -56,7 +56,7 @@ def parse_and_batch_prompt( if isinstance(prompt[0][0], int): # case 4: array of token arrays return [ - ParsedTokens(text=elem, is_tokens=True) + ParsedTokens(content=elem, is_tokens=True) for elem in cast(List[List[int]], prompt) ] From 77ee1c87707590ad015eb4636500a02277e30e83 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 06:50:39 +0000 Subject: [PATCH 40/94] Renable using `'fork'` start method and improve speed by using `torch.multiprocessing` wrapper instead of stdlib `multiprocessing` --- tests/entrypoints/test_server_oot_registration.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/entrypoints/test_server_oot_registration.py b/tests/entrypoints/test_server_oot_registration.py index 52dc1a0b898de..3e55d7f4297fb 100644 --- a/tests/entrypoints/test_server_oot_registration.py +++ b/tests/entrypoints/test_server_oot_registration.py @@ -1,4 +1,3 @@ -import multiprocessing import sys import time @@ -37,7 +36,7 @@ def server_function(port): def test_oot_registration_for_api_server(): port = get_open_port() - ctx = multiprocessing.get_context("spawn") + ctx = torch.multiprocessing.get_context() server = ctx.Process(target=server_function, args=(port, )) server.start() client = OpenAI( From b1bcdd17002f940a9b64a64925f318090ebfbc0b Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 07:52:51 +0000 Subject: [PATCH 41/94] Simplify logic of casting request output --- vllm/engine/async_llm_engine.py | 76 +++++++++++++++------------------ vllm/engine/llm_engine.py | 1 - vllm/entrypoints/llm.py | 23 ++++++---- 3 files changed, 49 insertions(+), 51 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8212b9c6e2027..7fe758c170b0c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,7 +2,7 @@ import time from functools import partial from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, Union, overload) + Set, Tuple, Type, TypeVar, Union) from transformers import PreTrainedTokenizer @@ -290,6 +290,9 @@ async def check_health_async(self) -> None: self.model_executor.check_health() +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class AsyncLLMEngine: """An asynchronous wrapper for :class:`LLMEngine`. @@ -653,10 +656,11 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, inputs, sampling_params, + output_type=RequestOutput, lora_request=lora_request, ): yield output @@ -727,63 +731,53 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self.process_request( + async for output in self._process_request( request_id, inputs, pooling_params, + output_type=EmbeddingRequestOutput, lora_request=lora_request, ): yield output - @overload - def process_request( - self, - request_id: str, - inputs: PromptInputs, - params: SamplingParams, - lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[RequestOutput]: - ... - - @overload - def process_request( # type: ignore[misc] - self, - request_id: str, - inputs: PromptInputs, - params: PoolingParams, - lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[EmbeddingRequestOutput]: - ... - - def process_request( + async def _process_request( self, request_id: str, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], + *, + output_type: Type[_O], lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: + ) -> AsyncIterator[_O]: """Common logic to process requests with SamplingParams or PoolingParams.""" + arrival_time = time.time() - async def generator(): - arrival_time = time.time() + stream = await self.add_request( + request_id, + inputs, + params, + arrival_time=arrival_time, + lora_request=lora_request, + ) - stream = await self.add_request( - request_id, - inputs, - params, - arrival_time=arrival_time, - lora_request=lora_request, - ) + try: + is_first = True - try: - async for request_output in stream: - yield request_output - except (Exception, asyncio.CancelledError) as e: - self._abort(request_id) - raise e + async for request_output in stream: + # To improve performance, we only check the first result + if is_first: + if not isinstance(request_output, output_type): + raise TypeError( + f"Expected output of type {output_type}, " + f"but found type {type(request_output)}") + + is_first = False - return generator() + yield request_output # type: ignore + except (Exception, asyncio.CancelledError) as e: + self._abort(request_id) + raise e async def abort(self, request_id: str) -> None: """Abort a request. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 89f99aa8f7098..f4644a999745e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -346,7 +346,6 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer = get_tokenizer_group( self.parallel_config.tokenizer_pool_config, **init_kwargs) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 8943929371ae9..eb48df5e4266a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -537,13 +537,24 @@ def _run_engine(self, output_type: Type[_O], *, postfix=f"Generation Speed: {0:.2f} toks/s", ) # Run the engine. - outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] + outputs: List[_O] = [] total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() + is_first = True + for output in step_outputs: + # To improve performance, we only check the first result + if is_first: + if not isinstance(outputs[0], output_type): + raise TypeError( + f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + is_first = False + if output.finished: - outputs.append(output) + outputs.append(output) # type: ignore if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput @@ -557,10 +568,4 @@ def _run_engine(self, output_type: Type[_O], *, # Sort the outputs by request ID. # This is necessary because some requests may be finished earlier than # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - - if len(outputs) > 0 and not isinstance(outputs[0], output_type): - raise TypeError(f"Expected output type to be {output_type}, " - f"but found type {type(outputs[0])}") - - return cast(List[_O], outputs) + return sorted(outputs, key=lambda x: int(x.request_id)) From 44b4681f42c74b4761b5ece6d4e0d7dee5b3c261 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 07:56:31 +0000 Subject: [PATCH 42/94] Improve code readability --- vllm/engine/async_llm_engine.py | 2 +- vllm/engine/llm_engine.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe758c170b0c..11684949e7075 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -278,7 +278,7 @@ async def add_request_async( processed_inputs = await self.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request) - return self._add_processed_request( + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f4644a999745e..65606215b997f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -336,7 +336,7 @@ def get_tokenizer_for_seq(self, return self.get_tokenizer_group().get_lora_tokenizer( sequence.lora_request) - def _init_tokenizer(self, **tokenizer_init_kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: init_kwargs = dict( tokenizer_id=self.model_config.tokenizer, enable_lora=bool(self.lora_config), @@ -486,7 +486,7 @@ def add_request( inputs=inputs, lora_request=lora_request) - return self._add_processed_request( + self._add_processed_request( request_id=request_id, processed_inputs=processed_inputs, params=params, From 50343cb5fe52dde05a558f932938fd292fbb987c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 08:34:17 +0000 Subject: [PATCH 43/94] Fix `multi_modal_data` being a required key --- vllm/engine/async_llm_engine.py | 2 +- vllm/inputs.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 11684949e7075..7869d2f1e3b3e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -239,7 +239,7 @@ async def step_async( async def process_model_inputs_async( self, - request_id: str, # pylint: disable=unused-argument + request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, ) -> LLMInputs: diff --git a/vllm/inputs.py b/vllm/inputs.py index 80011b6dd1d61..b6cd23f0c907f 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -1,6 +1,8 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, TypedDict, Union, cast, overload) +from typing_extensions import NotRequired + if TYPE_CHECKING: from vllm.sequence import MultiModalData @@ -70,7 +72,7 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - multi_modal_data: Optional["MultiModalData"] + multi_modal_data: NotRequired[Optional["MultiModalData"]] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -83,7 +85,7 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" - multi_modal_data: Optional["MultiModalData"] + multi_modal_data: NotRequired[Optional["MultiModalData"]] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -102,7 +104,7 @@ class TextTokensPrompt(TypedDict): """The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs.""" - multi_modal_data: Optional["MultiModalData"] + multi_modal_data: NotRequired[Optional["MultiModalData"]] """ Optional multi-modal data to pass to the model, if the model supports it. From 45aa42017eac55cd51377d5ac8b1b26552dd9e88 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 08:39:14 +0000 Subject: [PATCH 44/94] Fix index out of range error --- vllm/entrypoints/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index eb48df5e4266a..d67a439603d5e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -546,7 +546,7 @@ def _run_engine(self, output_type: Type[_O], *, for output in step_outputs: # To improve performance, we only check the first result if is_first: - if not isinstance(outputs[0], output_type): + if not isinstance(output, output_type): raise TypeError( f"Expected output of type {output_type}, " f"but found type {type(output)}") From d4e2589be107f58276b55ef197bc3f2eb7df0cae Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 16:07:43 +0000 Subject: [PATCH 45/94] Use a flag to control whether to check output types --- vllm/engine/async_llm_engine.py | 20 +++++++------------- vllm/engine/llm_engine.py | 6 +++++- vllm/entrypoints/llm.py | 20 +++++++------------- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7869d2f1e3b3e..a63a94628dc9e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, TypeVar, Union) +from typing import (TYPE_CHECKING, AsyncIterator, Callable, Dict, Iterable, + List, Optional, Set, Tuple, Type, TypeVar, Union) from transformers import PreTrainedTokenizer @@ -762,19 +762,13 @@ async def _process_request( ) try: - is_first = True - async for request_output in stream: - # To improve performance, we only check the first result - if is_first: - if not isinstance(request_output, output_type): - raise TypeError( - f"Expected output of type {output_type}, " - f"but found type {type(request_output)}") - - is_first = False + if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + and not isinstance(request_output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(request_output)}") - yield request_output # type: ignore + yield request_output except (Exception, asyncio.CancelledError) as e: self._abort(request_id) raise e diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 65606215b997f..76c006a05f307 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,5 +1,5 @@ import time -from typing import Iterable, List, Optional +from typing import ClassVar, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Type, Union @@ -86,6 +86,10 @@ class LLMEngine: log_stats: Whether to log statistics. usage_context: Specified entry point, used for usage info collection. """ + + VALIDATE_OUTPUT_TYPES: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + tokenizer: Optional[BaseTokenizerGroup] def __init__( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index d67a439603d5e..489dbe1266451 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,6 @@ from contextlib import contextmanager -from typing import (ClassVar, List, Optional, Sequence, Type, TypeVar, Union, - cast, overload) +from typing import (TYPE_CHECKING, ClassVar, List, Optional, Sequence, Type, + TypeVar, Union, cast, overload) from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -541,20 +541,14 @@ def _run_engine(self, output_type: Type[_O], *, total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() - is_first = True - for output in step_outputs: - # To improve performance, we only check the first result - if is_first: - if not isinstance(output, output_type): - raise TypeError( - f"Expected output of type {output_type}, " - f"but found type {type(output)}") - - is_first = False + if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") if output.finished: - outputs.append(output) # type: ignore + outputs.append(output) if use_tqdm: if isinstance(output, RequestOutput): # Calculate tokens only for RequestOutput From c07b5798bd4c2c25e93f666cc4ba84468a1429f8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:05:32 +0000 Subject: [PATCH 46/94] Simplify flags --- tests/entrypoints/test_llm_encode.py | 18 +++++++++--------- tests/entrypoints/test_llm_generate.py | 18 +++++++++--------- vllm/engine/async_llm_engine.py | 4 +++- vllm/engine/llm_engine.py | 10 ++++++++++ vllm/entrypoints/llm.py | 12 +++++++----- 5 files changed, 38 insertions(+), 24 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 872707b54e3f7..39fc7c2e0f0b9 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -30,15 +30,15 @@ @pytest.fixture(scope="module") def llm(): - with LLM.deprecate_legacy_ctx(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True) - + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) + + with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 37d1ea7e8745b..44f5feb1aa0a2 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -28,15 +28,15 @@ @pytest.fixture(scope="module") def llm(): - with LLM.deprecate_legacy_ctx(): - # pytest caches the fixture so we use weakref.proxy to - # enable garbage collection - llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) - + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model=MODEL_NAME, + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) + + with llm.deprecate_legacy_api(): yield weakref.proxy(llm) del llm diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a63a94628dc9e..8e93394bd0d9f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -761,9 +761,11 @@ async def _process_request( lora_request=lora_request, ) + validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES + try: async for request_output in stream: - if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + if ((TYPE_CHECKING or validate_output_types) and not isinstance(request_output, output_type)): raise TypeError(f"Expected output of type {output_type}, " f"but found type {type(request_output)}") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 76c006a05f307..ced5f98820b3e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,4 +1,5 @@ import time +from contextlib import contextmanager from typing import ClassVar, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Type, Union @@ -90,6 +91,15 @@ class LLMEngine: VALIDATE_OUTPUT_TYPES: ClassVar[bool] = False """A flag to toggle whether to validate the type of request output.""" + @classmethod + @contextmanager + def validate_output_types(cls): + cls.VALIDATE_OUTPUT_TYPES = True + + yield + + cls.VALIDATE_OUTPUT_TYPES = False + tokenizer: Optional[BaseTokenizerGroup] def __init__( diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 489dbe1266451..c9d2d6eff0497 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -91,14 +91,14 @@ class LLM: DEPRECATE_LEGACY: ClassVar[bool] = False """A flag to toggle whether to deprecate the legacy generate/encode API.""" - @staticmethod + @classmethod @contextmanager - def deprecate_legacy_ctx(): - LLM.DEPRECATE_LEGACY = True + def deprecate_legacy_api(cls): + cls.DEPRECATE_LEGACY = True yield - LLM.DEPRECATE_LEGACY = False + cls.DEPRECATE_LEGACY = False def __init__( self, @@ -541,8 +541,10 @@ def _run_engine(self, output_type: Type[_O], *, total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() + validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES + for output in step_outputs: - if ((TYPE_CHECKING or LLMEngine.VALIDATE_OUTPUT_TYPES) + if ((TYPE_CHECKING or validate_output_types) and not isinstance(output, output_type)): raise TypeError(f"Expected output of type {output_type}, " f"but found type {type(output)}") From 9d56eb0667b107af30f89604fb6ff5818b451f49 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:31:18 +0000 Subject: [PATCH 47/94] Move output validation to a more appropriate location --- tests/entrypoints/test_llm_encode.py | 8 ++-- tests/entrypoints/test_llm_generate.py | 8 ++-- tests/lora/test_long_context.py | 4 +- tests/samplers/test_logits_processor.py | 4 +- tests/samplers/test_seeded_generate.py | 4 +- vllm/engine/async_llm_engine.py | 23 +++-------- vllm/engine/llm_engine.py | 52 ++++++++++++++++++++++--- vllm/entrypoints/llm.py | 25 +++++------- 8 files changed, 74 insertions(+), 54 deletions(-) diff --git a/tests/entrypoints/test_llm_encode.py b/tests/entrypoints/test_llm_encode.py index 39fc7c2e0f0b9..7c3fbe43a8384 100644 --- a/tests/entrypoints/test_llm_encode.py +++ b/tests/entrypoints/test_llm_encode.py @@ -33,10 +33,10 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=32768, - tensor_parallel_size=1, - gpu_memory_utilization=0.75, - enforce_eager=True) + max_num_batched_tokens=32768, + tensor_parallel_size=1, + gpu_memory_utilization=0.75, + enforce_eager=True) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/tests/entrypoints/test_llm_generate.py b/tests/entrypoints/test_llm_generate.py index 44f5feb1aa0a2..a00fff91a310e 100644 --- a/tests/entrypoints/test_llm_generate.py +++ b/tests/entrypoints/test_llm_generate.py @@ -31,10 +31,10 @@ def llm(): # pytest caches the fixture so we use weakref.proxy to # enable garbage collection llm = LLM(model=MODEL_NAME, - max_num_batched_tokens=4096, - tensor_parallel_size=1, - gpu_memory_utilization=0.10, - enforce_eager=True) + max_num_batched_tokens=4096, + tensor_parallel_size=1, + gpu_memory_utilization=0.10, + enforce_eager=True) with llm.deprecate_legacy_api(): yield weakref.proxy(llm) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 3dd9b98ed911b..15189f421a539 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -5,7 +5,7 @@ import pytest import vllm -from vllm import RequestOutput, SamplingParams +from vllm import SamplingParams from vllm.lora.layers import LinearScalingRotaryEmbeddingWithLora from vllm.lora.request import LoRARequest from vllm.model_executor.layers.rotary_embedding import ( @@ -100,7 +100,7 @@ def batched_generate( # Add requests to the engine and run the engine for request_data in requests_data: llm._add_request(**request_data) - outputs = llm._run_engine(RequestOutput, use_tqdm=True) + outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py index 1b63c1dab98d2..0ccbabfff6403 100644 --- a/tests/samplers/test_logits_processor.py +++ b/tests/samplers/test_logits_processor.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm import RequestOutput, SamplingParams +from vllm import SamplingParams MODELS = ["facebook/opt-125m"] @@ -54,6 +54,6 @@ def pick_vllm(token_ids, logits): params=SamplingParams(max_tokens=max_tokens), ) - outputs = vllm_model.model._run_engine(RequestOutput, use_tqdm=False) + outputs = vllm_model.model._run_engine(use_tqdm=False) assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index fca2b0e05c335..fef5ff3fb9e8e 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -8,7 +8,7 @@ import pytest -from vllm import RequestOutput, SamplingParams +from vllm import SamplingParams from vllm.model_executor.utils import set_random_seed MODEL = "facebook/opt-125m" @@ -59,7 +59,7 @@ def test_random_sample_with_seed( ): llm._add_request(prompt, params=params) - results = llm._run_engine(RequestOutput, use_tqdm=False) + results = llm._run_engine(use_tqdm=False) all_outputs = [[out.token_ids for out in output.outputs] for output in results] diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8e93394bd0d9f..53d8f4421ad72 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (TYPE_CHECKING, AsyncIterator, Callable, Dict, Iterable, - List, Optional, Set, Tuple, Type, TypeVar, Union) +from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, + Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -290,9 +290,6 @@ async def check_health_async(self) -> None: self.model_executor.check_health() -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) - - class AsyncLLMEngine: """An asynchronous wrapper for :class:`LLMEngine`. @@ -660,10 +657,9 @@ async def generate( request_id, inputs, sampling_params, - output_type=RequestOutput, lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, RequestOutput) async def encode( self, @@ -735,10 +731,9 @@ async def encode( request_id, inputs, pooling_params, - output_type=EmbeddingRequestOutput, lora_request=lora_request, ): - yield output + yield LLMEngine.validate_output(output, EmbeddingRequestOutput) async def _process_request( self, @@ -746,9 +741,8 @@ async def _process_request( inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], *, - output_type: Type[_O], lora_request: Optional[LoRARequest] = None, - ) -> AsyncIterator[_O]: + ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" arrival_time = time.time() @@ -761,15 +755,8 @@ async def _process_request( lora_request=lora_request, ) - validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES - try: async for request_output in stream: - if ((TYPE_CHECKING or validate_output_types) - and not isinstance(request_output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(request_output)}") - yield request_output except (Exception, asyncio.CancelledError) as e: self._abort(request_id) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ced5f98820b3e..7520e7eb1c4cc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,8 +1,8 @@ import time from contextlib import contextmanager -from typing import ClassVar, Iterable, List, Optional +from typing import TYPE_CHECKING, ClassVar, Iterable, List, Optional from typing import Sequence as GenericSequence -from typing import Type, Union +from typing import Type, TypeVar, Union from transformers import GenerationConfig, PreTrainedTokenizer @@ -54,6 +54,9 @@ def _load_generation_config_dict(model_config: ModelConfig): return {} +_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -88,17 +91,54 @@ class LLMEngine: usage_context: Specified entry point, used for usage info collection. """ - VALIDATE_OUTPUT_TYPES: ClassVar[bool] = False + DO_VALIDATE_OUTPUT: ClassVar[bool] = False """A flag to toggle whether to validate the type of request output.""" @classmethod @contextmanager - def validate_output_types(cls): - cls.VALIDATE_OUTPUT_TYPES = True + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True yield - cls.VALIDATE_OUTPUT_TYPES = False + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ tokenizer: Optional[BaseTokenizerGroup] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c9d2d6eff0497..4efc6b27f6f49 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,6 +1,5 @@ from contextlib import contextmanager -from typing import (TYPE_CHECKING, ClassVar, List, Optional, Sequence, Type, - TypeVar, Union, cast, overload) +from typing import ClassVar, List, Optional, Sequence, Union, cast, overload from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -20,8 +19,6 @@ logger = init_logger(__name__) -_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) - class LLM: """An LLM for generating texts from given prompts and sampling parameters. @@ -293,7 +290,8 @@ def generate( lora_request=lora_request, ) - return self._run_engine(RequestOutput, use_tqdm=use_tqdm) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) @overload # LEGACY: single (prompt + optional token ids) def encode( @@ -430,7 +428,8 @@ def encode( lora_request=lora_request, ) - return self._run_engine(EmbeddingRequestOutput, use_tqdm=use_tqdm) + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) # LEGACY def _convert_v1_inputs( @@ -525,8 +524,9 @@ def _add_request( params, lora_request=lora_request) - def _run_engine(self, output_type: Type[_O], *, - use_tqdm: bool) -> List[_O]: + def _run_engine( + self, *, use_tqdm: bool + ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Initialize tqdm. if use_tqdm: num_requests = self.llm_engine.get_num_unfinished_requests() @@ -537,18 +537,11 @@ def _run_engine(self, output_type: Type[_O], *, postfix=f"Generation Speed: {0:.2f} toks/s", ) # Run the engine. - outputs: List[_O] = [] + outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_toks = 0 while self.llm_engine.has_unfinished_requests(): step_outputs = self.llm_engine.step() - validate_output_types = LLMEngine.VALIDATE_OUTPUT_TYPES - for output in step_outputs: - if ((TYPE_CHECKING or validate_output_types) - and not isinstance(output, output_type)): - raise TypeError(f"Expected output of type {output_type}, " - f"but found type {type(output)}") - if output.finished: outputs.append(output) if use_tqdm: From bc05031fe237193f048189fa9be33c0067ed40ba Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:37:39 +0000 Subject: [PATCH 48/94] Add message to deprecation notice --- tests/test_utils.py | 11 ++++++++++- vllm/entrypoints/llm.py | 20 ++++++++++++-------- vllm/utils.py | 14 ++++++++------ 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 988dc5ba2bf29..df993d2665b64 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,7 +31,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): dummy(new_arg=1) -def test_deprecate_kwargs_func(): +def test_deprecate_kwargs_dynamic(): is_deprecated = True @deprecate_kwargs("old_arg", is_deprecated=lambda: is_deprecated) @@ -51,3 +51,12 @@ def dummy(*, old_arg: object = None, new_arg: object = None): with error_on_warning(): dummy(new_arg=1) + + +def test_deprecate_kwargs_additional_message(): + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") + def dummy(*, old_arg: object = None, new_arg: object = None): + pass + + with pytest.warns(DeprecationWarning, match="abcd"): + dummy(old_arg=1) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4efc6b27f6f49..05aea9aac6456 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -233,10 +233,12 @@ def generate( ) -> List[RequestOutput]: ... - @deprecate_kwargs('prompts', - 'prompt_token_ids', - 'multi_modal_data', - is_deprecated=lambda: LLM.DEPRECATE_LEGACY) + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") def generate( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -372,10 +374,12 @@ def encode( ) -> List[EmbeddingRequestOutput]: ... - @deprecate_kwargs('prompts', - 'prompt_token_ids', - 'multi_modal_data', - is_deprecated=lambda: LLM.DEPRECATE_LEGACY) + @deprecate_kwargs("prompts", + "prompt_token_ids", + "multi_modal_data", + is_deprecated=lambda: LLM.DEPRECATE_LEGACY, + additional_message="Please use the 'inputs' parameter " + "instead.") def encode( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], diff --git a/vllm/utils.py b/vllm/utils.py index 4480c6c6960de..979e15568a0dc 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -664,8 +664,8 @@ def identity(value: T) -> T: def deprecate_kwargs( *kws: str, - is_deprecated: Union[bool, Callable[[], - bool]] = True) -> Callable[[F], F]: + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: deprecated_kws = set(kws) if not callable(is_deprecated): @@ -678,11 +678,13 @@ def inner(*args, **kwargs): if is_deprecated(): deprecated_kwargs = kwargs.keys() & deprecated_kws if deprecated_kwargs: + msg = (f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + warnings.warn( - DeprecationWarning( - f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update." - ), + DeprecationWarning(msg), stacklevel=3, # The inner function takes up one level ) From 95d41303edd2b086f88afbc78939d03e4fe995c6 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Thu, 23 May 2024 22:40:22 +0000 Subject: [PATCH 49/94] Apply formatter --- tests/test_utils.py | 1 + vllm/entrypoints/llm.py | 4 ++-- vllm/utils.py | 5 +++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index df993d2665b64..54dc5c6f5bfba 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -54,6 +54,7 @@ def dummy(*, old_arg: object = None, new_arg: object = None): def test_deprecate_kwargs_additional_message(): + @deprecate_kwargs("old_arg", is_deprecated=True, additional_message="abcd") def dummy(*, old_arg: object = None, new_arg: object = None): pass diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 05aea9aac6456..53091cdc6ee42 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -238,7 +238,7 @@ def generate( "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " - "instead.") + "instead.") def generate( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], @@ -379,7 +379,7 @@ def encode( "multi_modal_data", is_deprecated=lambda: LLM.DEPRECATE_LEGACY, additional_message="Please use the 'inputs' parameter " - "instead.") + "instead.") def encode( self, prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], diff --git a/vllm/utils.py b/vllm/utils.py index 979e15568a0dc..1d99f0be8d3be 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -678,8 +678,9 @@ def inner(*args, **kwargs): if is_deprecated(): deprecated_kwargs = kwargs.keys() & deprecated_kws if deprecated_kwargs: - msg = (f"The keyword arguments {deprecated_kwargs} are " - "deprecated and will be removed in a future update.") + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") if additional_message is not None: msg += f" {additional_message}" From cc84f65c0ea3b45b77236b836bb4a422eac28066 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 24 May 2024 00:50:30 +0000 Subject: [PATCH 50/94] Remove unused parameter in `_validate_and_add_requests` and fix test --- tests/lora/test_long_context.py | 8 +++----- vllm/entrypoints/llm.py | 3 --- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 15189f421a539..4361e5452cdff 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -86,20 +86,18 @@ def generate( def batched_generate( - llm, + llm: vllm.LLM, inputs: List[Tuple[str, SamplingParams, Optional[LoRARequest]]], ): for input in inputs: prompt, sampling_param, lora_req = input - requests_data = llm._validate_and_prepare_requests( + # Add requests to the engine and run the engine + llm._validate_and_add_requests( prompt, sampling_param, lora_request=lora_req, ) - # Add requests to the engine and run the engine - for request_data in requests_data: - llm._add_request(**request_data) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 53091cdc6ee42..40ce9b1a992e5 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -288,7 +288,6 @@ def generate( self._validate_and_add_requests( inputs=inputs, params=sampling_params, - use_tqdm=use_tqdm, lora_request=lora_request, ) @@ -428,7 +427,6 @@ def encode( self._validate_and_add_requests( inputs=inputs, params=pooling_params, - use_tqdm=use_tqdm, lora_request=lora_request, ) @@ -495,7 +493,6 @@ def _validate_and_add_requests( inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], - use_tqdm: bool, lora_request: Optional[LoRARequest], ) -> None: if isinstance(inputs, (str, dict)): From 6c5d4a6cd70f576a809db6eab95f11a09f92db4f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 02:37:25 +0000 Subject: [PATCH 51/94] Simplify code --- vllm/engine/llm_engine.py | 3 ++- vllm/entrypoints/llm.py | 26 ++++++++++++-------------- vllm/inputs.py | 6 +++--- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7520e7eb1c4cc..7ce8021a205ee 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -424,11 +424,12 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - eos_token_id = None + if self.tokenizer: eos_token_id = self.tokenizer.get_lora_tokenizer( lora_request).eos_token_id else: + eos_token_id = None logger.warning("Use None for EOS token id because tokenizer is " "not initialized") seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 40ce9b1a992e5..9759d05577796 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,7 +6,8 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, PromptStrictInputs, +from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, + TextTokensPrompt, TokensPrompt, parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -467,25 +468,22 @@ def _convert_v1_inputs( for i in range(num_requests): if prompts is not None: if prompt_token_ids is not None: - inputs.append({ - "prompt": prompts[i], - "prompt_token_ids": prompt_token_ids[i], - "multi_modal_data": multi_modal_data, - }) + item = TextTokensPrompt( + prompt=prompts[i], + prompt_token_ids=prompt_token_ids[i]) else: - inputs.append({ - "prompt": prompts[i], - "multi_modal_data": multi_modal_data, - }) + item = TextPrompt(prompt=prompts[i]) else: if prompt_token_ids is not None: - inputs.append({ - "prompt_token_ids": prompt_token_ids[i], - "multi_modal_data": multi_modal_data, - }) + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) else: raise AssertionError + if multi_modal_data is not None: + item["multi_modal_data"] = multi_modal_data + + inputs.append(item) + return inputs def _validate_and_add_requests( diff --git a/vllm/inputs.py b/vllm/inputs.py index b6cd23f0c907f..f5d99b1b66b70 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -72,7 +72,7 @@ class TextPrompt(TypedDict): prompt: str """The input text to be tokenized before passing to the model.""" - multi_modal_data: NotRequired[Optional["MultiModalData"]] + multi_modal_data: NotRequired["MultiModalData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -85,7 +85,7 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" - multi_modal_data: NotRequired[Optional["MultiModalData"]] + multi_modal_data: NotRequired["MultiModalData"] """ Optional multi-modal data to pass to the model, if the model supports it. @@ -104,7 +104,7 @@ class TextTokensPrompt(TypedDict): """The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs.""" - multi_modal_data: NotRequired[Optional["MultiModalData"]] + multi_modal_data: NotRequired["MultiModalData"] """ Optional multi-modal data to pass to the model, if the model supports it. From fd2da125ea7bde3a83cbb65d25460c5abeff5b83 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 02:42:08 +0000 Subject: [PATCH 52/94] Move attribute assignment outside `_init_tokenizer` --- vllm/engine/llm_engine.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7ce8021a205ee..0be3d3140b6b1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -207,11 +207,11 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - tokenizer = self._init_tokenizer() + self.tokenizer = tokenizer = self._init_tokenizer() self.detokenizer = Detokenizer(tokenizer) else: - self.detokenizer = None self.tokenizer = None + self.detokenizer = None self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( @@ -376,7 +376,9 @@ def __del__(self): MISSING_TOKENIZER_GROUP_MSG = ("Unable to get tokenizer because " "skip_tokenizer_init is True") - def get_tokenizer_group(self, fail_msg: str = MISSING_TOKENIZER_GROUP_MSG): + def get_tokenizer_group( + self, + fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: if self.tokenizer is None: raise ValueError(fail_msg) @@ -400,10 +402,9 @@ def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup: trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer = get_tokenizer_group( - self.parallel_config.tokenizer_pool_config, **init_kwargs) - return self.tokenizer + return get_tokenizer_group(self.parallel_config.tokenizer_pool_config, + **init_kwargs) def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) From d78de94c1c8be82f19213a0d68860925e40edf75 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 03:12:31 +0000 Subject: [PATCH 53/94] Only emit warning once --- vllm/engine/llm_engine.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0be3d3140b6b1..a0898562d4ccf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -213,6 +213,8 @@ def __init__( self.tokenizer = None self.detokenizer = None + self._eos_warn_count = 0 + self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) @@ -414,6 +416,18 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) + def _get_eos_token_id( + self, lora_request: Optional[LoRARequest]) -> Optional[int]: + if self.tokenizer: + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + else: + if self._eos_warn_count == 0: + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") + + self._eos_warn_count += 1 + return None + def _add_processed_request( self, request_id: str, @@ -425,14 +439,8 @@ def _add_processed_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) + eos_token_id = self._get_eos_token_id(lora_request) - if self.tokenizer: - eos_token_id = self.tokenizer.get_lora_tokenizer( - lora_request).eos_token_id - else: - eos_token_id = None - logger.warning("Use None for EOS token id because tokenizer is " - "not initialized") seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request) From 8a868299fab4f9505f3a7568114b54a78513abf7 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 03:16:24 +0000 Subject: [PATCH 54/94] Simplify assignment expression --- vllm/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a0898562d4ccf..3c50033e46e30 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -207,8 +207,8 @@ def __init__( self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: - self.tokenizer = tokenizer = self._init_tokenizer() - self.detokenizer = Detokenizer(tokenizer) + self.tokenizer = self._init_tokenizer() + self.detokenizer = Detokenizer(self.tokenizer) else: self.tokenizer = None self.detokenizer = None From 731ac0e2cf03006ef653fe3079a9812644a02825 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 25 May 2024 03:19:11 +0000 Subject: [PATCH 55/94] Place special case at the start --- vllm/engine/llm_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3c50033e46e30..cbc9bf741476a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -418,9 +418,7 @@ def _verify_args(self) -> None: def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: - if self.tokenizer: - return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - else: + if self.tokenizer is None: if self._eos_warn_count == 0: logger.warning("Using None for EOS token id because tokenizer " "is not initialized") @@ -428,6 +426,8 @@ def _get_eos_token_id( self._eos_warn_count += 1 return None + return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id + def _add_processed_request( self, request_id: str, From 2d1a0bccf100b6e9ca6912b69caac72b8ba81ded Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sat, 25 May 2024 12:24:44 -0700 Subject: [PATCH 56/94] move API reference to under developer doc --- docs/source/{ => dev}/offline_inference/llm.rst | 0 .../{ => dev}/offline_inference/llm_inputs.rst | 0 docs/source/dev/offline_inference/offline_index.rst | 8 ++++++++ .../{offline_inference => dev}/sampling_params.rst | 0 docs/source/index.rst | 12 +++--------- 5 files changed, 11 insertions(+), 9 deletions(-) rename docs/source/{ => dev}/offline_inference/llm.rst (100%) rename docs/source/{ => dev}/offline_inference/llm_inputs.rst (100%) create mode 100644 docs/source/dev/offline_inference/offline_index.rst rename docs/source/{offline_inference => dev}/sampling_params.rst (100%) diff --git a/docs/source/offline_inference/llm.rst b/docs/source/dev/offline_inference/llm.rst similarity index 100% rename from docs/source/offline_inference/llm.rst rename to docs/source/dev/offline_inference/llm.rst diff --git a/docs/source/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst similarity index 100% rename from docs/source/offline_inference/llm_inputs.rst rename to docs/source/dev/offline_inference/llm_inputs.rst diff --git a/docs/source/dev/offline_inference/offline_index.rst b/docs/source/dev/offline_inference/offline_index.rst new file mode 100644 index 0000000000000..27dfb0e9df90e --- /dev/null +++ b/docs/source/dev/offline_inference/offline_index.rst @@ -0,0 +1,8 @@ +Offline Inference +================================= + +.. toctree:: + :maxdepth: 1 + + llm + llm_inputs diff --git a/docs/source/offline_inference/sampling_params.rst b/docs/source/dev/sampling_params.rst similarity index 100% rename from docs/source/offline_inference/sampling_params.rst rename to docs/source/dev/sampling_params.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 6383680f2b512..acf02c1c22251 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -68,14 +68,6 @@ Documentation getting_started/quickstart getting_started/examples/examples_index -.. toctree:: - :maxdepth: 1 - :caption: Offline Inference - - offline_inference/llm - offline_inference/llm_inputs - offline_inference/sampling_params - .. toctree:: :maxdepth: 1 :caption: Serving @@ -109,7 +101,9 @@ Documentation .. toctree:: :maxdepth: 2 :caption: Developer Documentation - + + dev/sampling_params + dev/offline_inference/offline_index dev/engine/engine_index dev/kernel/paged_attention dev/dockerfile/dockerfile From 7b8ce2c271c2f5d2d726fc084dbe8942be0d5199 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 01:32:29 +0000 Subject: [PATCH 57/94] Fix links in docs --- docs/source/serving/openai_compatible_server.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a775c6addf1d9..15a8761eb5738 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -48,7 +48,7 @@ completion = client.chat.completions.create( ``` ### Extra Parameters for Chat API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python @@ -65,7 +65,7 @@ The following extra parameters are supported: ``` ### Extra Parameters for Completions API -The following [sampling parameters (click through to see documentation)](../offline_inference/sampling_params.rst) are supported. +The following [sampling parameters (click through to see documentation)](../dev/sampling_params.rst) are supported. ```{literalinclude} ../../../vllm/entrypoints/openai/protocol.py :language: python From fff21a1ec0e7520f5b4f46dcb49c36e305e80894 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 01:34:08 +0000 Subject: [PATCH 58/94] Remove unnecessary code to avoid repeated warning --- vllm/engine/llm_engine.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cbc9bf741476a..0dd42a1867c46 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -213,8 +213,6 @@ def __init__( self.tokenizer = None self.detokenizer = None - self._eos_warn_count = 0 - self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( model_config) @@ -419,11 +417,8 @@ def _verify_args(self) -> None: def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: if self.tokenizer is None: - if self._eos_warn_count == 0: - logger.warning("Using None for EOS token id because tokenizer " - "is not initialized") - - self._eos_warn_count += 1 + logger.warning("Using None for EOS token id because tokenizer " + "is not initialized") return None return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id From fab7f9233232c9830ac9dd19e51acd61a7da9a23 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 01:38:26 +0000 Subject: [PATCH 59/94] Parse and batch the prompt using #4328 --- vllm/entrypoints/openai/serving_engine.py | 53 +++-------------------- 1 file changed, 5 insertions(+), 48 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5e26eac77c856..dc1933fd2e5df 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,8 +1,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import (Dict, Iterable, Iterator, List, Literal, Optional, Tuple, - TypedDict, Union, cast) +from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -15,6 +14,7 @@ EmbeddingRequest, ErrorResponse, LogProbs, ModelCard, ModelList, ModelPermission) +from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import Logprob @@ -23,16 +23,6 @@ logger = init_logger(__name__) -class InputString(TypedDict): - text: str - is_tokens: Literal[False] - - -class InputTokens(TypedDict): - text: List[int] - is_tokens: Literal[True] - - @dataclass class LoRAModulePath: name: str @@ -304,38 +294,6 @@ def _tokenize_prompt_inputs( truncate_prompt_tokens=truncate_prompt_tokens, ) - def _parse_prompt_input_or_inputs( - self, - input_or_inputs: Union[str, List[str], List[int], List[List[int]]], - ) -> List[Union[InputString, InputTokens]]: - if isinstance(input_or_inputs, str): - # case 1: a string - return [InputString(text=input_or_inputs, is_tokens=False)] - - if isinstance(input_or_inputs, list): - if len(input_or_inputs) == 0: - raise ValueError("please provide at least one prompt") - if isinstance(input_or_inputs[0], str): - # case 2: array of strings - return [ - InputString(text=elem, is_tokens=False) - for elem in cast(List[str], input_or_inputs) - ] - if isinstance(input_or_inputs[0], int): - # case 3: array of tokens - elem = cast(List[int], input_or_inputs) - return [InputTokens(text=elem, is_tokens=True)] - if isinstance(input_or_inputs[0], list) and isinstance( - input_or_inputs[0][0], int): - # case 4: array of token arrays - return [ - InputTokens(text=elem, is_tokens=True) - for elem in cast(List[List[int]], input_or_inputs) - ] - - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") - def _tokenize_prompt_input_or_inputs( self, request: Union[ChatCompletionRequest, CompletionRequest, @@ -352,8 +310,7 @@ def _tokenize_prompt_input_or_inputs( """ tokenizer = self.tokenizer - for prompt_input in self._parse_prompt_input_or_inputs( - input_or_inputs): + for prompt_input in parse_and_batch_prompt(input_or_inputs): # Although our type checking is based on mypy, # VSCode Pyright extension should still work properly # "is True" is required for Pyright to perform type narrowing @@ -361,7 +318,7 @@ def _tokenize_prompt_input_or_inputs( if prompt_input["is_tokens"] is False: yield self._normalize_prompt_text_to_input( request, - prompt=prompt_input["text"], + prompt=prompt_input["content"], tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=add_special_tokens, @@ -369,7 +326,7 @@ def _tokenize_prompt_input_or_inputs( else: yield self._normalize_prompt_tokens_to_input( request, - prompt_ids=prompt_input["text"], + prompt_ids=prompt_input["content"], tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, ) From cb057ebaf8dbb55cbf723a687b606dfbc5e5e00c Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 02:07:25 +0000 Subject: [PATCH 60/94] Move logging from async engine to server --- benchmarks/benchmark_latency.py | 4 +- .../dev/offline_inference/llm_inputs.rst | 2 +- vllm/__init__.py | 4 +- vllm/engine/arg_utils.py | 6 -- vllm/engine/async_llm_engine.py | 35 +++------- vllm/entrypoints/llm.py | 37 ++++------ vllm/entrypoints/openai/api_server.py | 25 +++++-- vllm/entrypoints/openai/cli_args.py | 8 +++ vllm/entrypoints/openai/run_batch.py | 15 ++++ vllm/entrypoints/openai/serving_chat.py | 35 ++++++---- vllm/entrypoints/openai/serving_completion.py | 33 ++++++--- vllm/entrypoints/openai/serving_embedding.py | 31 +++++--- vllm/entrypoints/openai/serving_engine.py | 70 +++++++++++++++---- vllm/inputs.py | 25 +------ 14 files changed, 195 insertions(+), 135 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 3146fb33cc27e..0271ec9593683 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -10,7 +10,7 @@ from tqdm import tqdm from vllm import LLM, SamplingParams -from vllm.inputs import PromptStrictInputs +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -49,7 +49,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptStrictInputs] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 31c3d16a3c8eb..9adf82d43f3e0 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptStrictInputs +.. autodata:: vllm.inputs.PromptInputs .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/vllm/__init__.py b/vllm/__init__.py index a0e154d24087c..16ad660087526 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -17,7 +17,7 @@ __all__ = [ "LLM", "ModelRegistry", - "PromptStrictInputs", + "PromptInputs", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 538e3427e37fb..76f920aba09ea 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -678,12 +678,6 @@ def add_cli_args(parser: argparse.ArgumentParser, parser.add_argument('--disable-log-requests', action='store_true', help='Disable logging requests.') - parser.add_argument('--max-log-len', - type=int, - default=None, - help='Max number of prompt characters or prompt ' - 'ID numbers being printed in log.' - '\n\nDefault: Unlimited') return parser diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d4289c715d9e6..de265e123c32e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -135,7 +135,10 @@ def process_exception(self, logger.info("Finished request %s.", request_id) self.abort_request(request_id) - def add_request(self, request_id: str, + def add_request(self, + request_id: str, + *, + verbose: bool = False, **engine_add_request_kwargs) -> AsyncStream: """Add a request to be sent to the engine on the next background loop iteration.""" @@ -150,6 +153,9 @@ def add_request(self, request_id: str, self.new_requests_event.set() + if verbose: + logger.info("Added request %s.", request_id) + return stream def abort_request(self, request_id: str, *, verbose: bool = False) -> None: @@ -317,8 +323,6 @@ class AsyncLLMEngine: async frontend will be executed in a separate process as the model workers. log_requests: Whether to log the requests. - max_log_len: Maximum number of prompt characters or prompt ID numbers - being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. *args: Arguments for :class:`LLMEngine`. @@ -332,13 +336,11 @@ def __init__(self, engine_use_ray: bool, *args, log_requests: bool = True, - max_log_len: Optional[int] = None, start_engine_loop: bool = True, **kwargs) -> None: self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests - self.max_log_len = max_log_len self.engine = self._init_engine(*args, **kwargs) self.background_loop: Optional[asyncio.Future] = None @@ -392,7 +394,6 @@ def from_engine_args( executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, - max_log_len=engine_args.max_log_len, start_engine_loop=start_engine_loop, usage_context=usage_context, ) @@ -537,27 +538,6 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, ) -> AsyncStream: - if self.log_requests: - if isinstance(inputs, str): - shortened_prompt = inputs - shortened_token_ids = None - else: - shortened_prompt = inputs.get("prompt") - shortened_token_ids = inputs.get("prompt_token_ids") - - max_log_len = self.max_log_len - if max_log_len is not None: - if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:max_log_len] - if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:max_log_len] - - logger.info( - "Received request %s: prompt: %r, " - "params: %s, prompt_token_ids: %s, " - "lora_request: %s.", request_id, shortened_prompt, params, - shortened_token_ids, lora_request) - if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -585,6 +565,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, + verbose=self.log_requests, inputs=processed_inputs, params=params, arrival_time=arrival_time, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9759d05577796..3e2216a4dc051 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,8 +6,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, - TextTokensPrompt, TokensPrompt, +from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -224,7 +223,7 @@ def generate( @overload def generate( self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, @@ -242,7 +241,7 @@ def generate( "instead.") def generate( self, - prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -278,9 +277,7 @@ def generate( multi_modal_data=multi_modal_data, ) else: - inputs = cast( - Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if sampling_params is None: # Use default sampling params. @@ -364,7 +361,7 @@ def encode( @overload def encode( self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, @@ -382,7 +379,7 @@ def encode( "instead.") def encode( self, - prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -399,7 +396,7 @@ def encode( Args: inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + batch inference. See :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. @@ -417,9 +414,7 @@ def encode( multi_modal_data=multi_modal_data, ) else: - inputs = cast( - Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if pooling_params is None: # Use default pooling params. @@ -467,17 +462,11 @@ def _convert_v1_inputs( inputs: List[PromptInputs] = [] for i in range(num_requests): if prompts is not None: - if prompt_token_ids is not None: - item = TextTokensPrompt( - prompt=prompts[i], - prompt_token_ids=prompt_token_ids[i]) - else: - item = TextPrompt(prompt=prompts[i]) + item = TextPrompt(prompt=prompts[i]) + elif prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) else: - if prompt_token_ids is not None: - item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) - else: - raise AssertionError + raise AssertionError if multi_modal_data is not None: item["multi_modal_data"] = multi_modal_data @@ -488,7 +477,7 @@ def _convert_v1_inputs( def _validate_and_add_requests( self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[LoRARequest], diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 97b35262329ee..b37943aeb276f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -200,15 +200,30 @@ async def authentication(request: Request, call_next): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) - openai_serving_chat = OpenAIServingChat(engine, model_config, + log_requests = not args.disable_log_requests + max_log_len = args.max_log_len + + openai_serving_chat = OpenAIServingChat(engine, + model_config, served_model_names, args.response_role, args.lora_modules, - args.chat_template) + args.chat_template, + log_requests=log_requests, + max_log_len=max_log_len) openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules) - openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, - served_model_names) + engine, + model_config, + served_model_names, + args.lora_modules, + log_requests=log_requests, + max_log_len=max_log_len) + openai_serving_embedding = OpenAIServingEmbedding( + engine, + model_config, + served_model_names, + log_requests=log_requests, + max_log_len=max_log_len) app.root_path = args.root_path uvicorn.run(app, host=args.host, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 4c0cb1e4f3e49..1667af7f5469f 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -112,4 +112,12 @@ def make_arg_parser(): "using app.add_middleware(). ") parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + return parser diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 731f4f4a4028a..f05eeae692cf2 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -46,6 +46,14 @@ def parse_args(): "`request.add_generation_prompt=true`.") parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + return parser.parse_args() @@ -106,11 +114,18 @@ async def main(args): # When using single vLLM without engine_use_ray model_config = await engine.get_model_config() + log_requests = not args.disable_log_requests + max_log_len = args.max_log_len + openai_serving_chat = OpenAIServingChat( engine, model_config, served_model_names, args.response_role, + lora_modules=None, + chat_template=None, + log_requests=log_requests, + max_log_len=max_log_len, ) # Submit all requests in the file to the engine "concurrently". diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3a3329ee79b9d..5ec992e0a2d0b 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -39,17 +39,24 @@ class ChatMessageParseResult: class OpenAIServingChat(OpenAIServing): - def __init__(self, - engine: AsyncLLMEngine, - model_config: ModelConfig, - served_model_names: List[str], - response_role: str, - lora_modules: Optional[List[LoRAModulePath]] = None, - chat_template: Optional[str] = None): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + response_role: str, + lora_modules: Optional[List[LoRAModulePath]], + chat_template: Optional[str], + *, + log_requests: bool, + max_log_len: Optional[int], + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + log_requests=log_requests, + max_log_len=max_log_len) self.response_role = response_role self._load_chat_template(chat_template) @@ -171,18 +178,20 @@ async def create_chat_completion( sampling_params.logits_processors.append( guided_decode_logits_processor) - prompt_ids, prompt_text = self._tokenize_prompt_input( + prompt_inputs = self._tokenize_prompt_input( request, prompt, truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, add_special_tokens=False, ) + self._log_inputs(request_id, + prompt_inputs, + sampling_params, + lora_request=lora_request) + result_generator = self.engine.generate( - { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids - }, + {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, request_id, lora_request, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 28d2901e2971b..675ca6ea82bfb 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -30,13 +30,22 @@ class OpenAIServingCompletion(OpenAIServing): - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + *, + log_requests: bool, + max_log_len: Optional[int], + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + log_requests=log_requests, + max_log_len=max_log_len) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -88,14 +97,18 @@ async def create_completion(self, request: CompletionRequest, truncate_prompt_tokens, )) - for i, (prompt_ids, prompt_text) in enumerate(prompts): + for i, prompt_inputs in enumerate(prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + prompt_inputs, + sampling_params, + lora_request=lora_request) + generator = self.engine.generate( - { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids - }, + {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, - f"{request_id}-{i}", + request_id_item, lora_request=lora_request, ) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 01187c1ad3df9..ad3b0897151fd 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -20,12 +20,21 @@ class OpenAIServingEmbedding(OpenAIServing): - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + *, + log_requests: bool, + max_log_len: Optional[int], + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=None) + lora_modules=None, + log_requests=log_requests, + max_log_len=max_log_len) self._check_embedding_mode(model_config.embedding_mode) async def create_embedding(self, request: EmbeddingRequest, @@ -62,14 +71,18 @@ async def create_embedding(self, request: EmbeddingRequest, request.input, )) - for i, (prompt_ids, prompt_text) in enumerate(prompts): + for i, prompt_inputs in enumerate(prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + prompt_inputs, + pooling_params, + lora_request=None) + generator = self.engine.encode( - { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids - }, + {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, - f"{request_id}-{i}", + request_id_item, ) generators.append(generator) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index dc1933fd2e5df..6f24a714cee7e 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -1,7 +1,7 @@ import json from dataclasses import dataclass from http import HTTPStatus -from typing import Dict, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Dict, Iterable, Iterator, List, Optional, TypedDict, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -17,12 +17,19 @@ from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import get_tokenizer logger = init_logger(__name__) +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: List[int] + + @dataclass class LoRAModulePath: name: str @@ -31,9 +38,16 @@ class LoRAModulePath: class OpenAIServing: - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]], + *, + log_requests: bool, + max_log_len: Optional[int], + ): self.engine = engine self.max_model_len = model_config.max_model_len @@ -58,6 +72,9 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, ) for i, lora in enumerate(lora_modules, start=1) ] + self.log_requests = log_requests + self.max_log_len = max_log_len + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -174,7 +191,7 @@ def _normalize_prompt_text_to_input( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], add_special_tokens: bool, - ) -> Tuple[List[int], str]: + ) -> TextTokensPrompt: if truncate_prompt_tokens is None: encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) else: @@ -196,7 +213,7 @@ def _normalize_prompt_tokens_to_input( prompt_ids: List[int], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], - ) -> Tuple[List[int], str]: + ) -> TextTokensPrompt: if truncate_prompt_tokens is None: input_ids = prompt_ids else: @@ -212,7 +229,7 @@ def _validate_input( EmbeddingRequest], input_ids: List[int], input_text: str, - ) -> Tuple[List[int], str]: + ) -> TextTokensPrompt: token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens @@ -223,7 +240,8 @@ def _validate_input( f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the input for embedding " f"generation. Please reduce the length of the input.") - return input_ids, input_text + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) if request.max_tokens is None: if token_num >= self.max_model_len: @@ -243,7 +261,7 @@ def _validate_input( f"{request.max_tokens} in the completion). " f"Please reduce the length of the messages or completion.") - return input_ids, input_text + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) def _tokenize_prompt_input( self, @@ -252,7 +270,7 @@ def _tokenize_prompt_input( prompt_input: Union[str, List[int]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Tuple[List[int], str]: + ) -> TextTokensPrompt: """A simpler implementation of :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` that assumes single input.""" @@ -271,7 +289,7 @@ def _tokenize_prompt_inputs( prompt_inputs: Iterable[Union[str, List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Iterator[Tuple[List[int], str]]: + ) -> Iterator[TextTokensPrompt]: """A simpler implementation of :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` that assumes multiple inputs.""" @@ -301,7 +319,7 @@ def _tokenize_prompt_input_or_inputs( input_or_inputs: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, - ) -> Iterator[Tuple[List[int], str]]: + ) -> Iterator[TextTokensPrompt]: """Tokenize/detokenize depending on the input format. According to `OpenAI API `_ @@ -330,3 +348,31 @@ def _tokenize_prompt_input_or_inputs( tokenizer=tokenizer, truncate_prompt_tokens=truncate_prompt_tokens, ) + + def _log_inputs( + self, + request_id: str, + inputs: TextTokensPrompt, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[LoRARequest], + ) -> None: + if self.log_requests: + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + else: + shortened_prompt = inputs.get("prompt") + shortened_token_ids = inputs.get("prompt_token_ids") + + max_log_len = self.max_log_len + if max_log_len is not None: + if shortened_prompt is not None: + shortened_prompt = shortened_prompt[:max_log_len] + if shortened_token_ids is not None: + shortened_token_ids = shortened_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s.", request_id, shortened_prompt, params, + shortened_token_ids, lora_request) diff --git a/vllm/inputs.py b/vllm/inputs.py index f5d99b1b66b70..6aead2e122254 100644 --- a/vllm/inputs.py +++ b/vllm/inputs.py @@ -92,26 +92,7 @@ class TokensPrompt(TypedDict): """ -class TextTokensPrompt(TypedDict): - """It is assumed that :attr:`prompt` is consistent with - :attr:`prompt_token_ids`. This is currently used in - :class:`AsyncLLMEngine` for logging both the text and token IDs.""" - - prompt: str - """The prompt text.""" - - prompt_token_ids: List[int] - """The token IDs of the prompt. If None, we use the - tokenizer to convert the prompts to token IDs.""" - - multi_modal_data: NotRequired["MultiModalData"] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - - -PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +PromptInputs = Union[str, TextPrompt, TokensPrompt] """ The inputs to the LLM, which can take one of the following forms: @@ -119,10 +100,6 @@ class TextTokensPrompt(TypedDict): - A tokenized prompt (:class:`TokensPrompt`) """ -PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] -"""Same as :const:`PromptStrictInputs` but additionally accepts -:class:`TextTokensPrompt`.""" - class LLMInputs(TypedDict): prompt_token_ids: List[int] From b682dfab8cccc88a9b65096c50b74fb1a7e63e5e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 05:18:58 +0000 Subject: [PATCH 61/94] Fix missing args in test --- tests/entrypoints/openai/test_serving_chat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index c45f02fe564a3..6be5392b2e444 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -36,7 +36,10 @@ async def _async_serving_chat_init(): model_config, served_model_names=[MODEL_NAME], response_role="assistant", - chat_template=CHAT_TEMPLATE) + chat_template=CHAT_TEMPLATE, + lora_modules=None, + log_requests=False, + max_log_len=None) return serving_completion From d87ae34eafddd613cdc9e79f17fe7f42db1e4ab9 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 05:29:39 +0000 Subject: [PATCH 62/94] Improve control flow --- vllm/entrypoints/openai/serving_engine.py | 36 ++++++++++------------- 1 file changed, 16 insertions(+), 20 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 33ce6626b5a4a..aa3fdd9557534 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -357,23 +357,19 @@ def _log_inputs( params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest], ) -> None: - if self.log_requests: - if isinstance(inputs, str): - shortened_prompt = inputs - shortened_token_ids = None - else: - shortened_prompt = inputs.get("prompt") - shortened_token_ids = inputs.get("prompt_token_ids") - - max_log_len = self.max_log_len - if max_log_len is not None: - if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:max_log_len] - if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:max_log_len] - - logger.info( - "Received request %s: prompt: %r, " - "params: %s, prompt_token_ids: %s, " - "lora_request: %s.", request_id, shortened_prompt, params, - shortened_token_ids, lora_request) + if not self.log_requests: + return + + shortened_prompt = inputs["prompt"] + shortened_token_ids = inputs["prompt_token_ids"] + + max_log_len = self.max_log_len + if max_log_len is not None: + shortened_prompt = shortened_prompt[:max_log_len] + shortened_token_ids = shortened_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s.", request_id, shortened_prompt, params, + shortened_token_ids, lora_request) From c73e9724cbdb64c264fb7a2c71e8fa36c9dc4969 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 06:06:07 +0000 Subject: [PATCH 63/94] Fix missing input text when processing Completions API output --- vllm/entrypoints/openai/serving_completion.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 675ca6ea82bfb..93c86b456224f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -148,8 +148,15 @@ async def create_completion(self, request: CompletionRequest, final_res_batch[i] = res final_res_batch_checked: List[RequestOutput] = [] - for final_res in final_res_batch: + for i, final_res in enumerate(final_res_batch): assert final_res is not None + + # The output should contain the input text + # We did not pass it into vLLM engine to avoid being redundant + # with the inputs token IDs + if final_res.prompt is None: + final_res.prompt = prompts[i]["prompt"] + final_res_batch_checked.append(final_res) response = self.request_output_to_completion_response( From 7d2c08b697ddcbf7365038b77a1352fa0f0a26fe Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 26 May 2024 09:25:39 +0000 Subject: [PATCH 64/94] Remove extra attribute --- vllm/engine/arg_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 76f920aba09ea..d6a7f704f69c1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -664,7 +664,6 @@ class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" engine_use_ray: bool = False disable_log_requests: bool = False - max_log_len: Optional[int] = None @staticmethod def add_cli_args(parser: argparse.ArgumentParser, From 2936ee0ba076e79b9e707c0cce60e7643fc9efe2 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 8 Jun 2024 02:22:47 +0000 Subject: [PATCH 65/94] Update docs --- vllm/entrypoints/openai/serving_engine.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1c1ad20597370..a9ea108f6ee81 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -228,9 +228,10 @@ def _tokenize_prompt_input( truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> TextTokensPrompt: - """A simpler implementation of - :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` - that assumes single input.""" + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes single input. + """ return next( self._tokenize_prompt_inputs( request, @@ -247,9 +248,10 @@ def _tokenize_prompt_inputs( truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: - """A simpler implementation of - :meth:`~vllm.entrypoints.openai.serving_engine.OpenAIServing._tokenize_prompt_input_or_inputs` - that assumes multiple inputs.""" + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes multiple inputs. + """ tokenizer = self.tokenizer for text in prompt_inputs: @@ -277,7 +279,8 @@ def _tokenize_prompt_input_or_inputs( truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, ) -> Iterator[TextTokensPrompt]: - """Tokenize/detokenize depending on the input format. + """ + Tokenize/detokenize depending on the input format. According to `OpenAI API `_ , each input can be a string or array of tokens. Note that each request From b0fa8ff726b0c0d05dfea99f09dcb6b83330a973 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 25 Jun 2024 08:37:19 +0000 Subject: [PATCH 66/94] Fix type errors --- vllm/engine/async_llm_engine.py | 14 +++++++------- vllm/engine/llm_engine.py | 8 ++++---- vllm/sequence.py | 4 ++-- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d3f43caf07bd9..6cced3ef1b0fb 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, Union) +from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping, + Optional, Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -295,7 +295,7 @@ async def add_request_async( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -563,7 +563,7 @@ async def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncStream: if not self.is_running: if self.start_engine_loop: @@ -596,7 +596,7 @@ async def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -675,7 +675,7 @@ async def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. @@ -753,7 +753,7 @@ async def _process_request( params: Union[SamplingParams, PoolingParams], *, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7eae257fdd16..add3abd754728 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,6 +1,6 @@ import time from contextlib import contextmanager -from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, ClassVar, Iterable, List, Mapping, Optional from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union @@ -457,7 +457,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> None: # Create the sequences. block_size = self.cache_config.block_size @@ -522,7 +522,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> None: """Add a request to the engine's request pool. @@ -592,7 +592,7 @@ def _create_sequence_group_with_sampling( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs diff --git a/vllm/sequence.py b/vllm/sequence.py index 287e1b9df6165..8542fd0afbecc 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -3,7 +3,7 @@ import enum from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union import torch @@ -427,7 +427,7 @@ def __init__( embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} From 4fc018cc8a7afc4bbb8d9ec7c39deea45ee1cf8a Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Tue, 2 Jul 2024 02:45:36 +0000 Subject: [PATCH 67/94] Fix bad merge --- docs/source/dev/multimodal/multimodal_index.rst | 2 +- docs/source/models/vlm.rst | 2 +- vllm/engine/llm_engine.py | 4 ++-- vllm/entrypoints/openai/serving_embedding.py | 2 +- vllm/inputs/__init__.py | 7 +++---- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index f6fdfc1debffb..4336a0199a53f 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -5,7 +5,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. -:class:`vllm.inputs.PromptStrictInputs` accepts an additional attribute ``multi_modal_data`` +:class:`vllm.inputs.PromptInputs` accepts an additional attribute ``multi_modal_data`` which allows you to pass in multi-modal input alongside text and token prompts. By default, vLLM models do not support multi-modal inputs. To enable multi-modal support for a model, diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 1837dd2aa89f7..d6ce58f8771a5 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -46,7 +46,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` We will remove most of the vision-specific arguments in a future release as they can be inferred from the HuggingFace configuration. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: * ``prompt``: The prompt should have a number of ```` tokens equal to ``image_feature_size``. * ``multi_modal_data``: This should be an instance of :class:`~vllm.multimodal.image.ImagePixelData` or :class:`~vllm.multimodal.image.ImageFeatureData`. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7e567e6ee23c1..6ea029a508d08 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,7 +1,7 @@ import time from contextlib import contextmanager -from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Mapping, - Optional) +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, + Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 95acc3f16dd38..24131d9b78e47 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -31,7 +31,7 @@ def request_output_to_embedding_response( prompt_token_ids = final_res.prompt_token_ids embedding = final_res.outputs.embedding if encoding_format == "base64": - embedding = base64.b64encode(np.array(embedding)) + embedding = base64.b64encode(np.array(embedding)).decode("utf-8") embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index d094156962955..b13d9acf93d3b 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,6 +1,5 @@ from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, - PromptStrictInputs, TextPrompt, TextTokensPrompt, - TokensPrompt, parse_and_batch_prompt) + TextPrompt, TokensPrompt, parse_and_batch_prompt) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -14,6 +13,6 @@ __all__ = [ "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", - "TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", - "LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry" + "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY", + "InputContext", "InputRegistry" ] From 59870cfed0afa86a00b4df86c1941fd07418cfa6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 07:48:07 +0000 Subject: [PATCH 68/94] Fix linter errors --- vllm/entrypoints/openai/api_server.py | 3 ++- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/entrypoints/openai/serving_tokenization.py | 14 ++++++++------ 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 36995fb111126..e42a231f5cc8a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -249,7 +249,8 @@ def run_server(args, llm_engine=None): global openai_serving_embedding global openai_serving_tokenization - openai_serving_chat = OpenAIServingChat(engine, model_config, + openai_serving_chat = OpenAIServingChat(engine, + model_config, served_model_names, args.response_role, args.lora_modules, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 3e2521fb0499f..0eced77580597 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -20,8 +20,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob logger = init_logger(__name__) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 344d7e008220a..b3c88a973f44b 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -7,8 +7,7 @@ parse_chat_message_content) from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, - ErrorResponse, - TokenizeRequest, + ErrorResponse, TokenizeRequest, TokenizeResponse) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) @@ -16,7 +15,8 @@ class OpenAIServingTokenization(OpenAIServing): - def __init__(self, + def __init__( + self, engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], @@ -62,21 +62,23 @@ async def create_tokenize( tokenizer) conversation.extend(result.messages) - request.prompt = tokenizer.apply_chat_template( + prompt = tokenizer.apply_chat_template( add_generation_prompt=request.add_generation_prompt, conversation=conversation, tokenize=False, chat_template=self.chat_template) + else: + assert request.prompt is not None + prompt = request.prompt prompt_input = self._tokenize_prompt_input( request, tokenizer, - request.prompt, + prompt, add_special_tokens=request.add_special_tokens, ) input_ids = prompt_input["prompt_token_ids"] - return TokenizeResponse(tokens=input_ids, count=len(input_ids), max_model_len=self.max_model_len) From e07c2beb9304493221ca267df9fafa034997ac0f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 09:42:14 +0000 Subject: [PATCH 69/94] Fix docs --- docs/source/dev/multimodal/multimodal_index.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 6713dcf08d9f0..7cdbec2c9e3d4 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. From 1b41a29ed0a9a10f7129ceeb5836abc77b741e37 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 09:46:05 +0000 Subject: [PATCH 70/94] Remove logging params from tokenization endpoint --- vllm/entrypoints/openai/api_server.py | 2 -- vllm/entrypoints/openai/serving_tokenization.py | 7 ++----- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e42a231f5cc8a..6fe4dd774ffcc 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -279,8 +279,6 @@ def run_server(args, llm_engine=None): served_model_names, args.lora_modules, args.chat_template, - log_requests=log_requests, - max_log_len=max_log_len, ) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index b3c88a973f44b..1c5eb54e7390a 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -22,16 +22,13 @@ def __init__( served_model_names: List[str], lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None, - *, - log_requests: bool, - max_log_len: Optional[int], ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, - log_requests=log_requests, - max_log_len=max_log_len) + log_requests=False, + max_log_len=None) # If this is None we use the tokenizer's default chat template self.chat_template = load_chat_template(chat_template) From 3e88444ec4b09ca19df36de0ea327f4a6141554f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 09:52:09 +0000 Subject: [PATCH 71/94] Remove duplicated function --- vllm/entrypoints/openai/serving_embedding.py | 32 -------------------- 1 file changed, 32 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 4baa4c53e2739..193a40afd551f 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -158,35 +158,3 @@ def _check_embedding_mode(self, embedding_mode: bool): "embedding_mode is False. Embedding API will not work.") else: logger.info("Activating the server engine with embedding enabled.") - - def request_output_to_embedding_response( - self, - final_res_batch: List[EmbeddingRequestOutput], - request_id: str, - created_time: int, - model_name: str, - ) -> EmbeddingResponse: - data = [] - num_prompt_tokens = 0 - for idx, final_res in enumerate(final_res_batch): - assert final_res is not None - prompt_token_ids = final_res.prompt_token_ids - - embedding_data = EmbeddingResponseData( - index=idx, embedding=final_res.outputs.embedding) - data.append(embedding_data) - - num_prompt_tokens += len(prompt_token_ids) - - usage = UsageInfo( - prompt_tokens=num_prompt_tokens, - total_tokens=num_prompt_tokens, - ) - - return EmbeddingResponse( - id=request_id, - created=created_time, - model=model_name, - data=data, - usage=usage, - ) From 0b2eac0db5eee7bbe5bcbf19a9ac6a3601ed4279 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 10:44:05 +0000 Subject: [PATCH 72/94] Fix type errors --- vllm/entrypoints/openai/protocol.py | 4 +- vllm/entrypoints/openai/serving_chat.py | 12 +++- vllm/entrypoints/openai/serving_completion.py | 13 ++-- vllm/entrypoints/openai/serving_embedding.py | 3 +- vllm/entrypoints/openai/serving_engine.py | 71 +++++++++---------- .../openai/serving_tokenization.py | 4 +- 6 files changed, 56 insertions(+), 51 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2faf061192307..a6fb372569a4f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -704,8 +704,8 @@ class BatchRequestInput(OpenAIBaseModel): # /v1/chat/completions is supported. url: str - # The parameteters of the request. - body: Union[ChatCompletionRequest, ] + # The parameters of the request. + body: ChatCompletionRequest class BatchResponseData(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 32b9db7048336..5b1fb401e0224 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -81,7 +81,11 @@ async def create_chat_completion( return error_check_ret try: - _, lora_request = self._maybe_get_adapter(request) + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) conversation: List[ConversationMessage] = [] @@ -150,7 +154,8 @@ async def create_chat_completion( self._log_inputs(request_id, prompt_inputs, sampling_params, - lora_request=lora_request) + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) engine_inputs: PromptInputs = { "prompt_token_ids": prompt_inputs["prompt_token_ids"], @@ -170,8 +175,9 @@ async def create_chat_completion( engine_inputs, sampling_params, request_id, - lora_request, + lora_request=lora_request, trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, ) except ValueError as e: # TODO: Use a vllm-specific Validation Error diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c4448d53e9ab2..0d9870878ed5d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -87,12 +87,10 @@ async def create_completion(self, request: CompletionRequest, # Schedule the request and get the result generator. generators: List[AsyncIterator[RequestOutput]] = [] try: - adapter_type, adapter_request = self._maybe_get_adapter(request) - lora_request, prompt_adapter_request = None, None - if adapter_type == 'LoRA': - lora_request, prompt_adapter_request = adapter_request, None - elif adapter_type == 'PromptAdapter': - lora_request, prompt_adapter_request = None, adapter_request + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) tokenizer = await self.engine.get_tokenizer(lora_request) sampling_params = request.to_sampling_params() @@ -124,7 +122,8 @@ async def create_completion(self, request: CompletionRequest, self._log_inputs(request_id_item, prompt_inputs, sampling_params, - lora_request=lora_request) + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) is_tracing_enabled = await self.engine.is_tracing_enabled() trace_headers = None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 193a40afd551f..6d76b07172aed 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -31,7 +31,8 @@ def request_output_to_embedding_response( prompt_token_ids = final_res.prompt_token_ids embedding = final_res.outputs.embedding if encoding_format == "base64": - embedding = base64.b64encode(np.array(embedding)).decode("utf-8") + embedding_bytes = np.array(embedding).tobytes() + embedding = base64.b64encode(embedding_bytes).decode("utf-8") embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 0eced77580597..1d358a134f0ad 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -27,11 +27,6 @@ logger = init_logger(__name__) -class TextTokensPrompt(TypedDict): - prompt: str - prompt_token_ids: List[int] - - @dataclass class PromptAdapterPath: name: str @@ -44,6 +39,17 @@ class LoRAModulePath: local_path: str +AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, + EmbeddingRequest, TokenizeRequest] + +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + + +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: List[int] + + class OpenAIServing: def __init__( @@ -140,9 +146,8 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, - TokenizeRequest] + self, + request: AnyRequest, ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -158,28 +163,25 @@ async def _check_model( err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_adapter( - self, request: Union[CompletionRequest, ChatCompletionRequest, - EmbeddingRequest, TokenizeRequest, - DetokenizeRequest] - ) -> Tuple[Optional[str], Optional[Union[LoRARequest, - PromptAdapterRequest]]]: + def _maybe_get_adapters( + self, request: AnyRequest + ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ + None, PromptAdapterRequest]]: if request.model in self.served_model_names: return None, None for lora in self.lora_requests: if request.model == lora.lora_name: - return 'LoRA', lora + return lora, None for prompt_adapter in self.prompt_adapter_requests: if request.model == prompt_adapter.prompt_adapter_name: - return 'PromptAdapter', prompt_adapter + return None, prompt_adapter # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") def _normalize_prompt_text_to_input( self, - request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, TokenizeRequest], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + request: AnyRequest, + tokenizer: AnyTokenizer, prompt: str, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], add_special_tokens: bool, @@ -200,9 +202,8 @@ def _normalize_prompt_text_to_input( def _normalize_prompt_tokens_to_input( self, - request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, TokenizeRequest], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + request: AnyRequest, + tokenizer: AnyTokenizer, prompt_ids: List[int], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], ) -> TextTokensPrompt: @@ -217,8 +218,7 @@ def _normalize_prompt_tokens_to_input( def _validate_input( self, - request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, TokenizeRequest], + request: AnyRequest, input_ids: List[int], input_text: str, ) -> TextTokensPrompt: @@ -263,9 +263,8 @@ def _validate_input( def _tokenize_prompt_input( self, - request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, TokenizeRequest], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + request: AnyRequest, + tokenizer: AnyTokenizer, prompt_input: Union[str, List[int]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, @@ -285,9 +284,8 @@ def _tokenize_prompt_input( def _tokenize_prompt_inputs( self, - request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, TokenizeRequest], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + request: AnyRequest, + tokenizer: AnyTokenizer, prompt_inputs: Iterable[Union[str, List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, @@ -315,9 +313,8 @@ def _tokenize_prompt_inputs( def _tokenize_prompt_input_or_inputs( self, - request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, TokenizeRequest], - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + request: AnyRequest, + tokenizer: AnyTokenizer, input_or_inputs: Union[str, List[str], List[int], List[List[int]]], truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, add_special_tokens: bool = True, @@ -356,6 +353,7 @@ def _log_inputs( inputs: TextTokensPrompt, params: Union[SamplingParams, PoolingParams], lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: if not self.log_requests: return @@ -371,14 +369,15 @@ def _log_inputs( logger.info( "Received request %s: prompt: %r, " "params: %s, prompt_token_ids: %s, " - "lora_request: %s.", request_id, shortened_prompt, params, - shortened_token_ids, lora_request) + "lora_request: %s, prompt_adapter_request: %s.", request_id, + shortened_prompt, params, shortened_token_ids, lora_request, + prompt_adapter_request) @staticmethod def _get_decoded_token( logprob: Logprob, token_id: int, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + tokenizer: AnyTokenizer, ) -> str: if logprob.decoded_token is not None: return logprob.decoded_token diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 1c5eb54e7390a..a8d92af87a7bb 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -49,7 +49,7 @@ async def create_tokenize( return self.create_error_response( "Only one of `prompt` or `messages` should be provided.") - _, lora_request = self._maybe_get_adapter(request) + lora_request, _ = self._maybe_get_adapters(request) tokenizer = await self.engine.get_tokenizer(lora_request) if request.messages: conversation: List[ConversationMessage] = [] @@ -88,7 +88,7 @@ async def create_detokenize( if error_check_ret is not None: return error_check_ret - _, lora_request = self._maybe_get_adapter(request) + lora_request, _ = self._maybe_get_adapters(request) tokenizer = await self.engine.get_tokenizer(lora_request) prompt_input = self._tokenize_prompt_input( From 165c0b1e2c637e58057077171d2810cb39fb6a59 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 10:44:26 +0000 Subject: [PATCH 73/94] Handle lora and prompt adapters for embeddings --- vllm/entrypoints/openai/serving_embedding.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 6d76b07172aed..946c051a0f56c 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -95,8 +95,13 @@ async def create_embedding(self, request: EmbeddingRequest, # Schedule the request and get the result generator. generators: List[AsyncIterator[EmbeddingRequestOutput]] = [] try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) + pooling_params = request.to_pooling_params() - tokenizer = await self.engine.get_tokenizer() prompts = list( self._tokenize_prompt_input_or_inputs( @@ -111,12 +116,20 @@ async def create_embedding(self, request: EmbeddingRequest, self._log_inputs(request_id_item, prompt_inputs, pooling_params, - lora_request=None) + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + if prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported " + "for embedding models") generator = self.engine.encode( {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, request_id_item, + lora_request=lora_request, + # prompt_adapter_request=prompt_adapter_request, ) generators.append(generator) From 1dc3d219462cddf30e4c71ce012db9f2599fdb26 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 10:51:50 +0000 Subject: [PATCH 74/94] Remove redundant code --- vllm/entrypoints/openai/serving_embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 946c051a0f56c..80d5e244597d1 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -27,7 +27,6 @@ def request_output_to_embedding_response( data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): - assert final_res is not None prompt_token_ids = final_res.prompt_token_ids embedding = final_res.outputs.embedding if encoding_format == "base64": From 6264c82d48cfb57575e5567f8fcc7c0773db261f Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 11:16:35 +0000 Subject: [PATCH 75/94] Use type union for tokenization requests --- vllm/entrypoints/openai/protocol.py | 16 +++++++++---- vllm/entrypoints/openai/serving_chat.py | 3 ++- vllm/entrypoints/openai/serving_engine.py | 11 +++++++-- .../openai/serving_tokenization.py | 23 +++++++++---------- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a6fb372569a4f..ce657bac9a211 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -737,12 +737,20 @@ class BatchRequestOutput(OpenAIBaseModel): error: Optional[Any] -class TokenizeRequest(OpenAIBaseModel): +class TokenizeCompletionRequest(OpenAIBaseModel): + model: str + prompt: str + + +class TokenizeChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + add_generation_prompt: bool = Field(default=True) add_special_tokens: bool = Field(default=False) - prompt: Optional[str] = Field(default=None) - messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None) - model: str + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] class TokenizeResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 5b1fb401e0224..9605714562707 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -86,6 +86,7 @@ async def create_chat_completion( prompt_adapter_request, ) = self._maybe_get_adapters(request) + model_config = self.model_config tokenizer = await self.engine.get_tokenizer(lora_request) conversation: List[ConversationMessage] = [] @@ -93,7 +94,7 @@ async def create_chat_completion( for msg in request.messages: chat_parsed_result = parse_chat_message_content( - msg, self.model_config, tokenizer) + msg, model_config, tokenizer) conversation.extend(chat_parsed_result.messages) mm_futures.extend(chat_parsed_result.mm_futures) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 1d358a134f0ad..5470fe87a76fb 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -10,12 +10,18 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, EmbeddingRequest, ErrorResponse, ModelCard, ModelList, - ModelPermission, TokenizeRequest) + ModelPermission, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeRequest) +# yapf: enable from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -237,7 +243,8 @@ def _validate_input( # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation - if isinstance(request, (TokenizeRequest, DetokenizeRequest)): + if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest)): return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index a8d92af87a7bb..08d3411b1f2d2 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -5,10 +5,15 @@ from vllm.entrypoints.openai.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, - ErrorResponse, TokenizeRequest, + ErrorResponse, + TokenizeChatRequest, + TokenizeRequest, TokenizeResponse) +# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) @@ -41,21 +46,16 @@ async def create_tokenize( if error_check_ret is not None: return error_check_ret - if not (request.prompt or request.messages): - return self.create_error_response( - "Either `prompt` or `messages` should be provided.") - - if (request.prompt and request.messages): - return self.create_error_response( - "Only one of `prompt` or `messages` should be provided.") - lora_request, _ = self._maybe_get_adapters(request) tokenizer = await self.engine.get_tokenizer(lora_request) - if request.messages: + + if isinstance(request, TokenizeChatRequest): + model_config = self.model_config + conversation: List[ConversationMessage] = [] for message in request.messages: - result = parse_chat_message_content(message, self.model_config, + result = parse_chat_message_content(message, model_config, tokenizer) conversation.extend(result.messages) @@ -65,7 +65,6 @@ async def create_tokenize( tokenize=False, chat_template=self.chat_template) else: - assert request.prompt is not None prompt = request.prompt prompt_input = self._tokenize_prompt_input( From 0492d792466c41a514f0aadc20138dbc88c7a254 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 11:20:19 +0000 Subject: [PATCH 76/94] Fix `request_id` and add logging --- vllm/entrypoints/openai/serving_chat.py | 4 +- vllm/entrypoints/openai/serving_completion.py | 3 +- vllm/entrypoints/openai/serving_embedding.py | 5 ++- vllm/entrypoints/openai/serving_engine.py | 25 +++++++----- .../openai/serving_tokenization.py | 40 ++++++++++++++++++- 5 files changed, 61 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9605714562707..3da71f5434dba 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -128,7 +128,7 @@ async def create_chat_completion( logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) - request_id = f"cmpl-{random_uuid()}" + request_id = f"chat-{random_uuid()}" try: sampling_params = request.to_sampling_params() decoding_config = await self.engine.get_decoding_config() @@ -154,7 +154,7 @@ async def create_chat_completion( self._log_inputs(request_id, prompt_inputs, - sampling_params, + params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 0d9870878ed5d..afcd9c09cfee3 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -91,6 +91,7 @@ async def create_completion(self, request: CompletionRequest, lora_request, prompt_adapter_request, ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) sampling_params = request.to_sampling_params() @@ -121,7 +122,7 @@ async def create_completion(self, request: CompletionRequest, self._log_inputs(request_id_item, prompt_inputs, - sampling_params, + params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 80d5e244597d1..23d1b6e9719a9 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -88,7 +88,7 @@ async def create_embedding(self, request: EmbeddingRequest, "dimensions is currently not supported") model_name = request.model - request_id = f"cmpl-{random_uuid()}" + request_id = f"embed-{random_uuid()}" created_time = int(time.monotonic()) # Schedule the request and get the result generator. @@ -98,6 +98,7 @@ async def create_embedding(self, request: EmbeddingRequest, lora_request, prompt_adapter_request, ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) pooling_params = request.to_pooling_params() @@ -114,7 +115,7 @@ async def create_embedding(self, request: EmbeddingRequest, self._log_inputs(request_id_item, prompt_inputs, - pooling_params, + params=pooling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5470fe87a76fb..a2753debe8312 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -357,21 +357,28 @@ def _tokenize_prompt_input_or_inputs( def _log_inputs( self, request_id: str, - inputs: TextTokensPrompt, - params: Union[SamplingParams, PoolingParams], + inputs: Union[str, List[int], TextTokensPrompt], + params: Optional[Union[SamplingParams, PoolingParams]], lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: if not self.log_requests: return + + if isinstance(inputs, str): + shortened_prompt = inputs + shortened_token_ids = None + elif isinstance(inputs, list): + shortened_prompt = None + shortened_token_ids = inputs + else: + shortened_prompt = inputs["prompt"] + shortened_token_ids = inputs["prompt_token_ids"] - shortened_prompt = inputs["prompt"] - shortened_token_ids = inputs["prompt_token_ids"] - - max_log_len = self.max_log_len - if max_log_len is not None: - shortened_prompt = shortened_prompt[:max_log_len] - shortened_token_ids = shortened_token_ids[:max_log_len] + max_log_len = self.max_log_len + if max_log_len is not None: + shortened_prompt = shortened_prompt[:max_log_len] + shortened_token_ids = shortened_token_ids[:max_log_len] logger.info( "Received request %s: prompt: %r, " diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 08d3411b1f2d2..dd1ea49491697 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -16,6 +16,7 @@ # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) +from vllm.utils import random_uuid class OpenAIServingTokenization(OpenAIServing): @@ -46,7 +47,13 @@ async def create_tokenize( if error_check_ret is not None: return error_check_ret - lora_request, _ = self._maybe_get_adapters(request) + request_id = f"tok-{random_uuid()}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) if isinstance(request, TokenizeChatRequest): @@ -64,9 +71,21 @@ async def create_tokenize( conversation=conversation, tokenize=False, chat_template=self.chat_template) + assert isinstance(prompt, str) else: prompt = request.prompt + self._log_inputs(request_id, + prompt, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + if prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported " + "for tokenization") + prompt_input = self._tokenize_prompt_input( request, tokenizer, @@ -87,9 +106,26 @@ async def create_detokenize( if error_check_ret is not None: return error_check_ret - lora_request, _ = self._maybe_get_adapters(request) + request_id = f"tok-{random_uuid()}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) + self._log_inputs(request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + if prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported " + "for tokenization") + prompt_input = self._tokenize_prompt_input( request, tokenizer, From 4aa552acc5c506ad108e53934160eaaa59004d88 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 11:23:35 +0000 Subject: [PATCH 77/94] yapf --- vllm/entrypoints/openai/serving_engine.py | 2 +- vllm/entrypoints/openai/serving_tokenization.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index a2753debe8312..af8172e171dc4 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -364,7 +364,7 @@ def _log_inputs( ) -> None: if not self.log_requests: return - + if isinstance(inputs, str): shortened_prompt = inputs shortened_token_ids = None diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index dd1ea49491697..bc7c8c4ea0ce5 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -82,9 +82,8 @@ async def create_tokenize( prompt_adapter_request=prompt_adapter_request) if prompt_adapter_request is not None: - raise NotImplementedError( - "Prompt adapter is not supported " - "for tokenization") + raise NotImplementedError("Prompt adapter is not supported " + "for tokenization") prompt_input = self._tokenize_prompt_input( request, @@ -122,9 +121,8 @@ async def create_detokenize( prompt_adapter_request=prompt_adapter_request) if prompt_adapter_request is not None: - raise NotImplementedError( - "Prompt adapter is not supported " - "for tokenization") + raise NotImplementedError("Prompt adapter is not supported " + "for tokenization") prompt_input = self._tokenize_prompt_input( request, From 994f3eeddadb5f8863beb88acce64ad25a288118 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 11:24:29 +0000 Subject: [PATCH 78/94] Update request id --- vllm/entrypoints/openai/serving_embedding.py | 2 +- vllm/entrypoints/openai/serving_tokenization.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 23d1b6e9719a9..692786a1d62dc 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -88,7 +88,7 @@ async def create_embedding(self, request: EmbeddingRequest, "dimensions is currently not supported") model_name = request.model - request_id = f"embed-{random_uuid()}" + request_id = f"embd-{random_uuid()}" created_time = int(time.monotonic()) # Schedule the request and get the result generator. diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index bc7c8c4ea0ce5..b32e22e1f55da 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -47,7 +47,7 @@ async def create_tokenize( if error_check_ret is not None: return error_check_ret - request_id = f"tok-{random_uuid()}" + request_id = f"tokn-{random_uuid()}" ( lora_request, @@ -105,7 +105,7 @@ async def create_detokenize( if error_check_ret is not None: return error_check_ret - request_id = f"tok-{random_uuid()}" + request_id = f"tokn-{random_uuid()}" ( lora_request, From fe0629dfdf17e07f230e98ca5e6bc4830b7f77cc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 11:27:54 +0000 Subject: [PATCH 79/94] Enable logging --- vllm/entrypoints/openai/api_server.py | 2 ++ vllm/entrypoints/openai/serving_tokenization.py | 7 +++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 6fe4dd774ffcc..e42a231f5cc8a 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -279,6 +279,8 @@ def run_server(args, llm_engine=None): served_model_names, args.lora_modules, args.chat_template, + log_requests=log_requests, + max_log_len=max_log_len, ) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index b32e22e1f55da..485ac9f559e65 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -28,13 +28,16 @@ def __init__( served_model_names: List[str], lora_modules: Optional[List[LoRAModulePath]] = None, chat_template: Optional[str] = None, + *, + log_requests: bool, + max_log_len: Optional[int], ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, - log_requests=False, - max_log_len=None) + log_requests=log_requests, + max_log_len=max_log_len) # If this is None we use the tokenizer's default chat template self.chat_template = load_chat_template(chat_template) From d85ba521e75876efb5472c644ce9aa261ca27711 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Thu, 18 Jul 2024 11:37:31 +0000 Subject: [PATCH 80/94] Fix invalid attribute access --- vllm/entrypoints/openai/serving_tokenization.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 485ac9f559e65..0e10b95057e22 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -75,8 +75,11 @@ async def create_tokenize( tokenize=False, chat_template=self.chat_template) assert isinstance(prompt, str) + + add_special_tokens = request.add_special_tokens else: prompt = request.prompt + add_special_tokens = False self._log_inputs(request_id, prompt, @@ -92,7 +95,7 @@ async def create_tokenize( request, tokenizer, prompt, - add_special_tokens=request.add_special_tokens, + add_special_tokens=add_special_tokens, ) input_ids = prompt_input["prompt_token_ids"] From d0fd1f4b5d368c0f0620b48e731e903d7ba00be8 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 02:05:35 +0000 Subject: [PATCH 81/94] Add `add_special_tokens` to Completions API to simplify the logic --- vllm/entrypoints/openai/protocol.py | 18 +++++++++++++----- vllm/entrypoints/openai/run_batch.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 1 + .../entrypoints/openai/serving_tokenization.py | 5 +---- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ce657bac9a211..423c862390c8d 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -172,20 +172,20 @@ class ChatCompletionRequest(OpenAIBaseModel): echo: Optional[bool] = Field( default=False, description=( - "If true, the new message will be prepended with the last message " + "If True, the new message will be prepended with the last message " "if they belong to the same role."), ) - add_generation_prompt: Optional[bool] = Field( + add_generation_prompt: bool = Field( default=True, description= - ("If true, the generation prompt will be added to the chat template. " + ("If True, the generation prompt will be added to the chat template. " "This is a parameter used by chat template in tokenizer config of the " "model."), ) - add_special_tokens: Optional[bool] = Field( + add_special_tokens: bool = Field( default=False, description=( - "If true, special tokens (e.g. BOS) will be added to the prompt " + "If True, special tokens (e.g. BOS) will be added to the prompt " "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " "special tokens so this should be set to False (as is the " @@ -397,6 +397,12 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-sampling-params # doc: begin-completion-extra-params + add_special_tokens: bool = Field( + default=True, + description=( + "If True (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), + ) include_stop_str_in_output: Optional[bool] = Field( default=False, description=( @@ -741,6 +747,8 @@ class TokenizeCompletionRequest(OpenAIBaseModel): model: str prompt: str + add_special_tokens: bool = Field(default=True) + class TokenizeChatRequest(OpenAIBaseModel): model: str diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index a0afce96dc88d..9d1ceaa251bcd 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -44,7 +44,7 @@ def parse_args(): type=nullable_str, default="assistant", help="The role name to return if " - "`request.add_generation_prompt=true`.") + "`request.add_generation_prompt=True`.") parser = AsyncEngineArgs.add_cli_args(parser) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 3da71f5434dba..cda2754e833fc 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -149,7 +149,7 @@ async def create_chat_completion( tokenizer, prompt, truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, - add_special_tokens=False, + add_special_tokens=request.add_special_tokens, ) self._log_inputs(request_id, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index afcd9c09cfee3..7ad07bb9072e0 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -115,6 +115,7 @@ async def create_completion(self, request: CompletionRequest, request.prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, )) for i, prompt_inputs in enumerate(prompts): diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 0e10b95057e22..485ac9f559e65 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -75,11 +75,8 @@ async def create_tokenize( tokenize=False, chat_template=self.chat_template) assert isinstance(prompt, str) - - add_special_tokens = request.add_special_tokens else: prompt = request.prompt - add_special_tokens = False self._log_inputs(request_id, prompt, @@ -95,7 +92,7 @@ async def create_tokenize( request, tokenizer, prompt, - add_special_tokens=add_special_tokens, + add_special_tokens=request.add_special_tokens, ) input_ids = prompt_input["prompt_token_ids"] From 2c38ccfe5f5b30a8c0188a079de8105fcfbd421b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 02:23:07 +0000 Subject: [PATCH 82/94] Factor out logging args --- vllm/entrypoints/logger.py | 47 +++++++++++++++++ vllm/entrypoints/openai/api_server.py | 42 ++++++++------- vllm/entrypoints/openai/run_batch.py | 11 ++-- vllm/entrypoints/openai/serving_chat.py | 13 ++--- vllm/entrypoints/openai/serving_completion.py | 8 ++- vllm/entrypoints/openai/serving_embedding.py | 9 ++-- vllm/entrypoints/openai/serving_engine.py | 52 ++++++++----------- .../openai/serving_tokenization.py | 13 +++-- 8 files changed, 117 insertions(+), 78 deletions(-) create mode 100644 vllm/entrypoints/logger.py diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py new file mode 100644 index 0000000000000..1d3b4cef8fba7 --- /dev/null +++ b/vllm/entrypoints/logger.py @@ -0,0 +1,47 @@ +from typing import List, Optional, TypedDict, Union + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.pooling_params import PoolingParams +from vllm.sampling_params import SamplingParams + +logger = init_logger(__name__) + + +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: List[int] + + +class RequestLogger: + + def __init__(self, *, max_log_len: Optional[int]) -> None: + super().__init__() + + self.max_log_len = max_log_len + + def log_inputs( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]], + params: Optional[Union[SamplingParams, PoolingParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + + max_log_len = self.max_log_len + if max_log_len is not None: + if prompt is not None: + prompt = prompt[:max_log_len] + + if prompt_token_ids is not None: + prompt_token_ids = prompt_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, lora_request, + prompt_adapter_request) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index e42a231f5cc8a..65437d7b18591 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,6 +18,7 @@ import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block # yapf: disable @@ -241,46 +242,47 @@ def run_server(args, llm_engine=None): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) - log_requests = not args.disable_log_requests - max_log_len = args.max_log_len + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) global openai_serving_chat global openai_serving_completion global openai_serving_embedding global openai_serving_tokenization - openai_serving_chat = OpenAIServingChat(engine, - model_config, - served_model_names, - args.response_role, - args.lora_modules, - args.chat_template, - log_requests=log_requests, - max_log_len=max_log_len) + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + served_model_names, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=args.chat_template, + ) openai_serving_completion = OpenAIServingCompletion( engine, model_config, served_model_names, - args.lora_modules, - args.prompt_adapters, - log_requests=log_requests, - max_log_len=max_log_len, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, ) openai_serving_embedding = OpenAIServingEmbedding( engine, model_config, served_model_names, - log_requests=log_requests, - max_log_len=max_log_len, + request_logger=request_logger, ) openai_serving_tokenization = OpenAIServingTokenization( engine, model_config, served_model_names, - args.lora_modules, - args.chat_template, - log_requests=log_requests, - max_log_len=max_log_len, + lora_modules=args.lora_modules, + request_logger=request_logger, + chat_template=args.chat_template, ) app.root_path = args.root_path diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 9d1ceaa251bcd..3c5e5e651b54d 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -6,6 +6,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchRequestOutput, BatchResponseData, @@ -122,8 +123,10 @@ async def main(args): # When using single vLLM without engine_use_ray model_config = await engine.get_model_config() - log_requests = not args.disable_log_requests - max_log_len = args.max_log_len + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) openai_serving_chat = OpenAIServingChat( engine, @@ -131,9 +134,9 @@ async def main(args): served_model_names, args.response_role, lora_modules=None, + prompt_adapters=None, + request_logger=request_logger, chat_template=None, - log_requests=log_requests, - max_log_len=max_log_len, ) # Submit all requests in the file to the engine "concurrently". diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index cda2754e833fc..b922bcf927c4e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -20,7 +21,8 @@ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, - OpenAIServing) + OpenAIServing, + PromptAdapterPath) from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( @@ -44,17 +46,16 @@ def __init__( served_model_names: List[str], response_role: str, lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], chat_template: Optional[str], - *, - log_requests: bool, - max_log_len: Optional[int], ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, - log_requests=log_requests, - max_log_len=max_log_len) + prompt_adapters=prompt_adapters, + request_logger=request_logger) self.response_role = response_role diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 7ad07bb9072e0..808294c632030 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -9,6 +9,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (CompletionLogProbs, @@ -48,17 +49,14 @@ def __init__( served_model_names: List[str], lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], - *, - log_requests: bool, - max_log_len: Optional[int], + request_logger: Optional[RequestLogger], ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, prompt_adapters=prompt_adapters, - log_requests=log_requests, - max_log_len=max_log_len) + request_logger=request_logger) async def create_completion(self, request: CompletionRequest, raw_request: Request): diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 692786a1d62dc..3ec3b4bcfe568 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -7,6 +7,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, UsageInfo) @@ -58,16 +59,14 @@ def __init__( engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - *, - log_requests: bool, - max_log_len: Optional[int], + request_logger: Optional[RequestLogger], ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=None, - log_requests=log_requests, - max_log_len=max_log_len) + prompt_adapters=None, + request_logger=request_logger) self._check_embedding_mode(model_config.embedding_mode) async def create_embedding(self, request: EmbeddingRequest, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index af8172e171dc4..bb998696c11dd 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import pathlib from dataclasses import dataclass from http import HTTPStatus -from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union +from typing import Iterable, Iterator, List, Optional, Tuple, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -10,6 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger, TextTokensPrompt # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -51,11 +52,6 @@ class LoRAModulePath: AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] -class TextTokensPrompt(TypedDict): - prompt: str - prompt_token_ids: List[int] - - class OpenAIServing: def __init__( @@ -64,10 +60,8 @@ def __init__( model_config: ModelConfig, served_model_names: List[str], lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]] = None, - *, - log_requests: bool, - max_log_len: Optional[int], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], ): super().__init__() @@ -101,8 +95,7 @@ def __init__( prompt_adapter_local_path=prompt_adapter.local_path, prompt_adapter_num_virtual_tokens=num_virtual_tokens)) - self.log_requests = log_requests - self.max_log_len = max_log_len + self.request_logger = request_logger async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" @@ -362,30 +355,27 @@ def _log_inputs( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: - if not self.log_requests: + if self.request_logger is None: return if isinstance(inputs, str): - shortened_prompt = inputs - shortened_token_ids = None + prompt = inputs + prompt_token_ids = None elif isinstance(inputs, list): - shortened_prompt = None - shortened_token_ids = inputs + prompt = None + prompt_token_ids = inputs else: - shortened_prompt = inputs["prompt"] - shortened_token_ids = inputs["prompt_token_ids"] - - max_log_len = self.max_log_len - if max_log_len is not None: - shortened_prompt = shortened_prompt[:max_log_len] - shortened_token_ids = shortened_token_ids[:max_log_len] - - logger.info( - "Received request %s: prompt: %r, " - "params: %s, prompt_token_ids: %s, " - "lora_request: %s, prompt_adapter_request: %s.", request_id, - shortened_prompt, params, shortened_token_ids, lora_request, - prompt_adapter_request) + prompt = inputs["prompt"] + prompt_token_ids = inputs["prompt_token_ids"] + + self.request_logger.log_inputs( + request_id, + prompt, + prompt_token_ids, + params=params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) @staticmethod def _get_decoded_token( diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 485ac9f559e65..4f76c2276deae 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -2,6 +2,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) @@ -26,18 +27,16 @@ def __init__( engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]] = None, - chat_template: Optional[str] = None, - *, - log_requests: bool, - max_log_len: Optional[int], + lora_modules: Optional[List[LoRAModulePath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, - log_requests=log_requests, - max_log_len=max_log_len) + prompt_adapters=None, + request_logger=request_logger) # If this is None we use the tokenizer's default chat template self.chat_template = load_chat_template(chat_template) From b0f9595b7b859ee607826dbaf47610284d3bc76c Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 02:25:48 +0000 Subject: [PATCH 83/94] Make optional args keyword-only; cleanup --- vllm/entrypoints/logger.py | 6 +++--- vllm/entrypoints/openai/serving_chat.py | 1 + vllm/entrypoints/openai/serving_completion.py | 1 + vllm/entrypoints/openai/serving_embedding.py | 1 + vllm/entrypoints/openai/serving_engine.py | 1 + vllm/entrypoints/openai/serving_tokenization.py | 1 + 6 files changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index 1d3b4cef8fba7..a89feada4a4e5 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -2,8 +2,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams logger = init_logger(__name__) @@ -20,7 +20,7 @@ def __init__(self, *, max_log_len: Optional[int]) -> None: super().__init__() self.max_log_len = max_log_len - + def log_inputs( self, request_id: str, @@ -35,7 +35,7 @@ def log_inputs( if max_log_len is not None: if prompt is not None: prompt = prompt[:max_log_len] - + if prompt_token_ids is not None: prompt_token_ids = prompt_token_ids[:max_log_len] diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b922bcf927c4e..d19c396da8e78 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -45,6 +45,7 @@ def __init__( model_config: ModelConfig, served_model_names: List[str], response_role: str, + *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 808294c632030..76c1921be4198 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -47,6 +47,7 @@ def __init__( engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], + *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 3ec3b4bcfe568..61e1ec08c58f1 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -59,6 +59,7 @@ def __init__( engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], + *, request_logger: Optional[RequestLogger], ): super().__init__(engine=engine, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index bb998696c11dd..5cd5e0625e312 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -59,6 +59,7 @@ def __init__( engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], + *, lora_modules: Optional[List[LoRAModulePath]], prompt_adapters: Optional[List[PromptAdapterPath]], request_logger: Optional[RequestLogger], diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 4f76c2276deae..19c74224fedb9 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -27,6 +27,7 @@ def __init__( engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], + *, lora_modules: Optional[List[LoRAModulePath]], request_logger: Optional[RequestLogger], chat_template: Optional[str], From 4fddfa0337dfcfd076e13fcab2c6854a307bb3e1 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 02:26:16 +0000 Subject: [PATCH 84/94] Remove extra line --- vllm/entrypoints/logger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index a89feada4a4e5..f24952e05e9c8 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -30,7 +30,6 @@ def log_inputs( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: - max_log_len = self.max_log_len if max_log_len is not None: if prompt is not None: From 18a1facee617783268263f1341c80a025cb0336b Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 02:26:57 +0000 Subject: [PATCH 85/94] Move definition back --- vllm/entrypoints/logger.py | 7 +------ vllm/entrypoints/openai/serving_engine.py | 9 +++++++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py index f24952e05e9c8..091896e1c7a69 100644 --- a/vllm/entrypoints/logger.py +++ b/vllm/entrypoints/logger.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TypedDict, Union +from typing import List, Optional, Union from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -9,11 +9,6 @@ logger = init_logger(__name__) -class TextTokensPrompt(TypedDict): - prompt: str - prompt_token_ids: List[int] - - class RequestLogger: def __init__(self, *, max_log_len: Optional[int]) -> None: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 5cd5e0625e312..7578dc9dc3c0c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,7 +2,7 @@ import pathlib from dataclasses import dataclass from http import HTTPStatus -from typing import Iterable, Iterator, List, Optional, Tuple, Union +from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union from pydantic import Field from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -10,7 +10,7 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.logger import RequestLogger, TextTokensPrompt +from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, @@ -52,6 +52,11 @@ class LoRAModulePath: AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: List[int] + + class OpenAIServing: def __init__( From 61cf999b65f8d308b79a49baa2b199dcb44192b7 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 02:44:54 +0000 Subject: [PATCH 86/94] Avoid creating new list --- vllm/entrypoints/openai/serving_completion.py | 6 +++--- vllm/entrypoints/openai/serving_embedding.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 76c1921be4198..6aef4c9f96150 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -2,7 +2,7 @@ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional) from typing import Sequence as GenericSequence -from typing import Tuple +from typing import Tuple, cast from fastapi import Request from transformers import PreTrainedTokenizer @@ -179,7 +179,6 @@ async def create_completion(self, request: CompletionRequest, return self.create_error_response("Client disconnected") final_res_batch[i] = res - final_res_batch_checked: List[RequestOutput] = [] for i, final_res in enumerate(final_res_batch): assert final_res is not None @@ -189,7 +188,8 @@ async def create_completion(self, request: CompletionRequest, if final_res.prompt is None: final_res.prompt = prompts[i]["prompt"] - final_res_batch_checked.append(final_res) + final_res_batch_checked = cast(List[RequestOutput], + final_res_batch) response = self.request_output_to_completion_response( final_res_batch_checked, diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 61e1ec08c58f1..a56a168e77cd9 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,6 +1,6 @@ import base64 import time -from typing import AsyncIterator, List, Optional, Tuple +from typing import AsyncIterator, List, Optional, Tuple, cast import numpy as np from fastapi import Request @@ -148,14 +148,14 @@ async def create_embedding(self, request: EmbeddingRequest, if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(f"{request_id}-{i}") - # TODO: Use a vllm-specific Validation Error return self.create_error_response("Client disconnected") final_res_batch[i] = res - final_res_batch_checked: List[EmbeddingRequestOutput] = [] for final_res in final_res_batch: assert final_res is not None - final_res_batch_checked.append(final_res) + + final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch) response = request_output_to_embedding_response( final_res_batch_checked, request_id, created_time, model_name, From 76124a96a82629e447abf056c8934548df3dab65 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Fri, 19 Jul 2024 12:33:53 +0800 Subject: [PATCH 87/94] Update args in test_serving_chat.py --- tests/entrypoints/openai/test_serving_chat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index db864f986e91a..464465494b714 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -34,8 +34,8 @@ async def _async_serving_chat_init(): response_role="assistant", chat_template=CHAT_TEMPLATE, lora_modules=None, - log_requests=False, - max_log_len=None) + prompt_adapters=None, + request_logger=None) return serving_completion From 68d2c96ae2e9716c74770c10699711ff00fd4670 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 23:54:59 +0000 Subject: [PATCH 88/94] Silently ignore prompt adapter for tokenization --- vllm/entrypoints/openai/serving_tokenization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 19c74224fedb9..d0b330ac46e4a 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -84,9 +84,7 @@ async def create_tokenize( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - if prompt_adapter_request is not None: - raise NotImplementedError("Prompt adapter is not supported " - "for tokenization") + # Silently ignore prompt adapter since it does not affect tokenization prompt_input = self._tokenize_prompt_input( request, From 7a92e6e075a6ab0025c2aae8bce888e8bcae4537 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 23:55:13 +0000 Subject: [PATCH 89/94] Clean --- vllm/entrypoints/openai/serving_embedding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index a56a168e77cd9..bccc90894e79f 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -129,7 +129,6 @@ async def create_embedding(self, request: EmbeddingRequest, pooling_params, request_id_item, lora_request=lora_request, - # prompt_adapter_request=prompt_adapter_request, ) generators.append(generator) From e78dd2891b88078853d000054484f0978d3b0c5d Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 19 Jul 2024 23:56:52 +0000 Subject: [PATCH 90/94] Use HTTP boolean format --- vllm/entrypoints/openai/protocol.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 423c862390c8d..ab376b1674575 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -172,23 +172,23 @@ class ChatCompletionRequest(OpenAIBaseModel): echo: Optional[bool] = Field( default=False, description=( - "If True, the new message will be prepended with the last message " + "If true, the new message will be prepended with the last message " "if they belong to the same role."), ) add_generation_prompt: bool = Field( default=True, description= - ("If True, the generation prompt will be added to the chat template. " + ("If true, the generation prompt will be added to the chat template. " "This is a parameter used by chat template in tokenizer config of the " "model."), ) add_special_tokens: bool = Field( default=False, description=( - "If True, special tokens (e.g. BOS) will be added to the prompt " + "If true, special tokens (e.g. BOS) will be added to the prompt " "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " - "special tokens so this should be set to False (as is the " + "special tokens so this should be set to false (as is the " "default)."), ) documents: Optional[List[Dict[str, str]]] = Field( @@ -400,7 +400,7 @@ class CompletionRequest(OpenAIBaseModel): add_special_tokens: bool = Field( default=True, description=( - "If True (the default), special tokens (e.g. BOS) will be added to " + "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt."), ) include_stop_str_in_output: Optional[bool] = Field( @@ -529,7 +529,7 @@ def check_logprobs(cls, data): def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): raise ValueError( - "Stream options can only be defined when stream is True.") + "Stream options can only be defined when stream is true.") return data From d56c9cc645770ab90b138ec7da0735c441643294 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 20 Jul 2024 00:05:30 +0000 Subject: [PATCH 91/94] Fix incorrectly allowing some sampling params to be `None` --- .../output_processor/test_stop_checker.py | 4 +- vllm/entrypoints/openai/protocol.py | 56 ++++++++----------- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index f795403e3d8ad..0d84443c51f99 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str, @pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ ("This text ends with EOS token", "", 2), ]) -@pytest.mark.parametrize("ignore_eos", [True, False, None]) -@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None]) +@pytest.mark.parametrize("ignore_eos", [True, False]) +@pytest.mark.parametrize("include_stop_str_in_output", [True, False]) @pytest.mark.skip_global_cleanup def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, ignore_eos: bool, include_stop_str_in_output: bool): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index ab376b1674575..3735e8bd2610b 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -155,21 +155,22 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: begin-chat-completion-sampling-params best_of: Optional[int] = None - use_beam_search: Optional[bool] = False - top_k: Optional[int] = -1 - min_p: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - length_penalty: Optional[float] = 1.0 - early_stopping: Optional[bool] = False - ignore_eos: Optional[bool] = False - min_tokens: Optional[int] = 0 + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) - skip_special_tokens: Optional[bool] = True - spaces_between_special_tokens: Optional[bool] = True + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True # doc: end-chat-completion-sampling-params # doc: begin-chat-completion-extra-params - echo: Optional[bool] = Field( + echo: bool = Field( default=False, description=( "If true, the new message will be prepended with the last message " @@ -212,12 +213,6 @@ class ChatCompletionRequest(OpenAIBaseModel): description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) - include_stop_str_in_output: Optional[bool] = Field( - default=False, - description=( - "Whether to include the stop string in the output. " - "This is only applied when the stop or stop_token_ids is set."), - ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, description=("If specified, the output will follow the JSON schema."), @@ -382,17 +377,18 @@ class CompletionRequest(OpenAIBaseModel): user: Optional[str] = None # doc: begin-completion-sampling-params - use_beam_search: Optional[bool] = False - top_k: Optional[int] = -1 - min_p: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - length_penalty: Optional[float] = 1.0 - early_stopping: Optional[bool] = False + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) - ignore_eos: Optional[bool] = False - min_tokens: Optional[int] = 0 - skip_special_tokens: Optional[bool] = True - spaces_between_special_tokens: Optional[bool] = True + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: end-completion-sampling-params @@ -403,12 +399,6 @@ class CompletionRequest(OpenAIBaseModel): "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt."), ) - include_stop_str_in_output: Optional[bool] = Field( - default=False, - description=( - "Whether to include the stop string in the output. " - "This is only applied when the stop or stop_token_ids is set."), - ) response_format: Optional[ResponseFormat] = Field( default=None, description= From f62edef50659f5d3affe56fc9297970de1e3cbe2 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 20 Jul 2024 00:09:49 +0000 Subject: [PATCH 92/94] Fix inconsistent arg availability and ordering --- vllm/entrypoints/openai/protocol.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3735e8bd2610b..b5d142a5465d3 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -167,6 +167,7 @@ class ChatCompletionRequest(OpenAIBaseModel): min_tokens: int = 0 skip_special_tokens: bool = True spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: end-chat-completion-sampling-params # doc: begin-chat-completion-extra-params @@ -273,22 +274,22 @@ def logit_bias_logits_processor( return SamplingParams( n=self.n, + best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, temperature=self.temperature, top_p=self.top_p, + top_k=self.top_k, min_p=self.min_p, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, - max_tokens=self.max_tokens, - min_tokens=self.min_tokens, logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, - best_of=self.best_of, - top_k=self.top_k, ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens, + min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, skip_special_tokens=self.skip_special_tokens, @@ -296,6 +297,7 @@ def logit_bias_logits_processor( include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, ) @model_validator(mode='before') @@ -477,15 +479,15 @@ def logit_bias_logits_processor( seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, ignore_eos=self.ignore_eos, max_tokens=self.max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, - logprobs=self.logprobs, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, prompt_logprobs=self.logprobs if self.echo else None, skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=(self.spaces_between_special_tokens), + spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, logits_processors=logits_processors, From 032eeecb0cfa8ab86bbc19ca8a7c0cc14909e789 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 20 Jul 2024 00:12:39 +0000 Subject: [PATCH 93/94] Fix wrong model class --- vllm/entrypoints/openai/protocol.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b5d142a5465d3..aeff79363398f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -525,7 +525,7 @@ def validate_stream_options(cls, data): return data -class EmbeddingRequest(BaseModel): +class EmbeddingRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings model: str @@ -597,13 +597,13 @@ class CompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) -class EmbeddingResponseData(BaseModel): +class EmbeddingResponseData(OpenAIBaseModel): index: int object: str = "embedding" embedding: Union[List[float], str] -class EmbeddingResponse(BaseModel): +class EmbeddingResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) From 0a9b0d8fb2b6d983397e6f04fe1fb7fca5d5e497 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 21 Jul 2024 01:06:52 +0000 Subject: [PATCH 94/94] isort --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_tokenization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0e58b8c73bb8a..b21c2bc513186 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,10 +9,10 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 49b720b9ca6b8..94e1b03ed4036 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -2,12 +2,12 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, ErrorResponse,