diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml new file mode 100644 index 0000000000000..fbe0f816fd4af --- /dev/null +++ b/.github/workflows/mypy.yaml @@ -0,0 +1,50 @@ +name: mypy + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + ruff: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install mypy==1.9.0 + pip install types-setuptools + pip install types-PyYAML + pip install types-requests + pip install types-setuptools + - name: Mypy + run: | + mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + + # TODO(sang): Follow up + # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml + # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + diff --git a/format.sh b/format.sh index deb57b2b049d1..1c195b899c742 100755 --- a/format.sh +++ b/format.sh @@ -93,9 +93,23 @@ fi echo 'vLLM yapf: Done' # Run mypy -# TODO(zhuohan): Enable mypy -# echo 'vLLM mypy:' -# mypy +echo 'vLLM mypy:' +mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + +# TODO(sang): Follow up +# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml +# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml + CODESPELL_EXCLUDES=( '--skip' '*docs/source/_build/**' @@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then exit 1 fi - - diff --git a/pyproject.toml b/pyproject.toml index 607c09935db89..805170719e50a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,10 +45,13 @@ ignore = [ python_version = "3.8" ignore_missing_imports = true + check_untyped_defs = true files = "vllm" # TODO(woosuk): Include the code from Megatron and HuggingFace. -exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/" +exclude = [ + "vllm/model_executor/parallel_utils/|vllm/model_executor/models/", +] [tool.codespell] diff --git a/requirements-common.txt b/requirements-common.txt index ff053388a23e1..c96f9c9937fb0 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -11,4 +11,5 @@ uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 tiktoken == 0.6.0 # Required for DBRX tokenizer -outlines == 0.0.34 # Requires torch >= 2.1.0 \ No newline at end of file +outlines == 0.0.34 # Requires torch >= 2.1.0 +typing_extensions \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 75d22bbdb2a1b..96dfda6faf00f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ codespell==2.2.6 isort==5.13.2 # type checking -mypy==0.991 +mypy==1.9.0 types-PyYAML types-requests types-setuptools diff --git a/vllm/config.py b/vllm/config.py index 744fecdc7c64f..ca3f004d0dd4c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2,7 +2,7 @@ import json import os from dataclasses import dataclass, fields -from typing import TYPE_CHECKING, ClassVar, Optional, Union +from typing import TYPE_CHECKING, ClassVar, List, Optional, Union import torch from packaging.version import Version @@ -147,7 +147,7 @@ def _verify_load_format(self) -> None: supported_load_format = [ "auto", "pt", "safetensors", "npcache", "dummy" ] - rocm_not_supported_load_format = [] + rocm_not_supported_load_format: List[str] = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " @@ -719,6 +719,9 @@ def maybe_create_spec_config( "num_speculative_tokens to be provided, but found " f"{speculative_model=} and {num_speculative_tokens=}.") + assert (speculative_model is not None + and num_speculative_tokens is not None) + # TODO: The user should be able to specify revision/quantization/max # model len for the draft model. It is not currently supported. draft_revision = None @@ -1033,7 +1036,7 @@ def _get_and_verify_max_len( derived_max_model_len *= scaling_factor if max_model_len is None: - max_model_len = derived_max_model_len + max_model_len = int(derived_max_model_len) elif max_model_len > derived_max_model_len: # Some models might have a separate key for specifying model_max_length # that will be bigger than derived_max_model_len. We compare user input diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index e7e3b4dc1e9b4..e391a3b1e5a33 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,5 +1,6 @@ """A block manager that manages token blocks.""" from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from itertools import count, takewhile from os.path import commonprefix from typing import Dict, List, Optional, Set @@ -231,10 +232,10 @@ def __init__( if self.enable_caching: logger.info("Automatic prefix caching is enabled.") - self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator( + Device.CPU, block_size, num_cpu_blocks) else: self.gpu_allocator = UncachedBlockAllocator( Device.GPU, block_size, num_gpu_blocks) @@ -588,7 +589,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]: for b in takewhile(lambda b: b.computed, block_table[:-1]) ] - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Return the block ids that are common for a given sequence group. Used in prefill (can skip prefill of some blocks). diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 813e71ad883b2..19f0cf415eb34 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" +from collections.abc import Sequence as GenericSequence from typing import Dict, List, Optional from vllm.core.block.block_table import BlockTable @@ -205,7 +206,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup): # as computed. self.block_allocator.mark_blocks_as_computed() - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: """Determine which blocks for which we skip prefill. With prefix caching we can skip prefill for previously-generated blocks. diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 711536bcc97be..c1f68a2e891bf 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -1,5 +1,6 @@ import enum from abc import ABC, abstractmethod +from collections.abc import Sequence as GenericSequence from typing import Dict, List from vllm.sequence import Sequence, SequenceGroup @@ -103,7 +104,8 @@ def access_all_blocks_in_seq( pass @abstractmethod - def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]: + def get_common_computed_block_ids( + self, seqs: List[Sequence]) -> GenericSequence[int]: pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4aa17adb602d2..4da7600387f05 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -42,8 +42,8 @@ class SchedulingBudget: """ token_budget: int max_num_seqs: int - _requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set) - _requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set) + _requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set) + _requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set) _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 @@ -133,7 +133,7 @@ def is_empty(self) -> bool: return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) - def _sort_by_lora_ids(self) -> bool: + def _sort_by_lora_ids(self): self.scheduled_seq_groups = sorted( self.scheduled_seq_groups, key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id)) @@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self.free_seq(seq) def has_unfinished_seqs(self) -> bool: - return self.waiting or self.running or self.swapped + return len(self.waiting) != 0 or len(self.running) != 0 or len( + self.swapped) != 0 def get_num_unfinished_seq_groups(self) -> int: return len(self.waiting) + len(self.running) + len(self.swapped) @@ -404,7 +405,7 @@ def _schedule_running( budget.subtract_num_seqs(seq_group.request_id, num_running_seqs) if curr_loras is not None and seq_group.lora_int_id > 0: - curr_loras.pop(seq_group.lora_int_id) + curr_loras.remove(seq_group.lora_int_id) if running_queue: # Preempt the lowest-priority sequence groups. @@ -496,7 +497,7 @@ def _schedule_swapped( now = time.time() swapped_queue = policy.sort_by_priority(now, swapped_queue) - leftover_swapped = deque() + leftover_swapped: Deque[SequenceGroup] = deque() while swapped_queue: seq_group = swapped_queue[0] @@ -507,7 +508,9 @@ def _schedule_swapped( lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id - if (lora_int_id > 0 and lora_int_id not in curr_loras + assert curr_loras is not None + assert self.lora_config is not None + if (lora_int_id > 0 and (lora_int_id not in curr_loras) and len(curr_loras) >= self.lora_config.max_loras): # We don't have a space for another LoRA, so # we ignore this request for now. @@ -593,7 +596,7 @@ def _schedule_prefills( # Copy the queue so that the input queue is not modified. waiting_queue = deque([s for s in waiting_queue]) - leftover_waiting_sequences = deque() + leftover_waiting_sequences: Deque[SequenceGroup] = deque() while self._passed_delay(time.time()) and waiting_queue: seq_group = waiting_queue[0] @@ -635,6 +638,8 @@ def _schedule_prefills( lora_int_id = 0 if self.lora_enabled: lora_int_id = seq_group.lora_int_id + assert curr_loras is not None + assert self.lora_config is not None if (self.lora_enabled and lora_int_id > 0 and lora_int_id not in curr_loras and len(curr_loras) >= self.lora_config.max_loras): @@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self): token_budget=self.scheduler_config.max_num_batched_tokens, max_num_seqs=self.scheduler_config.max_num_seqs, ) - curr_loras = set() + curr_loras: Set[int] = set() remaining_waiting, prefills = (self.waiting, SchedulerPrefillOutputs.create_empty()) @@ -1108,7 +1113,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int: def _get_num_new_tokens(self, seq_group: SequenceGroup, status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> Tuple[int, bool]: + budget: SchedulingBudget) -> int: """Get the next new tokens to compute for a given sequence group that's in a given `status`. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 1004d626b6a4b..a3e93691a1e8e 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -1,5 +1,5 @@ from collections import namedtuple -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch.distributed import ProcessGroup @@ -144,7 +144,7 @@ def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, -) -> Dict[Any, Union[torch.Tensor, Any]]: +) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary.""" group = group or torch.distributed.group.WORLD ranks = torch.distributed.get_process_group_ranks(group) @@ -157,10 +157,10 @@ def broadcast_tensor_dict( rank = torch.distributed.get_rank() if rank == src: + metadata_list: List[Tuple[Any, Any]] = [] assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - metadata_list = [] for key, value in tensor_dict.items(): if isinstance(value, torch.Tensor): assert value.is_cuda, ( @@ -190,10 +190,10 @@ def broadcast_tensor_dict( torch.distributed.broadcast_object_list(recv_metadata_list, src=src, group=group) - metadata_list = recv_metadata_list[0] + assert recv_metadata_list[0] is not None tensor_dict = {} async_handles = [] - for key, value in metadata_list: + for key, value in recv_metadata_list[0]: if isinstance(value, TensorMetadata): tensor = torch.empty(value.size, dtype=value.dtype, diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 70d5c9b1fae05..04d4ed83976d0 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,9 +1,10 @@ import pickle -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger from vllm.utils import get_ip, is_hip, set_cuda_visible_devices +from vllm.worker.worker import Worker logger = init_logger(__name__) @@ -18,15 +19,20 @@ def __init__(self, init_cached_hf_modules=False) -> None: if init_cached_hf_modules: from transformers.dynamic_module_utils import init_hf_modules init_hf_modules() - self.worker = None + self._worker: Optional[Worker] = None # Since the compiled DAG runs a main execution # in a different thread that calls cuda.set_device. # The flag indicates is set_device is called on # that thread. self.compiled_dag_cuda_device_set = False - def init_worker(self, worker_init_fn): - self.worker = worker_init_fn() + def init_worker(self, worker_init_fn: Callable[[], Worker]): + self._worker = worker_init_fn() + + @property + def worker(self) -> Worker: + assert self._worker is not None + return self._worker def __getattr__(self, name): return getattr(self.worker, name) @@ -70,8 +76,8 @@ def execute_model_compiled_dag_remote(self, ignored): logger.warning(f"Failed to import Ray with {e!r}. " "For distributed inference, please install Ray with " "`pip install ray`.") - ray = None - RayWorkerVllm = None + ray = None # type: ignore + RayWorkerVllm = None # type: ignore def initialize_ray_cluster( diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 2a47eae112c12..587142adb9c6b 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -47,6 +47,7 @@ async def generate(request: Request) -> Response: sampling_params = SamplingParams(**request_dict) request_id = random_uuid() + assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index c5261d6e4556c..0694d3b7cdb91 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -179,8 +179,12 @@ def generate( multi_modal_data.data = multi_modal_data.data.to(torch.float16) # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + if prompts is not None: + num_requests = len(prompts) + else: + assert prompt_token_ids is not None + num_requests = len(prompt_token_ids) + for i in range(num_requests): prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index eda4e8989c163..33e67d8b3eec2 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch @@ -61,7 +61,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 80ca5cb7367c5..f20221a0b941a 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -66,7 +66,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 57436a85cfa27..ee8e87432fa67 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -47,7 +47,7 @@ def _init_worker(self): self.driver_worker.init_device() self.driver_worker.load_model() - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks by invoking the underlying worker. """ diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6c0ccd7e64c90..b937693c92257 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,7 +3,7 @@ import os import pickle from collections import defaultdict -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, SpeculativeConfig, @@ -197,7 +197,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_parallel_loading_workers, ) - def determine_num_available_blocks(self) -> tuple[int, int]: + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. This invokes `determine_num_available_blocks` on each worker and takes @@ -205,7 +205,7 @@ def determine_num_available_blocks(self) -> tuple[int, int]: compatible with all workers. Returns: - - tuple[num_gpu_blocks, num_cpu_blocks] + - Tuple[num_gpu_blocks, num_cpu_blocks] """ # Get the maximum number of blocks that can be allocated on GPU and CPU. num_blocks = self._run_workers("determine_num_available_blocks", ) @@ -276,7 +276,7 @@ def _run_workers( self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, max_concurrent_workers: Optional[int] = None, use_ray_compiled_dag: bool = False, @@ -291,6 +291,7 @@ def _run_workers( if use_ray_compiled_dag: # Right now, compiled DAG can only accept a single # input. TODO(sang): Fix it. + assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) else: # Start the ray workers first. @@ -369,7 +370,7 @@ async def _run_workers_async( self, method: str, *args, - driver_args: Optional[List[Any]] = None, + driver_args: Optional[Tuple[Any, ...]] = None, driver_kwargs: Optional[Dict[str, Any]] = None, **kwargs, ) -> Any: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 0b9787608798c..53a38b25bfdac 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -5,7 +5,8 @@ from typing import Callable, List, Optional, Union import torch -from pydantic import conint +from pydantic import Field +from typing_extensions import Annotated _SAMPLING_EPS = 1e-5 @@ -127,7 +128,7 @@ def __init__( skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, logits_processors: Optional[List[LogitsProcessor]] = None, - truncate_prompt_tokens: Optional[conint(ge=1)] = None, + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n diff --git a/vllm/sequence.py b/vllm/sequence.py index cdb6cce6f0255..dcde81df19923 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -171,10 +171,10 @@ def get_last_token_id(self) -> int: return self.prompt_token_ids[-1] return self.output_token_ids[-1] - def get_prompt_token_ids(self) -> int: + def get_prompt_token_ids(self) -> List[int]: return self.prompt_token_ids - def get_output_token_ids(self) -> int: + def get_output_token_ids(self) -> List[int]: return self.output_token_ids @property @@ -370,7 +370,7 @@ class SequenceGroupState: """Mutable state tied to a specific sequence group""" # torch.Generator used in seeded sampling - generator: Optional = None + generator: Optional = None # type: ignore class MultiModalData: @@ -599,7 +599,7 @@ def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 @property - def token_chunk_size(self) -> int: + def token_chunk_size(self) -> Optional[int]: """Return the number of tokens to be processed (chunk size).""" return self._token_chunk_size diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index ce7a30dce72fa..1756c91a612f0 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -2,7 +2,8 @@ from transformers import AutoConfig, PretrainedConfig -from vllm.transformers_utils.configs import * +from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, + JAISConfig, MPTConfig, RWConfig) _CONFIG_REGISTRY: Dict[str, PretrainedConfig] = { "chatglm": ChatGLMConfig, diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index 005932f1e3df4..f064c26c3f40c 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -168,8 +168,8 @@ def _convert_tokens_to_string_with_added_encoders( # NOTE(woosuk): The following code is slow because it runs a for loop over # the output_tokens. In Python, running a for loop over a list can be slow # even when the loop body is very simple. - sub_texts = [] - current_sub_text = [] + sub_texts: List[str] = [] + current_sub_text: List[str] = [] all_special_tokens = set(tokenizer.all_special_tokens) for token in output_tokens: if skip_special_tokens and token in all_special_tokens: @@ -263,6 +263,7 @@ def detokenize_incrementally( tokenizer, all_input_ids[:-1], skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None # If the new token id is out of bounds, return an empty string. if new_token_id >= len(tokenizer): @@ -271,6 +272,8 @@ def detokenize_incrementally( # Put new_token_id in a list so skip_special_tokens is respected new_tokens = tokenizer.convert_ids_to_tokens( [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] output_tokens = prev_tokens + new_tokens # If this is the first iteration, return all tokens. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index e216a99af91f9..5d3d5801c960d 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -5,7 +5,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import * +from vllm.transformers_utils.tokenizers import BaichuanTokenizer from vllm.utils import make_async logger = init_logger(__name__) @@ -28,7 +28,7 @@ def get_cached_tokenizer( tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) tokenizer_len = len(tokenizer) - class CachedTokenizer(tokenizer.__class__): + class CachedTokenizer(tokenizer.__class__): # type: ignore @property def all_special_ids(self): diff --git a/vllm/usage/usage_lib.py b/vllm/usage/usage_lib.py index 658fe5c98f5ee..b2672f7f1da61 100644 --- a/vllm/usage/usage_lib.py +++ b/vllm/usage/usage_lib.py @@ -7,7 +7,7 @@ from enum import Enum from pathlib import Path from threading import Thread -from typing import Dict, Optional +from typing import Any, Dict, Optional from uuid import uuid4 import cpuinfo @@ -124,7 +124,7 @@ def __init__(self) -> None: def report_usage(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any] = None) -> None: + extra_kvs: Optional[Dict[str, Any]] = None) -> None: t = Thread(target=self._report_usage_worker, args=(model_architecture, usage_context, extra_kvs or {}), daemon=True) @@ -132,13 +132,13 @@ def report_usage(self, def _report_usage_worker(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: self._report_usage_once(model_architecture, usage_context, extra_kvs) self._report_continous_usage() def _report_usage_once(self, model_architecture: str, usage_context: UsageContext, - extra_kvs: Dict[str, any]) -> None: + extra_kvs: Dict[str, Any]) -> None: # Platform information if torch.cuda.is_available(): device_property = torch.cuda.get_device_properties(0) diff --git a/vllm/utils.py b/vllm/utils.py index e67d267aed408..f4aaf0cc51d7a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -60,7 +60,7 @@ def __contains__(self, key: Hashable) -> bool: def __len__(self) -> int: return len(self.cache) - def __getitem__(self, key: Hashable) -> T: + def __getitem__(self, key: Hashable) -> Optional[T]: return self.get(key) def __setitem__(self, key: Hashable, value: T) -> None: @@ -76,7 +76,7 @@ def get(self, key: Hashable, default_value: Optional[T] = None) -> Optional[T]: if key in self.cache: - value = self.cache[key] + value: Optional[T] = self.cache[key] self.cache.move_to_end(key) else: value = default_value @@ -87,7 +87,7 @@ def put(self, key: Hashable, value: T) -> None: self.cache.move_to_end(key) self._remove_old_if_needed() - def _on_remove(self, key: Hashable, value: T): + def _on_remove(self, key: Hashable, value: Optional[T]): pass def remove_oldest(self): @@ -100,9 +100,11 @@ def _remove_old_if_needed(self) -> None: while len(self.cache) > self.capacity: self.remove_oldest() - def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T: + def pop(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: run_on_remove = key in self.cache - value = self.cache.pop(key, default_value) + value: Optional[T] = self.cache.pop(key, default_value) if run_on_remove: self._on_remove(key, value) return value