Skip to content

Commit

Permalink
[mypy] Enable following imports for entrypoints (vllm-project#7248)
Browse files Browse the repository at this point in the history
Co-authored-by: Woosuk Kwon <[email protected]>
Co-authored-by: Fei <[email protected]>
  • Loading branch information
3 people authored and omrishiv committed Aug 26, 2024
1 parent d377c3e commit 5c225f8
Show file tree
Hide file tree
Showing 26 changed files with 480 additions and 320 deletions.
1 change: 0 additions & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ jobs:
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ sphinx-argparse==0.4.0
msgspec

# packages to install to build the documentation
pydantic
pydantic >= 2.8
-f https://download.pytorch.org/whl/cpu
torch
py-cpuinfo
Expand Down
1 change: 0 additions & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/entrypoints --follow-imports skip
mypy vllm/executor --follow-imports skip
mypy vllm/lora --follow-imports skip
mypy vllm/model_executor --follow-imports skip
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ files = [
"vllm/*.py",
"vllm/adapter_commons",
"vllm/assets",
"vllm/entrypoints",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fastapi
aiohttp
openai >= 1.0 # Ensure modern openai package (ensure types module present)
uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
pydantic >= 2.8 # Required for OpenAI server.
pillow # Required for image processing
prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
Expand Down
84 changes: 83 additions & 1 deletion tests/entrypoints/openai/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# imports for guided decoding tests
import json
import re
from typing import List
from typing import Dict, List, Optional

import jsonschema
import openai # use the official client for correctness check
Expand Down Expand Up @@ -174,6 +174,88 @@ async def test_too_many_chat_logprobs(client: openai.AsyncOpenAI,
assert message.content is not None and len(message.content) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: Optional[int]):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}

if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}

if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.chat.completions.create(**params)
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs is not None:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}

completion_1 = await client.chat.completions.create(**params)

params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)

assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
Expand Down
101 changes: 5 additions & 96 deletions tests/entrypoints/openai/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import shutil
from tempfile import TemporaryDirectory
from typing import Dict, List
from typing import Dict, List, Optional

import jsonschema
import openai # use the official client for correctness check
Expand Down Expand Up @@ -268,118 +268,27 @@ async def test_too_many_completion_logprobs(client: openai.AsyncOpenAI,
assert len(completion.choices[0].text) >= 0


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, prompt_logprobs",
[(MODEL_NAME, 1), (MODEL_NAME, 0), (MODEL_NAME, -1), (MODEL_NAME, None)],
)
async def test_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str, prompt_logprobs: int):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name
}

if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}

if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
await client.chat.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.chat.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
assert completion.prompt_logprobs is not None
assert len(completion.prompt_logprobs) > 0
else:
assert completion.prompt_logprobs is None


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_more_than_one_prompt_logprobs_chat(client: openai.AsyncOpenAI,
model_name: str):
params: Dict = {
"messages": [{
"role": "system",
"content": "You are a helpful assistant."
}, {
"role": "user",
"content": "Who won the world series in 2020?"
}, {
"role":
"assistant",
"content":
"The Los Angeles Dodgers won the World Series in 2020."
}, {
"role": "user",
"content": "Where was it played?"
}],
"model":
model_name,
"extra_body": {
"prompt_logprobs": 1
}
}

completion_1 = await client.chat.completions.create(**params)

params["extra_body"] = {"prompt_logprobs": 2}
completion_2 = await client.chat.completions.create(**params)

assert len(completion_1.prompt_logprobs[3]) == 1
assert len(completion_2.prompt_logprobs[3]) == 2


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name, prompt_logprobs", [(MODEL_NAME, -1),
(MODEL_NAME, 0),
(MODEL_NAME, 1),
(MODEL_NAME, None)])
async def test_prompt_logprobs_completion(client: openai.AsyncOpenAI,
model_name: str,
prompt_logprobs: int):
prompt_logprobs: Optional[int]):
params: Dict = {
"prompt": ["A robot may not injure another robot", "My name is"],
"model": model_name,
}
if prompt_logprobs is not None:
params["extra_body"] = {"prompt_logprobs": prompt_logprobs}

if prompt_logprobs and prompt_logprobs < 0:
with pytest.raises(BadRequestError) as err_info:
if prompt_logprobs is not None and prompt_logprobs < 0:
with pytest.raises(BadRequestError):
await client.completions.create(**params)
expected_err_string = (
"Error code: 400 - {'object': 'error', 'message': "
"'Prompt_logprobs set to invalid negative value: -1',"
" 'type': 'BadRequestError', 'param': None, 'code': 400}")
assert str(err_info.value) == expected_err_string
else:
completion = await client.completions.create(**params)
if prompt_logprobs and prompt_logprobs > 0:
if prompt_logprobs is not None:
assert completion.choices[0].prompt_logprobs is not None
assert len(completion.choices[0].prompt_logprobs) > 0

Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Optional, Set, Tuple, Type, Union)

import torch
from transformers import PreTrainedTokenizer
from typing_extensions import assert_never

import vllm.envs as envs
Expand All @@ -31,6 +30,7 @@
from vllm.sampling_params import SamplingParams
from vllm.sequence import (ExecuteModelRequest, SamplerOutput,
SequenceGroupMetadata)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import print_warning_once

Expand Down Expand Up @@ -427,8 +427,8 @@ async def _tokenize_prompt_async(
lora_request: Optional[LoRARequest],
) -> List[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")

return await tokenizer.encode_async(request_id=request_id,
prompt=prompt,
Expand Down Expand Up @@ -771,7 +771,7 @@ def _error_callback(self, exc: Exception) -> None:
async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> "PreTrainedTokenizer":
) -> AnyTokenizer:
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote( # type: ignore
lora_request)
Expand Down
31 changes: 21 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, TypeVar, Union
from typing import Set, Tuple, Type, Union

from typing_extensions import assert_never
from typing_extensions import TypeVar, assert_never

import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
Expand Down Expand Up @@ -43,8 +43,9 @@
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device
Expand All @@ -67,6 +68,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
return config.to_diff_dict()


_G = TypeVar("_G", bound=BaseTokenizerGroup, default=BaseTokenizerGroup)
_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

PromptComponents = Tuple[Optional[str], List[int],
Expand Down Expand Up @@ -493,12 +495,21 @@ def __del__(self):
"skip_tokenizer_init is True")

def get_tokenizer_group(
self,
fail_msg: str = MISSING_TOKENIZER_GROUP_MSG) -> BaseTokenizerGroup:
if self.tokenizer is None:
raise ValueError(fail_msg)
self,
group_type: Type[_G] = BaseTokenizerGroup,
*,
missing_msg: str = MISSING_TOKENIZER_GROUP_MSG,
) -> _G:
tokenizer_group = self.tokenizer

if tokenizer_group is None:
raise ValueError(missing_msg)
if not isinstance(tokenizer_group, group_type):
raise TypeError("Invalid type of tokenizer group. "
f"Expected type: {group_type}, but "
f"found type: {type(tokenizer_group)}")

return self.tokenizer
return tokenizer_group

def get_tokenizer(
self,
Expand Down Expand Up @@ -693,8 +704,8 @@ def _tokenize_prompt(
* prompt token ids
'''

tokenizer = self.get_tokenizer_group("prompts must be None if "
"skip_tokenizer_init is True")
tokenizer = self.get_tokenizer_group(
missing_msg="prompts must be None if skip_tokenizer_init is True")

return tokenizer.encode(request_id=request_id,
prompt=prompt,
Expand Down
Loading

0 comments on commit 5c225f8

Please sign in to comment.