From 95931c0f2318a47e5a5c659d9c48dccdc56c6db3 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 23 Jul 2024 01:13:53 +0800 Subject: [PATCH] [Frontend] Refactor prompt processing (#4028) Co-authored-by: Roger Wang --- benchmarks/benchmark_latency.py | 4 +- .../dev/multimodal/multimodal_index.rst | 2 +- .../dev/offline_inference/llm_inputs.rst | 2 +- docs/source/models/vlm.rst | 2 +- .../output_processor/test_stop_checker.py | 4 +- tests/entrypoints/openai/test_serving_chat.py | 5 +- vllm/__init__.py | 4 +- vllm/engine/arg_utils.py | 7 - vllm/engine/async_llm_engine.py | 63 ++-- vllm/engine/llm_engine.py | 9 +- vllm/entrypoints/llm.py | 37 +-- vllm/entrypoints/logger.py | 41 +++ vllm/entrypoints/openai/api_server.py | 47 ++- vllm/entrypoints/openai/cli_args.py | 8 + vllm/entrypoints/openai/protocol.py | 108 +++---- vllm/entrypoints/openai/run_batch.py | 20 +- vllm/entrypoints/openai/serving_chat.py | 119 +++++--- vllm/entrypoints/openai/serving_completion.py | 115 ++++---- vllm/entrypoints/openai/serving_embedding.py | 78 +++-- vllm/entrypoints/openai/serving_engine.py | 271 ++++++++++++++---- .../openai/serving_tokenization.py | 106 +++++-- vllm/inputs/__init__.py | 7 +- vllm/inputs/data.py | 24 +- vllm/sequence.py | 5 +- 24 files changed, 698 insertions(+), 390 deletions(-) create mode 100644 vllm/entrypoints/logger.py diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 8d0554b0f4f05..97afd301c8f24 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptStrictInputs +from vllm.inputs import PromptInputs from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptStrictInputs] = [{ + dummy_inputs: List[PromptInputs] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 6713dcf08d9f0..7cdbec2c9e3d4 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 31c3d16a3c8eb..9adf82d43f3e0 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptStrictInputs +.. autodata:: vllm.inputs.PromptInputs .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 92aca168dadf2..ef4ce0d44a162 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. diff --git a/tests/engine/output_processor/test_stop_checker.py b/tests/engine/output_processor/test_stop_checker.py index f795403e3d8ad..0d84443c51f99 100644 --- a/tests/engine/output_processor/test_stop_checker.py +++ b/tests/engine/output_processor/test_stop_checker.py @@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str, @pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [ ("This text ends with EOS token", "", 2), ]) -@pytest.mark.parametrize("ignore_eos", [True, False, None]) -@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None]) +@pytest.mark.parametrize("ignore_eos", [True, False]) +@pytest.mark.parametrize("include_stop_str_in_output", [True, False]) @pytest.mark.skip_global_cleanup def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int, ignore_eos: bool, include_stop_str_in_output: bool): diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 9a7abcfe5e590..464465494b714 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -32,7 +32,10 @@ async def _async_serving_chat_init(): model_config, served_model_names=[MODEL_NAME], response_role="assistant", - chat_template=CHAT_TEMPLATE) + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None) return serving_completion diff --git a/vllm/__init__.py b/vllm/__init__.py index 318f078fdbee7..0895c571d1d89 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -5,7 +5,7 @@ from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt +from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt from vllm.model_executor.models import ModelRegistry from vllm.outputs import (CompletionOutput, EmbeddingOutput, EmbeddingRequestOutput, RequestOutput) @@ -19,7 +19,7 @@ "__version__", "LLM", "ModelRegistry", - "PromptStrictInputs", + "PromptInputs", "TextPrompt", "TokensPrompt", "SamplingParams", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 972d4e0cd9942..4db071e4caef4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -827,7 +827,6 @@ class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" engine_use_ray: bool = False disable_log_requests: bool = False - max_log_len: Optional[int] = None @staticmethod def add_cli_args(parser: FlexibleArgumentParser, @@ -841,12 +840,6 @@ def add_cli_args(parser: FlexibleArgumentParser, parser.add_argument('--disable-log-requests', action='store_true', help='Disable logging requests.') - parser.add_argument('--max-log-len', - type=int, - default=None, - help='Max number of prompt characters or prompt ' - 'ID numbers being printed in log.' - '\n\nDefault: Unlimited') return parser diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c258cd9fdad63..3089cafc670a4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,8 +1,8 @@ import asyncio import time from functools import partial -from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, Union) +from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping, + Optional, Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -151,7 +151,10 @@ def process_exception(self, logger.info("Finished request %s.", request_id) self.abort_request(request_id) - def add_request(self, request_id: str, + def add_request(self, + request_id: str, + *, + verbose: bool = False, **engine_add_request_kwargs) -> AsyncStream: """Add a request to be sent to the engine on the next background loop iteration.""" @@ -166,6 +169,9 @@ def add_request(self, request_id: str, self.new_requests_event.set() + if verbose: + logger.info("Added request %s.", request_id) + return stream def abort_request(self, request_id: str, *, verbose: bool = False) -> None: @@ -299,14 +305,14 @@ async def process_model_inputs_async( return self.input_processor(llm_inputs) async def add_request_async( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -353,8 +359,6 @@ class AsyncLLMEngine: async frontend will be executed in a separate process as the model workers. log_requests: Whether to log the requests. - max_log_len: Maximum number of prompt characters or prompt ID numbers - being printed in log. start_engine_loop: If True, the background task to run the engine will be automatically started in the generate call. *args: Arguments for :class:`LLMEngine`. @@ -368,13 +372,11 @@ def __init__(self, engine_use_ray: bool, *args, log_requests: bool = True, - max_log_len: Optional[int] = None, start_engine_loop: bool = True, **kwargs) -> None: self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests - self.max_log_len = max_log_len self.engine = self._init_engine(*args, **kwargs) # jimpang: for lora self.lora_names_map = {} @@ -471,7 +473,6 @@ def from_engine_args( executor_class=executor_class, log_requests=not engine_args.disable_log_requests, log_stats=not engine_args.disable_log_stats, - max_log_len=engine_args.max_log_len, start_engine_loop=start_engine_loop, usage_context=usage_context, stat_loggers=stat_loggers, @@ -670,30 +671,9 @@ async def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncStream: - if self.log_requests: - if isinstance(inputs, str): - shortened_prompt = inputs - shortened_token_ids = None - else: - shortened_prompt = inputs.get("prompt") - shortened_token_ids = inputs.get("prompt_token_ids") - - max_log_len = self.max_log_len - if max_log_len is not None: - if shortened_prompt is not None: - shortened_prompt = shortened_prompt[:max_log_len] - if shortened_token_ids is not None: - shortened_token_ids = shortened_token_ids[:max_log_len] - - logger.info( - "Received request %s: prompt: %r, " - "params: %s, prompt_token_ids: %s, " - "lora_request: %s.", request_id, shortened_prompt, params, - shortened_token_ids, lora_request) - if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -709,6 +689,7 @@ async def add_request( stream = self._request_tracker.add_request( request_id, + verbose=self.log_requests, inputs=inputs, params=params, arrival_time=arrival_time, @@ -724,7 +705,7 @@ async def generate( sampling_params: SamplingParams, request_id: str, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -807,7 +788,7 @@ async def encode( pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncIterator[EmbeddingRequestOutput]: """Generate outputs for a request from an embedding model. @@ -885,7 +866,7 @@ async def _process_request( params: Union[SamplingParams, PoolingParams], *, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 27c429ae131e6..eabe3b23a9d58 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,6 +1,7 @@ import time from contextlib import contextmanager -from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, + Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union @@ -522,7 +523,7 @@ def _add_processed_request( arrival_time: float, lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, ) -> None: # Create the sequences. block_size = self.cache_config.block_size @@ -603,7 +604,7 @@ def add_request( params: Union[SamplingParams, PoolingParams], arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -677,7 +678,7 @@ def _create_sequence_group_with_sampling( sampling_params: SamplingParams, arrival_time: float, lora_request: Optional[LoRARequest], - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index cadaffa0e30cf..62309ed345b1d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,8 +6,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt, - TextTokensPrompt, TokensPrompt, +from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt, parse_and_batch_prompt) from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -238,7 +237,7 @@ def generate( @overload def generate( self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, sampling_params: Optional[Union[SamplingParams, @@ -255,7 +254,7 @@ def generate( "instead.") def generate( self, - prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, sampling_params: Optional[Union[SamplingParams, Sequence[SamplingParams]]] = None, @@ -302,9 +301,7 @@ def generate( prompt_token_ids=prompt_token_ids, ) else: - inputs = cast( - Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if sampling_params is None: # Use default sampling params. @@ -383,7 +380,7 @@ def encode( @overload def encode( self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], /, # We may enable `inputs` keyword after removing the old API *, pooling_params: Optional[Union[PoolingParams, @@ -400,7 +397,7 @@ def encode( "instead.") def encode( self, - prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + prompts: Union[Union[PromptInputs, Sequence[PromptInputs]], Optional[Union[str, List[str]]]] = None, pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None, @@ -417,7 +414,7 @@ def encode( Args: inputs: The inputs to the LLM. You may pass a sequence of inputs for - batch inference. See :class:`~vllm.inputs.PromptStrictInputs` + batch inference. See :class:`~vllm.inputs.PromptInputs` for more details about the format of each input. pooling_params: The pooling parameters for pooling. If None, we use the default pooling parameters. @@ -446,9 +443,7 @@ def encode( prompt_token_ids=prompt_token_ids, ) else: - inputs = cast( - Union[PromptStrictInputs, Sequence[PromptStrictInputs]], - prompts) + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts) if pooling_params is None: # Use default pooling params. @@ -496,17 +491,11 @@ def _convert_v1_inputs( inputs: List[PromptInputs] = [] for i in range(num_requests): if prompts is not None: - if prompt_token_ids is not None: - item = TextTokensPrompt( - prompt=prompts[i], - prompt_token_ids=prompt_token_ids[i]) - else: - item = TextPrompt(prompt=prompts[i]) + item = TextPrompt(prompt=prompts[i]) + elif prompt_token_ids is not None: + item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) else: - if prompt_token_ids is not None: - item = TokensPrompt(prompt_token_ids=prompt_token_ids[i]) - else: - raise AssertionError + raise AssertionError inputs.append(item) @@ -514,7 +503,7 @@ def _convert_v1_inputs( def _validate_and_add_requests( self, - inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]], + inputs: Union[PromptInputs, Sequence[PromptInputs]], params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], diff --git a/vllm/entrypoints/logger.py b/vllm/entrypoints/logger.py new file mode 100644 index 0000000000000..091896e1c7a69 --- /dev/null +++ b/vllm/entrypoints/logger.py @@ -0,0 +1,41 @@ +from typing import List, Optional, Union + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams + +logger = init_logger(__name__) + + +class RequestLogger: + + def __init__(self, *, max_log_len: Optional[int]) -> None: + super().__init__() + + self.max_log_len = max_log_len + + def log_inputs( + self, + request_id: str, + prompt: Optional[str], + prompt_token_ids: Optional[List[int]], + params: Optional[Union[SamplingParams, PoolingParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + max_log_len = self.max_log_len + if max_log_len is not None: + if prompt is not None: + prompt = prompt[:max_log_len] + + if prompt_token_ids is not None: + prompt_token_ids = prompt_token_ids[:max_log_len] + + logger.info( + "Received request %s: prompt: %r, " + "params: %s, prompt_token_ids: %s, " + "lora_request: %s, prompt_adapter_request: %s.", request_id, + prompt, params, prompt_token_ids, lora_request, + prompt_adapter_request) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 72f307c432020..931063d90566c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -18,6 +18,7 @@ import vllm.envs as envs from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.cli_args import make_arg_parser # yapf conflicts with isort for this block # yapf: disable @@ -244,24 +245,48 @@ def run_server(args, llm_engine=None): # When using single vLLM without engine_use_ray model_config = asyncio.run(engine.get_model_config()) + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + global openai_serving_chat global openai_serving_completion global openai_serving_embedding global openai_serving_tokenization - openai_serving_chat = OpenAIServingChat(engine, model_config, - served_model_names, - args.response_role, - args.lora_modules, - args.chat_template) + openai_serving_chat = OpenAIServingChat( + engine, + model_config, + served_model_names, + args.response_role, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + chat_template=args.chat_template, + ) openai_serving_completion = OpenAIServingCompletion( - engine, model_config, served_model_names, args.lora_modules, - args.prompt_adapters) - openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, - served_model_names) + engine, + model_config, + served_model_names, + lora_modules=args.lora_modules, + prompt_adapters=args.prompt_adapters, + request_logger=request_logger, + ) + openai_serving_embedding = OpenAIServingEmbedding( + engine, + model_config, + served_model_names, + request_logger=request_logger, + ) openai_serving_tokenization = OpenAIServingTokenization( - engine, model_config, served_model_names, args.lora_modules, - args.chat_template) + engine, + model_config, + served_model_names, + lora_modules=args.lora_modules, + request_logger=request_logger, + chat_template=args.chat_template, + ) app.root_path = args.root_path logger.info("Available routes are:") diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index f841633b572a9..64919c8be8642 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -130,6 +130,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "using app.add_middleware(). ") parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + return parser diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 212483109a799..c024bbc07c069 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -121,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: begin-chat-completion-sampling-params best_of: Optional[int] = None - use_beam_search: Optional[bool] = False - top_k: Optional[int] = -1 - min_p: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - length_penalty: Optional[float] = 1.0 - early_stopping: Optional[bool] = False - ignore_eos: Optional[bool] = False - min_tokens: Optional[int] = 0 + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) - skip_special_tokens: Optional[bool] = True - spaces_between_special_tokens: Optional[bool] = True + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: end-chat-completion-sampling-params # doc: begin-chat-completion-extra-params - echo: Optional[bool] = Field( + echo: bool = Field( default=False, description=( "If true, the new message will be prepended with the last message " "if they belong to the same role."), ) - add_generation_prompt: Optional[bool] = Field( + add_generation_prompt: bool = Field( default=True, description= ("If true, the generation prompt will be added to the chat template. " "This is a parameter used by chat template in tokenizer config of the " "model."), ) - add_special_tokens: Optional[bool] = Field( + add_special_tokens: bool = Field( default=False, description=( "If true, special tokens (e.g. BOS) will be added to the prompt " "on top of what is added by the chat template. " "For most models, the chat template takes care of adding the " - "special tokens so this should be set to False (as is the " + "special tokens so this should be set to false (as is the " "default)."), ) documents: Optional[List[Dict[str, str]]] = Field( @@ -178,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel): description=("Additional kwargs to pass to the template renderer. " "Will be accessible by the chat template."), ) - include_stop_str_in_output: Optional[bool] = Field( - default=False, - description=( - "Whether to include the stop string in the output. " - "This is only applied when the stop or stop_token_ids is set."), - ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, description=("If specified, the output will follow the JSON schema."), @@ -244,22 +240,22 @@ def logit_bias_logits_processor( return SamplingParams( n=self.n, + best_of=self.best_of, presence_penalty=self.presence_penalty, frequency_penalty=self.frequency_penalty, repetition_penalty=self.repetition_penalty, temperature=self.temperature, top_p=self.top_p, + top_k=self.top_k, min_p=self.min_p, seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, - max_tokens=self.max_tokens, - min_tokens=self.min_tokens, logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, - best_of=self.best_of, - top_k=self.top_k, ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens, + min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, skip_special_tokens=self.skip_special_tokens, @@ -267,6 +263,7 @@ def logit_bias_logits_processor( include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, logits_processors=logits_processors, + truncate_prompt_tokens=self.truncate_prompt_tokens, ) @model_validator(mode='before') @@ -348,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel): user: Optional[str] = None # doc: begin-completion-sampling-params - use_beam_search: Optional[bool] = False - top_k: Optional[int] = -1 - min_p: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - length_penalty: Optional[float] = 1.0 - early_stopping: Optional[bool] = False + use_beam_search: bool = False + top_k: int = -1 + min_p: float = 0.0 + repetition_penalty: float = 1.0 + length_penalty: float = 1.0 + early_stopping: bool = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) - ignore_eos: Optional[bool] = False - min_tokens: Optional[int] = 0 - skip_special_tokens: Optional[bool] = True - spaces_between_special_tokens: Optional[bool] = True + include_stop_str_in_output: bool = False + ignore_eos: bool = False + min_tokens: int = 0 + skip_special_tokens: bool = True + spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: end-completion-sampling-params # doc: begin-completion-extra-params - include_stop_str_in_output: Optional[bool] = Field( - default=False, + add_special_tokens: bool = Field( + default=True, description=( - "Whether to include the stop string in the output. " - "This is only applied when the stop or stop_token_ids is set."), + "If true (the default), special tokens (e.g. BOS) will be added to " + "the prompt."), ) response_format: Optional[ResponseFormat] = Field( default=None, @@ -447,15 +445,15 @@ def logit_bias_logits_processor( seed=self.seed, stop=self.stop, stop_token_ids=self.stop_token_ids, + logprobs=self.logprobs, ignore_eos=self.ignore_eos, max_tokens=self.max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, - logprobs=self.logprobs, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, prompt_logprobs=self.logprobs if self.echo else None, skip_special_tokens=self.skip_special_tokens, - spaces_between_special_tokens=(self.spaces_between_special_tokens), + spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, logits_processors=logits_processors, @@ -489,11 +487,11 @@ def check_logprobs(cls, data): def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): raise ValueError( - "Stream options can only be defined when stream is True.") + "Stream options can only be defined when stream is true.") return data -class EmbeddingRequest(BaseModel): +class EmbeddingRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/embeddings model: str @@ -565,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) -class EmbeddingResponseData(BaseModel): +class EmbeddingResponseData(OpenAIBaseModel): index: int object: str = "embedding" embedding: Union[List[float], str] -class EmbeddingResponse(BaseModel): +class EmbeddingResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "list" created: int = Field(default_factory=lambda: int(time.time())) @@ -670,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel): # /v1/chat/completions is supported. url: str - # The parameteters of the request. - body: Union[ChatCompletionRequest, ] + # The parameters of the request. + body: ChatCompletionRequest class BatchResponseData(OpenAIBaseModel): @@ -703,12 +701,22 @@ class BatchRequestOutput(OpenAIBaseModel): error: Optional[Any] -class TokenizeRequest(OpenAIBaseModel): +class TokenizeCompletionRequest(OpenAIBaseModel): + model: str + prompt: str + + add_special_tokens: bool = Field(default=True) + + +class TokenizeChatRequest(OpenAIBaseModel): + model: str + messages: List[ChatCompletionMessageParam] + add_generation_prompt: bool = Field(default=True) add_special_tokens: bool = Field(default=False) - prompt: Optional[str] = Field(default=None) - messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None) - model: str + + +TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest] class TokenizeResponse(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index dac6c2b4cd48f..3c5e5e651b54d 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -6,6 +6,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchRequestOutput, BatchResponseData, @@ -44,9 +45,17 @@ def parse_args(): type=nullable_str, default="assistant", help="The role name to return if " - "`request.add_generation_prompt=true`.") + "`request.add_generation_prompt=True`.") parser = AsyncEngineArgs.add_cli_args(parser) + + parser.add_argument('--max-log-len', + type=int, + default=None, + help='Max number of prompt characters or prompt ' + 'ID numbers being printed in log.' + '\n\nDefault: Unlimited') + return parser.parse_args() @@ -114,11 +123,20 @@ async def main(args): # When using single vLLM without engine_use_ray model_config = await engine.get_model_config() + if args.disable_log_requests: + request_logger = None + else: + request_logger = RequestLogger(max_log_len=args.max_log_len) + openai_serving_chat = OpenAIServingChat( engine, model_config, served_model_names, args.response_role, + lora_modules=None, + prompt_adapters=None, + request_logger=request_logger, + chat_template=None, ) # Submit all requests in the file to the engine "concurrently". diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 95ca5d080afca..b21c2bc513186 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,6 +12,7 @@ from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, @@ -20,7 +21,8 @@ ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, - OpenAIServing) + OpenAIServing, + PromptAdapterPath) from vllm.inputs import PromptInputs from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( @@ -37,17 +39,24 @@ class OpenAIServingChat(OpenAIServing): - def __init__(self, - engine: AsyncLLMEngine, - model_config: ModelConfig, - served_model_names: List[str], - response_role: str, - lora_modules: Optional[List[LoRAModulePath]] = None, - chat_template: Optional[str] = None): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + prompt_adapters=prompt_adapters, + request_logger=request_logger) self.response_role = response_role @@ -74,7 +83,12 @@ async def create_chat_completion( return error_check_ret try: - _, lora_request = self._maybe_get_adapter(request) + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + model_config = self.model_config tokenizer = await self.engine.get_tokenizer(lora_request) conversation: List[ConversationMessage] = [] @@ -82,7 +96,7 @@ async def create_chat_completion( for msg in request.messages: chat_parsed_result = parse_chat_message_content( - msg, self.model_config, tokenizer) + msg, model_config, tokenizer) conversation.extend(chat_parsed_result.messages) mm_futures.extend(chat_parsed_result.mm_futures) @@ -116,14 +130,8 @@ async def create_chat_completion( logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) - request_id = f"cmpl-{random_uuid()}" + request_id = f"chat-{random_uuid()}" try: - # Tokenize/detokenize depending on prompt format (string/token list) - prompt_ids, prompt_text = await self._validate_prompt_and_tokenize( - request, - tokenizer, - prompt=prompt, - add_special_tokens=request.add_special_tokens) sampling_params = request.to_sampling_params() decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ @@ -137,31 +145,47 @@ async def create_chat_completion( sampling_params.logits_processors = [] sampling_params.logits_processors.append( guided_decode_logits_processor) + + prompt_inputs = self._tokenize_prompt_input( + request, + tokenizer, + prompt, + truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + + self._log_inputs(request_id, + prompt_inputs, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + engine_inputs: PromptInputs = { + "prompt_token_ids": prompt_inputs["prompt_token_ids"], + } + if mm_data is not None: + engine_inputs["multi_modal_data"] = mm_data + + is_tracing_enabled = await self.engine.is_tracing_enabled() + trace_headers = None + if is_tracing_enabled and raw_request: + trace_headers = extract_trace_headers(raw_request.headers) + if (not is_tracing_enabled and raw_request + and contains_trace_headers(raw_request.headers)): + log_tracing_disabled_warning() + + result_generator = self.engine.generate( + engine_inputs, + sampling_params, + request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, + ) except ValueError as e: + # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - inputs: PromptInputs = { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids, - } - if mm_data: - inputs["multi_modal_data"] = mm_data - - is_tracing_enabled = await self.engine.is_tracing_enabled() - trace_headers = None - if is_tracing_enabled and raw_request: - trace_headers = extract_trace_headers(raw_request.headers) - if not is_tracing_enabled and raw_request and contains_trace_headers( - raw_request.headers): - log_tracing_disabled_warning() - - result_generator = self.engine.generate( - inputs, - sampling_params, - request_id, - lora_request, - trace_headers=trace_headers, - ) # Streaming response if request.stream: return self.chat_completion_stream_generator( @@ -195,10 +219,11 @@ async def chat_completion_stream_generator( first_iteration = True # Send response for each token for each request.n (index) - assert request.n is not None - previous_texts = [""] * request.n - previous_num_tokens = [0] * request.n - finish_reason_sent = [False] * request.n + num_choices = 1 if request.n is None else request.n + previous_texts = [""] * num_choices + previous_num_tokens = [0] * num_choices + finish_reason_sent = [False] * num_choices + try: async for res in result_generator: # We need to do it here, because if there are exceptions in @@ -208,7 +233,7 @@ async def chat_completion_stream_generator( # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) - for i in range(request.n): + for i in range(num_choices): choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(role=role), @@ -236,19 +261,19 @@ async def chat_completion_stream_generator( last_msg_content = conversation[-1]["content"] if last_msg_content: - for i in range(request.n): + for i in range(num_choices): choice_data = ( ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( content=last_msg_content), + logprobs=None, finish_reason=None)) chunk = ChatCompletionStreamResponse( id=request_id, object=chunk_object_type, created=created_time, choices=[choice_data], - logprobs=None, model=model_name) if (request.stream_options and request.stream_options.include_usage): diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index e61f3fdbf6666..6aef4c9f96150 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -2,13 +2,14 @@ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional) from typing import Sequence as GenericSequence -from typing import Tuple +from typing import Tuple, cast from fastapi import Request from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable from vllm.entrypoints.openai.protocol import (CompletionLogProbs, @@ -39,40 +40,24 @@ [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs] -def parse_prompt_format(prompt) -> Tuple[bool, list]: - # get the prompt, openai supports the following - # "a string, array of strings, array of tokens, or array of token arrays." - prompt_is_tokens = False - prompts = [prompt] # case 1: a string - if isinstance(prompt, list): - if len(prompt) == 0: - raise ValueError("please provide at least one prompt") - elif isinstance(prompt[0], str): - prompt_is_tokens = False - prompts = prompt # case 2: array of strings - elif isinstance(prompt[0], int): - prompt_is_tokens = True - prompts = [prompt] # case 3: array of tokens - elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): - prompt_is_tokens = True - prompts = prompt # case 4: array of token arrays - else: - raise ValueError("prompt must be a string, array of strings, " - "array of tokens, or array of token arrays") - return prompt_is_tokens, prompts - - class OpenAIServingCompletion(OpenAIServing): - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, lora_modules=lora_modules, - prompt_adapters=prompt_adapters) + prompt_adapters=prompt_adapters, + request_logger=request_logger) async def create_completion(self, request: CompletionRequest, raw_request: Request): @@ -101,12 +86,11 @@ async def create_completion(self, request: CompletionRequest, # Schedule the request and get the result generator. generators: List[AsyncIterator[RequestOutput]] = [] try: - adapter_type, adapter_request = self._maybe_get_adapter(request) - lora_request, prompt_adapter_request = None, None - if adapter_type == 'LoRA': - lora_request, prompt_adapter_request = adapter_request, None - elif adapter_type == 'PromptAdapter': - lora_request, prompt_adapter_request = None, adapter_request + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) sampling_params = request.to_sampling_params() @@ -122,17 +106,25 @@ async def create_completion(self, request: CompletionRequest, sampling_params.logits_processors = [] sampling_params.logits_processors.append( guided_decode_logit_processor) - prompt_is_tokens, prompts = parse_prompt_format(request.prompt) - for i, prompt in enumerate(prompts): - prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" - prompt_formats = await self._validate_prompt_and_tokenize( + prompts = list( + self._tokenize_prompt_input_or_inputs( request, tokenizer, + request.prompt, truncate_prompt_tokens=sampling_params. truncate_prompt_tokens, - **{prompt_arg: prompt}) - prompt_ids, prompt_text = prompt_formats + add_special_tokens=request.add_special_tokens, + )) + + for i, prompt_inputs in enumerate(prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + prompt_inputs, + params=sampling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) is_tracing_enabled = await self.engine.is_tracing_enabled() trace_headers = None @@ -143,12 +135,9 @@ async def create_completion(self, request: CompletionRequest, log_tracing_disabled_warning() generator = self.engine.generate( - { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids - }, + {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, sampling_params, - f"{request_id}-{i}", + request_id_item, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, @@ -189,9 +178,27 @@ async def create_completion(self, request: CompletionRequest, await self.engine.abort(f"{request_id}-{i}") return self.create_error_response("Client disconnected") final_res_batch[i] = res + + for i, final_res in enumerate(final_res_batch): + assert final_res is not None + + # The output should contain the input text + # We did not pass it into vLLM engine to avoid being redundant + # with the inputs token IDs + if final_res.prompt is None: + final_res.prompt = prompts[i]["prompt"] + + final_res_batch_checked = cast(List[RequestOutput], + final_res_batch) + response = self.request_output_to_completion_response( - final_res_batch, request, request_id, created_time, model_name, - tokenizer) + final_res_batch_checked, + request, + request_id, + created_time, + model_name, + tokenizer, + ) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -220,10 +227,10 @@ async def completion_stream_generator( num_prompts: int, tokenizer: PreTrainedTokenizer, ) -> AsyncGenerator[str, None]: - assert request.n is not None - previous_texts = [""] * request.n * num_prompts - previous_num_tokens = [0] * request.n * num_prompts - has_echoed = [False] * request.n * num_prompts + num_choices = 1 if request.n is None else request.n + previous_texts = [""] * num_choices * num_prompts + previous_num_tokens = [0] * num_choices * num_prompts + has_echoed = [False] * num_choices * num_prompts try: async for prompt_idx, res in result_generator: @@ -234,7 +241,7 @@ async def completion_stream_generator( raise StopAsyncIteration() for output in res.outputs: - i = output.index + prompt_idx * request.n + i = output.index + prompt_idx * num_choices # TODO(simon): optimize the performance by avoiding full # text O(n^2) sending. @@ -343,8 +350,8 @@ def request_output_to_completion_response( choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 + for final_res in final_res_batch: - assert final_res is not None prompt_token_ids = final_res.prompt_token_ids prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 19e4288f5aa1c..bccc90894e79f 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,16 +1,16 @@ import base64 import time -from typing import AsyncIterator, List, Optional, Tuple +from typing import AsyncIterator, List, Optional, Tuple, cast import numpy as np from fastapi import Request from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (EmbeddingRequest, EmbeddingResponse, EmbeddingResponseData, UsageInfo) -from vllm.entrypoints.openai.serving_completion import parse_prompt_format from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger from vllm.outputs import EmbeddingRequestOutput @@ -28,11 +28,11 @@ def request_output_to_embedding_response( data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): - assert final_res is not None prompt_token_ids = final_res.prompt_token_ids embedding = final_res.outputs.embedding if encoding_format == "base64": - embedding = base64.b64encode(np.array(embedding)) + embedding_bytes = np.array(embedding).tobytes() + embedding = base64.b64encode(embedding_bytes).decode("utf-8") embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) @@ -54,12 +54,20 @@ def request_output_to_embedding_response( class OpenAIServingEmbedding(OpenAIServing): - def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig, - served_model_names: List[str]): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + *, + request_logger: Optional[RequestLogger], + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=None) + lora_modules=None, + prompt_adapters=None, + request_logger=request_logger) self._check_embedding_mode(model_config.embedding_mode) async def create_embedding(self, request: EmbeddingRequest, @@ -80,29 +88,47 @@ async def create_embedding(self, request: EmbeddingRequest, "dimensions is currently not supported") model_name = request.model - request_id = f"cmpl-{random_uuid()}" + request_id = f"embd-{random_uuid()}" created_time = int(time.monotonic()) # Schedule the request and get the result generator. - generators = [] + generators: List[AsyncIterator[EmbeddingRequestOutput]] = [] try: - prompt_is_tokens, prompts = parse_prompt_format(request.input) + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine.get_tokenizer(lora_request) + pooling_params = request.to_pooling_params() - tokenizer = await self.engine.get_tokenizer() - for i, prompt in enumerate(prompts): - prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" - prompt_formats = await self._validate_prompt_and_tokenize( - request, tokenizer, **{prompt_arg: prompt}) - prompt_ids, prompt_text = prompt_formats + prompts = list( + self._tokenize_prompt_input_or_inputs( + request, + tokenizer, + request.input, + )) + + for i, prompt_inputs in enumerate(prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + prompt_inputs, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + if prompt_adapter_request is not None: + raise NotImplementedError( + "Prompt adapter is not supported " + "for embedding models") generator = self.engine.encode( - { - "prompt": prompt_text, - "prompt_token_ids": prompt_ids - }, + {"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, pooling_params, - f"{request_id}-{i}", + request_id_item, + lora_request=lora_request, ) generators.append(generator) @@ -121,11 +147,17 @@ async def create_embedding(self, request: EmbeddingRequest, if await raw_request.is_disconnected(): # Abort the request if the client disconnects. await self.engine.abort(f"{request_id}-{i}") - # TODO: Use a vllm-specific Validation Error return self.create_error_response("Client disconnected") final_res_batch[i] = res + + for final_res in final_res_batch: + assert final_res is not None + + final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch) + response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name, + final_res_batch_checked, request_id, created_time, model_name, encoding_format) except ValueError as e: # TODO: Use a vllm-specific Validation Error diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 4123ace36479e..7578dc9dc3c0c 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -2,23 +2,33 @@ import pathlib from dataclasses import dataclass from http import HTTPStatus -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union from pydantic import Field -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from typing_extensions import Annotated from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.logger import RequestLogger +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, CompletionRequest, DetokenizeRequest, EmbeddingRequest, ErrorResponse, ModelCard, ModelList, - ModelPermission, TokenizeRequest) + ModelPermission, + TokenizeChatRequest, + TokenizeCompletionRequest, + TokenizeRequest) +# yapf: enable +from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob logger = init_logger(__name__) @@ -36,6 +46,17 @@ class LoRAModulePath: local_path: str +AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest, + EmbeddingRequest, TokenizeRequest] + +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] + + +class TextTokensPrompt(TypedDict): + prompt: str + prompt_token_ids: List[int] + + class OpenAIServing: def __init__( @@ -43,8 +64,10 @@ def __init__( engine: AsyncLLMEngine, model_config: ModelConfig, served_model_names: List[str], + *, lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]] = None, + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], ): super().__init__() @@ -78,6 +101,8 @@ def __init__( prompt_adapter_local_path=prompt_adapter.local_path, prompt_adapter_num_virtual_tokens=num_virtual_tokens)) + self.request_logger = request_logger + async def show_available_models(self) -> ModelList: """Show available models. Right now we only have one model.""" model_cards = [ @@ -126,9 +151,8 @@ def create_streaming_error_response( return json_str async def _check_model( - self, request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, - TokenizeRequest] + self, + request: AnyRequest, ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None @@ -144,64 +168,65 @@ async def _check_model( err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_adapter( - self, request: Union[CompletionRequest, ChatCompletionRequest, - EmbeddingRequest, TokenizeRequest, - DetokenizeRequest] - ) -> Tuple[Optional[str], Optional[Union[LoRARequest, - PromptAdapterRequest]]]: + def _maybe_get_adapters( + self, request: AnyRequest + ) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[ + None, PromptAdapterRequest]]: if request.model in self.served_model_names: return None, None for lora in self.lora_requests: if request.model == lora.lora_name: - return 'LoRA', lora + return lora, None for prompt_adapter in self.prompt_adapter_requests: if request.model == prompt_adapter.prompt_adapter_name: - return 'PromptAdapter', prompt_adapter + return None, prompt_adapter # if _check_model has been called earlier, this will be unreachable raise ValueError(f"The model `{request.model}` does not exist.") - async def _validate_prompt_and_tokenize( - self, - request: Union[ChatCompletionRequest, CompletionRequest, - DetokenizeRequest, EmbeddingRequest, - TokenizeRequest], - tokenizer: "PreTrainedTokenizer", - prompt: Optional[str] = None, - prompt_ids: Optional[List[int]] = None, - truncate_prompt_tokens: Optional[Annotated[int, - Field(ge=1)]] = None, - add_special_tokens: Optional[bool] = True - ) -> Tuple[List[int], str]: - if not (prompt or prompt_ids): - raise ValueError("Either prompt or prompt_ids should be provided.") - if prompt and prompt_ids: - raise ValueError( - "Only one of prompt or prompt_ids should be provided.") - - if prompt_ids is None: - # When using OpenAIServingChat for chat completions, for - # most models the special tokens (e.g., BOS) have already - # been added by the chat template. Therefore, we do not - # need to add them again. - # Set add_special_tokens to False (by default) to avoid - # adding the BOS tokens again. - tokenizer_kwargs: Dict[str, Any] = { - "add_special_tokens": add_special_tokens - } - if truncate_prompt_tokens is not None: - tokenizer_kwargs.update({ - "truncation": True, - "max_length": truncate_prompt_tokens, - }) - input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids - elif truncate_prompt_tokens is not None: - input_ids = prompt_ids[-truncate_prompt_tokens:] + def _normalize_prompt_text_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt: str, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + add_special_tokens: bool, + ) -> TextTokensPrompt: + if truncate_prompt_tokens is None: + encoded = tokenizer(prompt, add_special_tokens=add_special_tokens) else: + encoded = tokenizer(prompt, + add_special_tokens=add_special_tokens, + truncation=True, + max_length=truncate_prompt_tokens) + + input_ids = encoded.input_ids + + input_text = prompt + + return self._validate_input(request, input_ids, input_text) + + def _normalize_prompt_tokens_to_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_ids: List[int], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]], + ) -> TextTokensPrompt: + if truncate_prompt_tokens is None: input_ids = prompt_ids + else: + input_ids = prompt_ids[-truncate_prompt_tokens:] + + input_text = tokenizer.decode(input_ids) - input_text = prompt if prompt is not None else tokenizer.decode( - input_ids) + return self._validate_input(request, input_ids, input_text) + + def _validate_input( + self, + request: AnyRequest, + input_ids: List[int], + input_text: str, + ) -> TextTokensPrompt: token_num = len(input_ids) # Note: EmbeddingRequest doesn't have max_tokens @@ -211,13 +236,16 @@ async def _validate_prompt_and_tokenize( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the input for embedding " - f"generation. Please reduce the length of the input.", ) - return input_ids, input_text + f"generation. Please reduce the length of the input.") + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) # Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens # and does not require model context length validation - if isinstance(request, (TokenizeRequest, DetokenizeRequest)): - return input_ids, input_text + if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest, + DetokenizeRequest)): + return TextTokensPrompt(prompt=input_text, + prompt_token_ids=input_ids) if request.max_tokens is None: if token_num >= self.max_model_len: @@ -225,7 +253,7 @@ async def _validate_prompt_and_tokenize( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " - f"Please reduce the length of the messages.", ) + f"Please reduce the length of the messages.") request.max_tokens = self.max_model_len - token_num if token_num + request.max_tokens > self.max_model_len: @@ -235,13 +263,132 @@ async def _validate_prompt_and_tokenize( f"{request.max_tokens + token_num} tokens " f"({token_num} in the messages, " f"{request.max_tokens} in the completion). " - f"Please reduce the length of the messages or completion.", ) + f"Please reduce the length of the messages or completion.") + + return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids) + + def _tokenize_prompt_input( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_input: Union[str, List[int]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> TextTokensPrompt: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes single input. + """ + return next( + self._tokenize_prompt_inputs( + request, + tokenizer, + [prompt_input], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + )) + + def _tokenize_prompt_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + prompt_inputs: Iterable[Union[str, List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Iterator[TextTokensPrompt]: + """ + A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs` + that assumes multiple inputs. + """ + for text in prompt_inputs: + if isinstance(text, str): + yield self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=text, + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + yield self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=text, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _tokenize_prompt_input_or_inputs( + self, + request: AnyRequest, + tokenizer: AnyTokenizer, + input_or_inputs: Union[str, List[str], List[int], List[List[int]]], + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, + add_special_tokens: bool = True, + ) -> Iterator[TextTokensPrompt]: + """ + Tokenize/detokenize depending on the input format. + + According to `OpenAI API `_ + , each input can be a string or array of tokens. Note that each request + can pass one or more inputs. + """ + for prompt_input in parse_and_batch_prompt(input_or_inputs): + # Although our type checking is based on mypy, + # VSCode Pyright extension should still work properly + # "is True" is required for Pyright to perform type narrowing + # See: https://github.com/microsoft/pyright/issues/7672 + if prompt_input["is_tokens"] is False: + yield self._normalize_prompt_text_to_input( + request, + tokenizer, + prompt=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + add_special_tokens=add_special_tokens, + ) + else: + yield self._normalize_prompt_tokens_to_input( + request, + tokenizer, + prompt_ids=prompt_input["content"], + truncate_prompt_tokens=truncate_prompt_tokens, + ) + + def _log_inputs( + self, + request_id: str, + inputs: Union[str, List[int], TextTokensPrompt], + params: Optional[Union[SamplingParams, PoolingParams]], + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + ) -> None: + if self.request_logger is None: + return + + if isinstance(inputs, str): + prompt = inputs + prompt_token_ids = None + elif isinstance(inputs, list): + prompt = None + prompt_token_ids = inputs else: - return input_ids, input_text + prompt = inputs["prompt"] + prompt_token_ids = inputs["prompt_token_ids"] + + self.request_logger.log_inputs( + request_id, + prompt, + prompt_token_ids, + params=params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) @staticmethod - def _get_decoded_token(logprob: Logprob, token_id: int, - tokenizer: PreTrainedTokenizer) -> str: + def _get_decoded_token( + logprob: Logprob, + token_id: int, + tokenizer: AnyTokenizer, + ) -> str: if logprob.decoded_token is not None: return logprob.decoded_token return tokenizer.decode(token_id) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 70a254785eba3..94e1b03ed4036 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,83 +1,135 @@ -from typing import List, Optional +from typing import List, Optional, Union from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +# yapf conflicts with isort for this block +# yapf: disable from vllm.entrypoints.chat_utils import (ConversationMessage, load_chat_template, parse_chat_message_content) +from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, + ErrorResponse, + TokenizeChatRequest, TokenizeRequest, TokenizeResponse) +# yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) +from vllm.utils import random_uuid class OpenAIServingTokenization(OpenAIServing): - def __init__(self, - engine: AsyncLLMEngine, - model_config: ModelConfig, - served_model_names: List[str], - lora_modules: Optional[List[LoRAModulePath]] = None, - chat_template: Optional[str] = None): + def __init__( + self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + *, + lora_modules: Optional[List[LoRAModulePath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + ): super().__init__(engine=engine, model_config=model_config, served_model_names=served_model_names, - lora_modules=lora_modules) + lora_modules=lora_modules, + prompt_adapters=None, + request_logger=request_logger) # If this is None we use the tokenizer's default chat template self.chat_template = load_chat_template(chat_template) - async def create_tokenize(self, - request: TokenizeRequest) -> TokenizeResponse: + async def create_tokenize( + self, + request: TokenizeRequest, + ) -> Union[TokenizeResponse, ErrorResponse]: error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret - if not (request.prompt or request.messages): - return self.create_error_response( - "Either `prompt` or `messages` should be provided.") + request_id = f"tokn-{random_uuid()}" - if (request.prompt and request.messages): - return self.create_error_response( - "Only one of `prompt` or `messages` should be provided.") + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) - _, lora_request = self._maybe_get_adapter(request) tokenizer = await self.engine.get_tokenizer(lora_request) - if request.messages: + + if isinstance(request, TokenizeChatRequest): + model_config = self.model_config + conversation: List[ConversationMessage] = [] for message in request.messages: - result = parse_chat_message_content(message, self.model_config, + result = parse_chat_message_content(message, model_config, tokenizer) conversation.extend(result.messages) - request.prompt = tokenizer.apply_chat_template( + prompt = tokenizer.apply_chat_template( add_generation_prompt=request.add_generation_prompt, conversation=conversation, tokenize=False, chat_template=self.chat_template) + assert isinstance(prompt, str) + else: + prompt = request.prompt + + self._log_inputs(request_id, + prompt, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) - (input_ids, input_text) = await self._validate_prompt_and_tokenize( + # Silently ignore prompt adapter since it does not affect tokenization + + prompt_input = self._tokenize_prompt_input( request, tokenizer, - prompt=request.prompt, - add_special_tokens=request.add_special_tokens) + prompt, + add_special_tokens=request.add_special_tokens, + ) + input_ids = prompt_input["prompt_token_ids"] return TokenizeResponse(tokens=input_ids, count=len(input_ids), max_model_len=self.max_model_len) async def create_detokenize( - self, request: DetokenizeRequest) -> DetokenizeResponse: + self, + request: DetokenizeRequest, + ) -> Union[DetokenizeResponse, ErrorResponse]: error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret - _, lora_request = self._maybe_get_adapter(request) + request_id = f"tokn-{random_uuid()}" + + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + tokenizer = await self.engine.get_tokenizer(lora_request) - (input_ids, input_text) = await self._validate_prompt_and_tokenize( - request, tokenizer, prompt_ids=request.tokens) + + self._log_inputs(request_id, + request.tokens, + params=None, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for tokenization") + + prompt_input = self._tokenize_prompt_input( + request, + tokenizer, + request.tokens, + ) + input_text = prompt_input["prompt"] return DetokenizeResponse(prompt=input_text) diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index d094156962955..b13d9acf93d3b 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -1,6 +1,5 @@ from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, - PromptStrictInputs, TextPrompt, TextTokensPrompt, - TokensPrompt, parse_and_batch_prompt) + TextPrompt, TokensPrompt, parse_and_batch_prompt) from .registry import InputContext, InputRegistry INPUT_REGISTRY = InputRegistry() @@ -14,6 +13,6 @@ __all__ = [ "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", - "TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", - "LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry" + "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY", + "InputContext", "InputRegistry" ] diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index c6381fcc01e5f..4443e6c70fe5b 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -92,25 +92,7 @@ class TokensPrompt(TypedDict): """ -class TextTokensPrompt(TypedDict): - """It is assumed that :attr:`prompt` is consistent with - :attr:`prompt_token_ids`. This is currently used in - :class:`AsyncLLMEngine` for logging both the text and token IDs.""" - - prompt: str - """The prompt text.""" - - prompt_token_ids: List[int] - """The token IDs of the prompt.""" - - multi_modal_data: NotRequired["MultiModalDataDict"] - """ - Optional multi-modal data to pass to the model, - if the model supports it. - """ - - -PromptStrictInputs = Union[str, TextPrompt, TokensPrompt] +PromptInputs = Union[str, TextPrompt, TokensPrompt] """ The inputs to the LLM, which can take one of the following forms: @@ -118,10 +100,6 @@ class TextTokensPrompt(TypedDict): - A tokenized prompt (:class:`TokensPrompt`) """ -PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt] -"""Same as :const:`PromptStrictInputs` but additionally accepts -:class:`TextTokensPrompt`.""" - class LLMInputs(TypedDict): """ diff --git a/vllm/sequence.py b/vllm/sequence.py index 6c12a01bd0b2b..0cd4c7e71d78d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,7 +5,8 @@ from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union +from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple, + Union) import torch @@ -438,7 +439,7 @@ def __init__( embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, - trace_headers: Optional[Dict[str, str]] = None, + trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id