diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index f7b84eebc8b6a..3474bd3861598 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -38,7 +38,6 @@ jobs: mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip - mypy vllm/entrypoints --follow-imports skip mypy vllm/executor --follow-imports skip mypy vllm/lora --follow-imports skip mypy vllm/model_executor --follow-imports skip diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 6a8d99635b8f0..e292c32999d63 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -6,7 +6,7 @@ sphinx-argparse==0.4.0 msgspec # packages to install to build the documentation -pydantic +pydantic >= 2.8 -f https://download.pytorch.org/whl/cpu torch py-cpuinfo diff --git a/format.sh b/format.sh index a8fd95a1ea445..9e0780870303d 100755 --- a/format.sh +++ b/format.sh @@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip -mypy vllm/entrypoints --follow-imports skip mypy vllm/executor --follow-imports skip mypy vllm/lora --follow-imports skip mypy vllm/model_executor --follow-imports skip diff --git a/pyproject.toml b/pyproject.toml index ba0e10241ca2d..90df64ad2ae33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ files = [ "vllm/*.py", "vllm/adapter_commons", "vllm/assets", + "vllm/entrypoints", "vllm/inputs", "vllm/logging", "vllm/multimodal", diff --git a/requirements-common.txt b/requirements-common.txt index b6bed8a73d8c8..534d63feec2b8 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,7 +11,7 @@ fastapi aiohttp openai >= 1.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] -pydantic >= 2.0 # Required for OpenAI server. +pydantic >= 2.8 # Required for OpenAI server. pillow # Required for image processing prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index c96d602b63438..afcb0f44befc5 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -1,7 +1,7 @@ # imports for guided decoding tests import json import re -from typing import List +from typing import Dict, List, Optional import jsonschema import openai # use the official client for correctness check @@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI, assert message.content is not None and len(message.content) >= 0 +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name, prompt_logprobs", + [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], +) +async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, + model_name: str, + prompt_logprobs: Optional[int]): + params: Dict = { + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }], + "model": + model_name + } + + if prompt_logprobs is not None: + params["extra_body"] = {"prompt_logprobs": prompt_logprobs} + + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): + await client.chat.completions.create(**params) + else: + completion = await client.chat.completions.create(**params) + if prompt_logprobs is not None: + assert completion.prompt_logprobs is not None + assert len(completion.prompt_logprobs) > 0 + else: + assert completion.prompt_logprobs is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, + model_name: str): + params: Dict = { + "messages": [{ + "role": "system", + "content": "You are a helpful assistant." + }, { + "role": "user", + "content": "Who won the world series in 2020?" + }, { + "role": + "assistant", + "content": + "The Los Angeles Dodgers won the World Series in 2020." + }, { + "role": "user", + "content": "Where was it played?" + }], + "model": + model_name, + "extra_body": { + "prompt_logprobs": 1 + } + } + + completion_1 = await client.chat.completions.create(**params) + + params["extra_body"] = {"prompt_logprobs": 2} + completion_2 = await client.chat.completions.create(**params) + + assert len(completion_1.prompt_logprobs[3]) == 1 + assert len(completion_2.prompt_logprobs[3]) == 2 + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 4d0c6d73518dd..18f41f5fc671b 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -3,7 +3,7 @@ import re import shutil from tempfile import TemporaryDirectory -from typing import Dict, List +from typing import Dict, List, Optional import jsonschema import openai # use the official client for correctness check @@ -268,92 +268,6 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI, assert len(completion.choices[0].text) >= 0 -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name, prompt_logprobs", - [(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)], -) -async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str, prompt_logprobs: int): - params: Dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name - } - - if prompt_logprobs is not None: - params["extra_body"] = {"prompt_logprobs": prompt_logprobs} - - if prompt_logprobs and prompt_logprobs < 0: - with pytest.raises(BadRequestError) as err_info: - await client.chat.completions.create(**params) - expected_err_string = ( - "Error code: 400 - {'object': 'error', 'message': " - "'Prompt_logprobs set to invalid negative value: -1'," - " 'type': 'BadRequestError', 'param': None, 'code': 400}") - assert str(err_info.value) == expected_err_string - else: - completion = await client.chat.completions.create(**params) - if prompt_logprobs and prompt_logprobs > 0: - assert completion.prompt_logprobs is not None - assert len(completion.prompt_logprobs) > 0 - else: - assert completion.prompt_logprobs is None - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, - model_name: str): - params: Dict = { - "messages": [{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - "model": - model_name, - "extra_body": { - "prompt_logprobs": 1 - } - } - - completion_1 = await client.chat.completions.create(**params) - - params["extra_body"] = {"prompt_logprobs": 2} - completion_2 = await client.chat.completions.create(**params) - - assert len(completion_1.prompt_logprobs[3]) == 1 - assert len(completion_2.prompt_logprobs[3]) == 2 - - @pytest.mark.asyncio @pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1), (MODEL_NAME, 0), @@ -361,7 +275,7 @@ async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI, (MODEL_NAME, None)]) async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, model_name: str, - prompt_logprobs: int): + prompt_logprobs: Optional[int]): params: Dict = { "prompt": ["A robot may not injure another robot", "My name is"], "model": model_name, @@ -369,17 +283,12 @@ async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI, if prompt_logprobs is not None: params["extra_body"] = {"prompt_logprobs": prompt_logprobs} - if prompt_logprobs and prompt_logprobs < 0: - with pytest.raises(BadRequestError) as err_info: + if prompt_logprobs is not None and prompt_logprobs < 0: + with pytest.raises(BadRequestError): await client.completions.create(**params) - expected_err_string = ( - "Error code: 400 - {'object': 'error', 'message': " - "'Prompt_logprobs set to invalid negative value: -1'," - " 'type': 'BadRequestError', 'param': None, 'code': 400}") - assert str(err_info.value) == expected_err_string else: completion = await client.completions.create(**params) - if prompt_logprobs and prompt_logprobs > 0: + if prompt_logprobs is not None: assert completion.choices[0].prompt_logprobs is not None assert len(completion.choices[0].prompt_logprobs) > 0 diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6385d3ca2297e..b33c19e97141a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -6,7 +6,6 @@ Optional, Set, Tuple, Type, Union) import torch -from transformers import PreTrainedTokenizer from typing_extensions import assert_never import vllm.envs as envs @@ -31,6 +30,7 @@ from vllm.sampling_params import SamplingParams from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceGroupMetadata) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -427,8 +427,8 @@ async def _tokenize_prompt_async( lora_request: Optional[LoRARequest], ) -> List[int]: """Async version of :meth:`_tokenize_prompt`.""" - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + missing_msg="prompts must be None if skip_tokenizer_init is True") return await tokenizer.encode_async(request_id=request_id, prompt=prompt, @@ -771,7 +771,7 @@ def _error_callback(self, exc: Exception) -> None: async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, - ) -> "PreTrainedTokenizer": + ) -> AnyTokenizer: if self.engine_use_ray: return await self.engine.get_tokenizer.remote( # type: ignore lora_request) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 36cb6ce795f3e..94aed6b8c50c7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,9 +3,9 @@ from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence -from typing import Set, Tuple, Type, TypeVar, Union +from typing import Set, Tuple, Type, Union -from typing_extensions import assert_never +from typing_extensions import TypeVar, assert_never import vllm.envs as envs from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, @@ -43,8 +43,9 @@ init_tracer) from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import ( - AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs) + BaseTokenizerGroup, init_tokenizer_from_configs) from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, usage_message) from vllm.utils import Counter, Device @@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: return config.to_diff_dict() +_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup) _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) PromptComponents = Tuple[Optional[str], List[int], @@ -493,12 +495,21 @@ def __del__(self): "skip_tokenizer_init is True") def get_tokenizer_group( - self, - fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup: - if self.tokenizer is None: - raise ValueError(fail_msg) + self, + group_type: Type[_G] = BaseTokenizerGroup, + *, + missing_msg: str = MISSING_TOKENIZER_GROUP_MSG, + ) -> _G: + tokenizer_group = self.tokenizer + + if tokenizer_group is None: + raise ValueError(missing_msg) + if not isinstance(tokenizer_group, group_type): + raise TypeError("Invalid type of tokenizer group. " + f"Expected type: {group_type}, but " + f"found type: {type(tokenizer_group)}") - return self.tokenizer + return tokenizer_group def get_tokenizer( self, @@ -693,8 +704,8 @@ def _tokenize_prompt( * prompt token ids ''' - tokenizer = self.get_tokenizer_group("prompts must be None if " - "skip_tokenizer_init is True") + tokenizer = self.get_tokenizer_group( + missing_msg="prompts must be None if skip_tokenizer_init is True") return tokenizer.encode(request_id=request_id, prompt=prompt, diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 92aecebe6ec38..a385f37d807ad 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,13 +1,12 @@ from abc import ABC, abstractmethod from typing import Callable, List -from transformers import PreTrainedTokenizer - from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.stop_checker import StopChecker from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter @@ -29,7 +28,7 @@ def create_output_processor( detokenizer: Detokenizer, scheduler: List[Scheduler], seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], stop_checker: "StopChecker", ): """Create an output processor. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 25d15df9f915d..6c472528a7a9c 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,8 +1,6 @@ import functools from typing import Callable, List -from transformers import PreTrainedTokenizer - from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) @@ -12,6 +10,7 @@ from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import Counter logger = init_logger(__name__) @@ -36,7 +35,7 @@ def __init__( detokenizer: Detokenizer, scheduler: List[Scheduler], seq_counter: Counter, - get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer], + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer], stop_checker: StopChecker, ): self.detokenizer = detokenizer diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 96f0d1142611b..0c5f8fb7f5be7 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -1,10 +1,9 @@ from typing import Callable, Optional -from transformers import PreTrainedTokenizer - from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.sequence import Sequence, SequenceStatus +from vllm.transformers_utils.tokenizer import AnyTokenizer class StopChecker: @@ -15,8 +14,7 @@ class StopChecker: """ def __init__(self, max_model_len: int, - get_tokenizer_for_seq: Callable[[Sequence], - PreTrainedTokenizer]): + get_tokenizer_for_seq: Callable[[Sequence], AnyTokenizer]): # Do not use it directly, but use `self._get_max_model_len`. self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index e05c01fa8d6c3..cb16775a1cd59 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,8 +1,6 @@ from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, runtime_checkable) -from transformers import PreTrainedTokenizer - from vllm.config import DecodingConfig, ModelConfig from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptInputs @@ -12,6 +10,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput +from vllm.transformers_utils.tokenizer import AnyTokenizer @runtime_checkable @@ -40,6 +39,7 @@ def generate( prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncGenerator[RequestOutput, None]: """Generates outputs for a request""" + ... def encode( self, @@ -50,6 +50,7 @@ def encode( trace_headers: Optional[Mapping[str, str]] = None, ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model.""" + ... async def abort(self, request_id: str) -> None: """Abort a request. @@ -60,25 +61,29 @@ async def abort(self, request_id: str) -> None: async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" + ... async def get_decoding_config(self) -> DecodingConfig: + ... """Get the decoding configuration of the vLLM engine.""" async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, - ) -> PreTrainedTokenizer: - """Get the appropriate Tokenizer for the request""" + ) -> AnyTokenizer: + """Get the appropriate tokenizer for the request""" + ... async def is_tracing_enabled(self) -> bool: - pass + ... async def do_log_stats( self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, ) -> None: - pass + ... async def check_health(self) -> None: """Raise if unhealthy""" + ... diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f6e8a417b648c..6127177b4d889 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -61,6 +61,7 @@ async def generate(request: Request) -> Response: async def stream_results() -> AsyncGenerator[bytes, None]: async for request_output in results_generator: prompt = request_output.prompt + assert prompt is not None text_outputs = [ prompt + output.text for output in request_output.outputs ] @@ -80,6 +81,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: assert final_output is not None prompt = final_output.prompt + assert prompt is not None text_outputs = [prompt + output.text for output in final_output.outputs] ret = {"text": text_outputs} return JSONResponse(ret) @@ -115,6 +117,7 @@ async def run_server(args: Namespace, logger.info("args: %s", args) app = await init_app(args, llm_engine) + assert engine is not None shutdown_task = await serve_http( app, diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4a0b0f879e8ef..48fd1333d8f40 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -3,7 +3,7 @@ from functools import lru_cache from pathlib import Path from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, - Union, cast) + Union) # yapf conflicts with isort for this block # yapf: disable @@ -15,9 +15,8 @@ ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions -from pydantic import ConfigDict -from transformers import PreTrainedTokenizer -from typing_extensions import Required, TypedDict +from pydantic import ConfigDict, TypeAdapter +from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger @@ -50,9 +49,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): """The type of the content part.""" -ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam, - ChatCompletionContentPartAudioParam, - CustomChatCompletionContentPartParam] +ChatCompletionContentPartParam: TypeAlias = Union[ + OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, + CustomChatCompletionContentPartParam, ] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -114,7 +113,7 @@ def load_chat_template( @lru_cache(maxsize=None) -def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer, +def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, modality: Literal["image", "audio"]) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) @@ -151,11 +150,16 @@ def _get_full_multimodal_text_prompt(placeholder_token_str: str, return f"{placeholder_token_str}\n{text_prompt}" +_TextParser = TypeAdapter(ChatCompletionContentPartTextParam) +_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam) +_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam) + + def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> ChatMessageParseResult: texts: List[str] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -164,7 +168,7 @@ def _parse_chat_message_content_parts( for part in parts: part_type = part["type"] if part_type == "text": - text = cast(ChatCompletionContentPartTextParam, part)["text"] + text = _TextParser.validate_python(part)["text"] texts.append(text) elif part_type == "image_url": modality = "image" @@ -172,8 +176,7 @@ def _parse_chat_message_content_parts( raise NotImplementedError( "Multiple multimodal inputs is currently not supported.") - image_url = cast(ChatCompletionContentPartImageParam, - part)["image_url"] + image_url = _ImageParser.validate_python(part)["image_url"] if image_url.get("detail", "auto") != "auto": logger.warning( @@ -188,8 +191,7 @@ def _parse_chat_message_content_parts( raise NotImplementedError( "Multiple multimodal inputs is currently not supported.") - audio_url = cast(ChatCompletionContentPartAudioParam, - part)["audio_url"] + audio_url = _AudioParser.validate_python(part)["audio_url"] audio_future = async_get_and_parse_audio(audio_url["url"]) mm_futures.append(audio_future) else: @@ -219,7 +221,7 @@ def _parse_chat_message_content_parts( def _parse_chat_message_content( message: ChatCompletionMessageParam, model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> ChatMessageParseResult: role = message["role"] content = message.get("content") @@ -230,14 +232,18 @@ def _parse_chat_message_content( messages = [ConversationMessage(role=role, content=content)] return ChatMessageParseResult(messages=messages, mm_futures=[]) - return _parse_chat_message_content_parts(role, content, model_config, - tokenizer) + return _parse_chat_message_content_parts( + role, + content, # type: ignore + model_config, + tokenizer, + ) def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: conversation: List[ConversationMessage] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index ecd6dc64d343b..372e96e3716aa 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,8 +1,7 @@ from contextlib import contextmanager from typing import ClassVar, List, Optional, Sequence, Union, cast, overload -from tqdm.auto import tqdm -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from tqdm import tqdm from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine @@ -20,7 +19,9 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from vllm.transformers_utils.tokenizer import (AnyTokenizer, + get_cached_tokenizer) +from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_kwargs @@ -122,7 +123,7 @@ def __init__( tokenizer_revision: Optional[str] = None, seed: int = 0, gpu_memory_utilization: float = 0.9, - swap_space: int = 4, + swap_space: float = 4, cpu_offload_gb: float = 0, enforce_eager: Optional[bool] = None, max_context_len_to_capture: Optional[int] = None, @@ -175,22 +176,19 @@ def __init__( engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() - def get_tokenizer( - self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - return self.llm_engine.tokenizer.tokenizer + def get_tokenizer(self) -> AnyTokenizer: + return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer + + def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: + tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup) - def set_tokenizer( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - ) -> None: # While CachedTokenizer is dynamic, have no choice but # compare class name. Misjudgment will arise from # user-defined tokenizer started with 'Cached' if tokenizer.__class__.__name__.startswith("Cached"): - self.llm_engine.tokenizer.tokenizer = tokenizer + tokenizer_group.tokenizer = tokenizer else: - self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( - tokenizer) + tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) @overload # LEGACY: single (prompt + optional token ids) def generate( @@ -578,6 +576,8 @@ def _convert_v1_inputs( inputs: List[PromptInputs] = [] for i in range(num_requests): + item: PromptInputs + if prompts is not None: item = TextPrompt(prompt=prompts[i]) elif prompt_token_ids is not None: @@ -635,7 +635,7 @@ def _add_request( self, inputs: PromptInputs, params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: request_id = str(next(self.request_counter)) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index d79238e08d540..f37c7f4d91d57 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -15,6 +15,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, Response, StreamingResponse from starlette.routing import Mount +from typing_extensions import assert_never import vllm.envs as envs from vllm.config import ModelConfig @@ -29,14 +30,16 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, + CompletionResponse, DetokenizeRequest, DetokenizeResponse, - EmbeddingRequest, ErrorResponse, + EmbeddingRequest, + EmbeddingResponse, ErrorResponse, TokenizeRequest, TokenizeResponse) +# yapf: enable from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.server import run_rpc_server -# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -90,7 +93,8 @@ async def _force_log(): @asynccontextmanager -async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: +async def build_async_engine_client( + args: Namespace) -> AsyncIterator[AsyncEngineClient]: # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit global engine_args @@ -142,12 +146,15 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: logger.info("Started engine process with PID %d", rpc_server_process.pid) # Build RPCClient, which conforms to AsyncEngineClient Protocol. - async_engine_client = AsyncEngineRPCClient(rpc_path) + # NOTE: Actually, this is not true yet. We still need to support + # embedding models via RPC (see TODO above) + rpc_client = AsyncEngineRPCClient(rpc_path) + async_engine_client = rpc_client # type: ignore try: while True: try: - await async_engine_client.setup() + await rpc_client.setup() break except TimeoutError as e: if not rpc_server_process.is_alive(): @@ -161,7 +168,7 @@ async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]: rpc_server_process.terminate() # Close all open connections to the backend - async_engine_client.close() + rpc_client.close() # Wait for server process to join rpc_server_process.join() @@ -216,10 +223,11 @@ async def tokenize(request: TokenizeRequest): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: - assert isinstance(generator, TokenizeResponse) + elif isinstance(generator, TokenizeResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + @router.post("/detokenize") async def detokenize(request: DetokenizeRequest): @@ -227,10 +235,11 @@ async def detokenize(request: DetokenizeRequest): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: - assert isinstance(generator, DetokenizeResponse) + elif isinstance(generator, DetokenizeResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + @router.get("/v1/models") async def show_available_models(): @@ -252,13 +261,11 @@ async def create_chat_completion(request: ChatCompletionRequest, if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, - media_type="text/event-stream") - else: - assert isinstance(generator, ChatCompletionResponse) + elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + @router.post("/v1/completions") async def create_completion(request: CompletionRequest, raw_request: Request): @@ -267,12 +274,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - if request.stream: - return StreamingResponse(content=generator, - media_type="text/event-stream") - else: + elif isinstance(generator, CompletionResponse): return JSONResponse(content=generator.model_dump()) + return StreamingResponse(content=generator, media_type="text/event-stream") + @router.post("/v1/embeddings") async def create_embedding(request: EmbeddingRequest, raw_request: Request): @@ -281,9 +287,11 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) - else: + elif isinstance(generator, EmbeddingResponse): return JSONResponse(content=generator.model_dump()) + assert_never(generator) + def build_app(args: Namespace) -> FastAPI: app = FastAPI(lifespan=lifespan) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 1facedac72ca8..94742838b421c 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -7,6 +7,7 @@ import argparse import json import ssl +from typing import List, Optional, Sequence, Union from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, @@ -16,8 +17,19 @@ class LoRAParserAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - lora_list = [] + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + lora_list: List[LoRAModulePath] = [] for item in values: name, path = item.split('=') lora_list.append(LoRAModulePath(name, path)) @@ -26,8 +38,19 @@ def __call__(self, parser, namespace, values, option_string=None): class PromptAdapterParserAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - adapter_list = [] + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Optional[Union[str, Sequence[str]]], + option_string: Optional[str] = None, + ): + if values is None: + values = [] + if isinstance(values, str): + raise TypeError("Expected values to be a list") + + adapter_list: List[PromptAdapterPath] = [] for item in values: name, path = item.split('=') adapter_list.append(PromptAdapterPath(name, path)) diff --git a/vllm/entrypoints/openai/logits_processors.py b/vllm/entrypoints/openai/logits_processors.py index c470c32c27ede..7913f8720ca73 100644 --- a/vllm/entrypoints/openai/logits_processors.py +++ b/vllm/entrypoints/openai/logits_processors.py @@ -2,9 +2,9 @@ from typing import Dict, FrozenSet, Iterable, List, Optional, Union import torch -from transformers import PreTrainedTokenizer from vllm.sampling_params import LogitsProcessor +from vllm.transformers_utils.tokenizer import AnyTokenizer class AllowedTokenIdsLogitsProcessor: @@ -51,10 +51,11 @@ def logit_bias_logits_processor( def get_logits_processors( - logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], - allowed_token_ids: Optional[List[int]], - tokenizer: PreTrainedTokenizer) -> List[LogitsProcessor]: - logits_processors = [] + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]], + allowed_token_ids: Optional[List[int]], + tokenizer: AnyTokenizer, +) -> List[LogitsProcessor]: + logits_processors: List[LogitsProcessor] = [] if logit_bias: try: # Convert token_id to integer diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index aef42e9425ef5..c46f5cf8ce663 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -6,7 +6,6 @@ import torch from pydantic import BaseModel, ConfigDict, Field, model_validator -from transformers import PreTrainedTokenizer from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -14,11 +13,13 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import random_uuid # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) +_LONG_INFO: Union["torch.iinfo", Namespace] try: from sphinx.ext.autodoc.mock import _MockModule @@ -235,13 +236,17 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, + self, tokenizer: AnyTokenizer, guided_decode_logits_processor: Optional[LogitsProcessor], default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.top_logprobs + # We now allow logprobs being true without top_logrobs. logits_processors = get_logits_processors( logit_bias=self.logit_bias, @@ -251,7 +256,7 @@ def to_sampling_params( if guided_decode_logits_processor: logits_processors.append(guided_decode_logits_processor) - return SamplingParams( + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, presence_penalty=self.presence_penalty, @@ -265,8 +270,7 @@ def to_sampling_params( stop=self.stop, stop_token_ids=self.stop_token_ids, logprobs=self.top_logprobs if self.logprobs else None, - prompt_logprobs=self.prompt_logprobs if self.prompt_logprobs else - (self.top_logprobs if self.echo else None), + prompt_logprobs=prompt_logprobs, ignore_eos=self.ignore_eos, max_tokens=max_tokens, min_tokens=self.min_tokens, @@ -280,14 +284,36 @@ def to_sampling_params( truncate_prompt_tokens=self.truncate_prompt_tokens, ) - @model_validator(mode='before') + @model_validator(mode="before") @classmethod - def validate_stream_options(cls, values): - if (values.get('stream_options') is not None - and not values.get('stream')): + def validate_stream_options(cls, data): + if data.get("stream_options") and not data.get("stream"): raise ValueError( - "stream_options can only be set if stream is true") - return values + "Stream options can only be defined when `stream=True`.") + + return data + + @model_validator(mode="before") + @classmethod + def check_logprobs(cls, data): + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (top_logprobs := data.get("top_logprobs")) is not None: + if top_logprobs < 0: + raise ValueError("`top_logprobs` must be a positive value.") + + if not data.get("logprobs"): + raise ValueError( + "when using `top_logprobs`, `logprobs` must be set to true." + ) + + return data @model_validator(mode="before") @classmethod @@ -320,19 +346,6 @@ def check_tool_choice(cls, data): "When using `tool_choice`, `tools` must be set.") return data - @model_validator(mode="before") - @classmethod - def check_logprobs(cls, data): - if "top_logprobs" in data and data["top_logprobs"] is not None: - if "logprobs" not in data or data["logprobs"] is False: - raise ValueError( - "when using `top_logprobs`, `logprobs` must be set to true." - ) - elif data["top_logprobs"] < 0: - raise ValueError( - "`top_logprobs` must be a value a positive value.") - return data - class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation @@ -422,13 +435,17 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, + self, tokenizer: AnyTokenizer, guided_decode_logits_processor: Optional[LogitsProcessor], default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens + prompt_logprobs = self.prompt_logprobs + if prompt_logprobs is None and self.echo: + prompt_logprobs = self.logprobs + echo_without_generation = self.echo and self.max_tokens == 0 logits_processors = get_logits_processors( @@ -439,7 +456,7 @@ def to_sampling_params( if guided_decode_logits_processor: logits_processors.append(guided_decode_logits_processor) - return SamplingParams( + return SamplingParams.from_optional( n=self.n, best_of=self.best_of, presence_penalty=self.presence_penalty, @@ -458,8 +475,7 @@ def to_sampling_params( min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, - prompt_logprobs=self.prompt_logprobs - if self.prompt_logprobs else self.logprobs if self.echo else None, + prompt_logprobs=prompt_logprobs, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, @@ -485,9 +501,17 @@ def check_guided_decoding_count(cls, data): @model_validator(mode="before") @classmethod def check_logprobs(cls, data): - if "logprobs" in data and data[ - "logprobs"] is not None and not data["logprobs"] >= 0: - raise ValueError("if passed, `logprobs` must be a positive value.") + if (prompt_logprobs := data.get("prompt_logprobs")) is not None: + if data.get("stream") and prompt_logprobs > 0: + raise ValueError( + "`prompt_logprobs` are not available when `stream=True`.") + + if prompt_logprobs < 0: + raise ValueError("`prompt_logprobs` must be a positive value.") + + if (logprobs := data.get("logprobs")) is not None and logprobs < 0: + raise ValueError("`logprobs` must be a positive value.") + return data @model_validator(mode="before") @@ -495,7 +519,8 @@ def check_logprobs(cls, data): def validate_stream_options(cls, data): if data.get("stream_options") and not data.get("stream"): raise ValueError( - "Stream options can only be defined when stream is true.") + "Stream options can only be defined when `stream=True`.") + return data @@ -504,7 +529,7 @@ class EmbeddingRequest(OpenAIBaseModel): # https://platform.openai.com/docs/api-reference/embeddings model: str input: Union[List[int], List[List[int]], str, List[str]] - encoding_format: Optional[str] = Field('float', pattern='^(float|base64)$') + encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 471d62631135a..770ee77926df9 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -23,8 +23,8 @@ class AsyncEngineRPCServer: def __init__(self, async_engine_args: AsyncEngineArgs, usage_context: UsageContext, rpc_path: str): # Initialize engine first. - self.engine = AsyncLLMEngine.from_engine_args(async_engine_args, - usage_context) + self.engine = AsyncLLMEngine.from_engine_args( + async_engine_args, usage_context=usage_context) # Initialize context. self.context = zmq.asyncio.Context() @@ -39,7 +39,7 @@ def cleanup(self): self.context.destroy() self.engine.shutdown_background_loop() # Clear the engine reference so that it can be GC'ed. - self.engine = None + del self.engine async def get_model_config(self, identity): """Send the ModelConfig""" diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 08209d44d207c..4d8e240a88ee6 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,11 +1,10 @@ import asyncio import time -from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional +from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional from typing import Sequence as GenericSequence from typing import Union from fastapi import Request -from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient @@ -24,13 +23,14 @@ from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath) -from vllm.inputs import PromptInputs +from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import iterate_with_cancellation, random_uuid logger = init_logger(__name__) @@ -67,9 +67,9 @@ def __init__( async def create_chat_completion( self, request: ChatCompletionRequest, - raw_request: Optional[Request] = None - ) -> Union[ErrorResponse, AsyncGenerator[str, None], - ChatCompletionResponse]: + raw_request: Optional[Request] = None, + ) -> Union[AsyncGenerator[str, None], ChatCompletionResponse, + ErrorResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/chat/create @@ -83,16 +83,6 @@ async def create_chat_completion( if error_check_ret is not None: return error_check_ret - if request.prompt_logprobs is not None: - if request.stream and request.prompt_logprobs > 0: - return self.create_error_response( - "Prompt_logprobs are not available when stream is enabled") - - if request.prompt_logprobs < 0: - return self.create_error_response( - f"Prompt_logprobs set to invalid " - f"negative value: {request.prompt_logprobs}") - try: ( lora_request, @@ -160,9 +150,8 @@ async def create_chat_completion( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request) - engine_inputs: PromptInputs = { - "prompt_token_ids": prompt_inputs["prompt_token_ids"], - } + engine_inputs = TokensPrompt( + prompt_token_ids=prompt_inputs["prompt_token_ids"]) if mm_data is not None: engine_inputs["multi_modal_data"] = mm_data @@ -214,11 +203,11 @@ async def chat_completion_stream_generator( result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) - chunk_object_type = "chat.completion.chunk" + chunk_object_type: Final = "chat.completion.chunk" first_iteration = True # Send response for each token for each request.n (index) @@ -438,7 +427,7 @@ async def chat_completion_full_generator( result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage], - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> Union[ErrorResponse, ChatCompletionResponse]: model_name = self.served_model_names[0] @@ -523,7 +512,7 @@ async def chat_completion_full_generator( def _get_top_logprobs( self, logprobs: Dict[int, Logprob], top_logprobs: Optional[int], - tokenizer: PreTrainedTokenizer) -> List[ChatCompletionLogProb]: + tokenizer: AnyTokenizer) -> List[ChatCompletionLogProb]: return [ ChatCompletionLogProb(token=(token := self._get_decoded_token( p[1], @@ -541,12 +530,11 @@ def _create_chat_logprobs( self, token_ids: GenericSequence[int], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, num_output_top_logprobs: Optional[int] = None, ) -> ChatCompletionLogProbs: """Create OpenAI-style logprobs.""" - - logprobs_content = [] + logprobs_content: List[ChatCompletionLogProbsContent] = [] for i, token_id in enumerate(token_ids): step_top_logprobs = top_logprobs[i] @@ -554,23 +542,32 @@ def _create_chat_logprobs( token = tokenizer.decode(token_id) if self.return_tokens_as_token_ids: token = f"token_id:{token_id}" + logprobs_content.append( ChatCompletionLogProbsContent( token=token, - bytes=list(token.encode("utf-8", errors="replace")))) + bytes=list(token.encode("utf-8", errors="replace")), + )) else: + step_token = step_top_logprobs[token_id] + step_decoded = step_token.decoded_token + logprobs_content.append( ChatCompletionLogProbsContent( token=self._get_decoded_token( - step_top_logprobs[token_id], token_id, tokenizer, - self.return_tokens_as_token_ids), - logprob=max(step_top_logprobs[token_id].logprob, - -9999.0), - bytes=list( - step_top_logprobs[token_id].decoded_token.encode( - "utf-8", errors="replace")), + step_token, + token_id, + tokenizer, + self.return_tokens_as_token_ids, + ), + logprob=max(step_token.logprob, -9999.0), + bytes=None if step_decoded is None else list( + step_decoded.encode("utf-8", errors="replace")), top_logprobs=self._get_top_logprobs( - step_top_logprobs, num_output_top_logprobs, - tokenizer))) + step_top_logprobs, + num_output_top_logprobs, + tokenizer, + ), + )) return ChatCompletionLogProbs(content=logprobs_content) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 24206b59cf5e6..34f1200753f8d 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -3,10 +3,9 @@ from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional) from typing import Sequence as GenericSequence -from typing import Tuple, cast +from typing import Tuple, Union, cast from fastapi import Request -from transformers import PreTrainedTokenizer from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient @@ -19,7 +18,7 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - UsageInfo) + ErrorResponse, UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, @@ -29,6 +28,7 @@ from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) +from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -60,8 +60,11 @@ def __init__( request_logger=request_logger, return_tokens_as_token_ids=return_tokens_as_token_ids) - async def create_completion(self, request: CompletionRequest, - raw_request: Request): + async def create_completion( + self, + request: CompletionRequest, + raw_request: Request, + ) -> Union[AsyncGenerator[str, None], CompletionResponse, ErrorResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/completions/create @@ -84,15 +87,6 @@ async def create_completion(self, request: CompletionRequest, request_id = f"cmpl-{random_uuid()}" created_time = int(time.time()) - if request.prompt_logprobs is not None: - if request.stream and request.prompt_logprobs > 0: - return self.create_error_response( - "Prompt_logprobs are not available when stream is enabled") - elif request.prompt_logprobs < 0: - return self.create_error_response( - f"Prompt_logprobs set to invalid negative " - f"value: {request.prompt_logprobs}") - # Schedule the request and get the result generator. generators: List[AsyncGenerator[RequestOutput, None]] = [] try: @@ -153,9 +147,8 @@ async def create_completion(self, request: CompletionRequest, # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator: AsyncIterator[Tuple[ - int, RequestOutput]] = merge_async_iterators( - *generators, is_cancelled=raw_request.is_disconnected) + result_generator = merge_async_iterators( + *generators, is_cancelled=raw_request.is_disconnected) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use @@ -227,7 +220,7 @@ async def completion_stream_generator( created_time: int, model_name: str, num_prompts: int, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> AsyncGenerator[str, None]: num_choices = 1 if request.n is None else request.n previous_texts = [""] * num_choices * num_prompts @@ -236,6 +229,13 @@ async def completion_stream_generator( try: async for prompt_idx, res in result_generator: + prompt_token_ids = res.prompt_token_ids + prompt_logprobs = res.prompt_logprobs + prompt_text = res.prompt + + delta_token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[ + int, Logprob]]]] for output in res.outputs: i = output.index + prompt_idx * num_choices @@ -244,19 +244,25 @@ async def completion_stream_generator( assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_text is not None # only return the prompt - delta_text = res.prompt - delta_token_ids = res.prompt_token_ids - out_logprobs = res.prompt_logprobs + delta_text = prompt_text + delta_token_ids = prompt_token_ids + out_logprobs = prompt_logprobs has_echoed[i] = True elif (request.echo and request.max_tokens > 0 and not has_echoed[i]): + assert prompt_text is not None + assert prompt_logprobs is not None # echo the prompt and first token - delta_text = res.prompt + output.text - delta_token_ids = (res.prompt_token_ids + - output.token_ids) - out_logprobs = res.prompt_logprobs + (output.logprobs - or []) + delta_text = prompt_text + output.text + delta_token_ids = [ + *prompt_token_ids, *output.token_ids + ] + out_logprobs = [ + *prompt_logprobs, + *(output.logprobs or []), + ] has_echoed[i] = True else: # return just the delta @@ -301,7 +307,7 @@ async def completion_stream_generator( and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats or output.finish_reason is not None): - prompt_tokens = len(res.prompt_token_ids) + prompt_tokens = len(prompt_token_ids) completion_tokens = len(output.token_ids) usage = UsageInfo( prompt_tokens=prompt_tokens, @@ -342,7 +348,7 @@ def request_output_to_completion_response( request_id: str, created_time: int, model_name: str, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, ) -> CompletionResponse: choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 @@ -353,16 +359,31 @@ def request_output_to_completion_response( prompt_logprobs = final_res.prompt_logprobs prompt_text = final_res.prompt + token_ids: GenericSequence[int] + out_logprobs: Optional[GenericSequence[Optional[Dict[int, + Logprob]]]] + for output in final_res.outputs: assert request.max_tokens is not None if request.echo and request.max_tokens == 0: + assert prompt_text is not None token_ids = prompt_token_ids out_logprobs = prompt_logprobs output_text = prompt_text elif request.echo and request.max_tokens > 0: - token_ids = prompt_token_ids + list(output.token_ids) - out_logprobs = (prompt_logprobs + output.logprobs - if request.logprobs is not None else None) + assert prompt_text is not None + token_ids = [*prompt_token_ids, *output.token_ids] + + if request.logprobs is None: + out_logprobs = None + else: + assert prompt_logprobs is not None + assert output.logprobs is not None + out_logprobs = [ + *prompt_logprobs, + *output.logprobs, + ] + output_text = prompt_text + output.text else: token_ids = output.token_ids @@ -413,7 +434,7 @@ def _create_completion_logprobs( token_ids: GenericSequence[int], top_logprobs: GenericSequence[Optional[Dict[int, Logprob]]], num_output_top_logprobs: int, - tokenizer: PreTrainedTokenizer, + tokenizer: AnyTokenizer, initial_text_offset: int = 0, ) -> CompletionLogProbs: """Create logprobs for OpenAI Completion API.""" @@ -430,17 +451,21 @@ def _create_completion_logprobs( token = tokenizer.decode(token_id) if self.return_tokens_as_token_ids: token = f"token_id:{token_id}" + out_tokens.append(token) out_token_logprobs.append(None) out_top_logprobs.append(None) else: + step_token = step_top_logprobs[token_id] + token = self._get_decoded_token( - step_top_logprobs[token_id], + step_token, token_id, tokenizer, - return_as_token_id=self.return_tokens_as_token_ids) - token_logprob = max(step_top_logprobs[token_id].logprob, - -9999.0) + return_as_token_id=self.return_tokens_as_token_ids, + ) + token_logprob = max(step_token.logprob, -9999.0) + out_tokens.append(token) out_token_logprobs.append(token_logprob) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 0dc3c3bc7d154..b0f70ff43e228 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,11 +1,11 @@ import asyncio import base64 import time -from typing import (AsyncGenerator, AsyncIterator, List, Optional, Tuple, - Union, cast) +from typing import AsyncGenerator, List, Literal, Optional, Union, cast import numpy as np from fastapi import Request +from typing_extensions import assert_never from vllm.config import ModelConfig from vllm.engine.protocol import AsyncEngineClient @@ -16,7 +16,7 @@ ErrorResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger -from vllm.outputs import EmbeddingRequestOutput +from vllm.outputs import EmbeddingOutput, EmbeddingRequestOutput from vllm.utils import merge_async_iterators, random_uuid logger = init_logger(__name__) @@ -24,18 +24,28 @@ TypeTokenIDs = List[int] +def _get_embedding( + output: EmbeddingOutput, + encoding_format: Literal["float", "base64"], +) -> Union[List[float], str]: + if encoding_format == "float": + return output.embedding + elif encoding_format == "base64": + embedding_bytes = np.array(output.embedding).tobytes() + return base64.b64encode(embedding_bytes).decode("utf-8") + + assert_never(encoding_format) + + def request_output_to_embedding_response( final_res_batch: List[EmbeddingRequestOutput], request_id: str, created_time: int, model_name: str, - encoding_format: str) -> EmbeddingResponse: + encoding_format: Literal["float", "base64"]) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): prompt_token_ids = final_res.prompt_token_ids - embedding = final_res.outputs.embedding - if encoding_format == "base64": - embedding_bytes = np.array(embedding).tobytes() - embedding = base64.b64encode(embedding_bytes).decode("utf-8") + embedding = _get_embedding(final_res.outputs, encoding_format) embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) @@ -76,8 +86,8 @@ def __init__( async def create_embedding( self, request: EmbeddingRequest, - raw_request: Optional[Request] = None - ) -> Union[ErrorResponse, EmbeddingResponse]: + raw_request: Optional[Request] = None, + ) -> Union[EmbeddingResponse, ErrorResponse]: """Completion API similar to OpenAI's API. See https://platform.openai.com/docs/api-reference/embeddings/create @@ -89,8 +99,7 @@ async def create_embedding( if error_check_ret is not None: return error_check_ret - encoding_format = (request.encoding_format - if request.encoding_format else "float") + encoding_format = request.encoding_format if request.dimensions is not None: return self.create_error_response( "dimensions is currently not supported") @@ -145,11 +154,10 @@ async def create_embedding( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - result_generator: AsyncIterator[Tuple[ - int, EmbeddingRequestOutput]] = merge_async_iterators( - *generators, - is_cancelled=raw_request.is_disconnected - if raw_request else None) + result_generator = merge_async_iterators( + *generators, + is_cancelled=raw_request.is_disconnected if raw_request else None, + ) # Non-streaming response final_res_batch: List[Optional[EmbeddingRequestOutput]] @@ -175,7 +183,7 @@ async def create_embedding( return response - def _check_embedding_mode(self, embedding_mode: bool): + def _check_embedding_mode(self, embedding_mode: bool) -> bool: if not embedding_mode: logger.warning( "embedding_mode is False. Embedding API will not work.") diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d8b5ea4bdf5d..26e91e7cc94dd 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -31,7 +31,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob -from vllm.transformers_utils.tokenizer_group import AnyTokenizer +from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 7197b51398538..c83ed5cca6791 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -153,6 +153,68 @@ class SamplingParams( output_text_buffer_length: int = 0 _all_stop_token_ids: Set[int] = msgspec.field(default_factory=set) + @staticmethod + def from_optional( + n: Optional[int] = 1, + best_of: Optional[int] = None, + presence_penalty: Optional[float] = 0.0, + frequency_penalty: Optional[float] = 0.0, + repetition_penalty: Optional[float] = 1.0, + temperature: Optional[float] = 1.0, + top_p: Optional[float] = 1.0, + top_k: int = -1, + min_p: float = 0.0, + seed: Optional[int] = None, + use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + ignore_eos: bool = False, + max_tokens: Optional[int] = 16, + min_tokens: int = 0, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + detokenize: bool = True, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = None, + truncate_prompt_tokens: Optional[Annotated[int, + msgspec.Meta(ge=1)]] = None, + ) -> "SamplingParams": + return SamplingParams( + n=1 if n is None else n, + best_of=best_of, + presence_penalty=0.0 + if presence_penalty is None else presence_penalty, + frequency_penalty=0.0 + if frequency_penalty is None else frequency_penalty, + repetition_penalty=1.0 + if repetition_penalty is None else repetition_penalty, + temperature=1.0 if temperature is None else temperature, + top_p=1.0 if top_p is None else top_p, + top_k=top_k, + min_p=min_p, + seed=seed, + use_beam_search=use_beam_search, + length_penalty=length_penalty, + early_stopping=early_stopping, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_stop_str_in_output, + ignore_eos=ignore_eos, + max_tokens=max_tokens, + min_tokens=min_tokens, + logprobs=logprobs, + prompt_logprobs=prompt_logprobs, + detokenize=detokenize, + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + logits_processors=logits_processors, + truncate_prompt_tokens=truncate_prompt_tokens, + ) + def __post_init__(self) -> None: self.best_of = self.best_of or self.n if 0 < self.temperature < _MAX_TEMP: