Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[mypy] Add mypy type annotation part 1 (vllm-project#4006)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored and robertgshaw2-neuralmagic committed Apr 26, 2024
1 parent 405a695 commit 801ad22
Show file tree
Hide file tree
Showing 25 changed files with 171 additions and 72 deletions.
50 changes: 50 additions & 0 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: mypy

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
pull_request:
branches:
- main

jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.8"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install mypy==1.9.0
pip install types-setuptools
pip install types-PyYAML
pip install types-requests
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
22 changes: 17 additions & 5 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,23 @@ fi
echo 'vLLM yapf: Done'

# Run mypy
# TODO(zhuohan): Enable mypy
# echo 'vLLM mypy:'
# mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml

# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml


CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
Expand Down Expand Up @@ -228,5 +242,3 @@ if ! git diff --quiet &>/dev/null; then

exit 1
fi


5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,13 @@ ignore = [
python_version = "3.8"

ignore_missing_imports = true
check_untyped_defs = true

files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = "vllm/model_executor/parallel_utils/|vllm/model_executor/models/"
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
]


[tool.codespell]
Expand Down
3 changes: 2 additions & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ uvicorn[standard]
pydantic >= 2.0 # Required for OpenAI server.
prometheus_client >= 0.18.0
tiktoken == 0.6.0 # Required for DBRX tokenizer
outlines == 0.0.34 # Requires torch >= 2.1.0
outlines == 0.0.34 # Requires torch >= 2.1.0
typing_extensions
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ codespell==2.2.6
isort==5.13.2

# type checking
mypy==0.991
mypy==1.9.0
types-PyYAML
types-requests
types-setuptools
Expand Down
9 changes: 6 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import os
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import TYPE_CHECKING, ClassVar, List, Optional, Union

import torch
from packaging.version import Version
Expand Down Expand Up @@ -147,7 +147,7 @@ def _verify_load_format(self) -> None:
supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy"
]
rocm_not_supported_load_format = []
rocm_not_supported_load_format: List[str] = []
if load_format not in supported_load_format:
raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of "
Expand Down Expand Up @@ -719,6 +719,9 @@ def maybe_create_spec_config(
"num_speculative_tokens to be provided, but found "
f"{speculative_model=} and {num_speculative_tokens=}.")

assert (speculative_model is not None
and num_speculative_tokens is not None)

# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
draft_revision = None
Expand Down Expand Up @@ -1033,7 +1036,7 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor

if max_model_len is None:
max_model_len = derived_max_model_len
max_model_len = int(derived_max_model_len)
elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input
Expand Down
12 changes: 7 additions & 5 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A block manager that manages token blocks."""
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from itertools import count, takewhile
from os.path import commonprefix
from typing import Dict, List, Optional, Set
Expand Down Expand Up @@ -231,10 +232,10 @@ def __init__(

if self.enable_caching:
logger.info("Automatic prefix caching is enabled.")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
self.gpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator: BlockAllocatorBase = CachedBlockAllocator(
Device.CPU, block_size, num_cpu_blocks)
else:
self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks)
Expand Down Expand Up @@ -588,7 +589,8 @@ def get_all_computed_blocks(self, seq: Sequence) -> List[int]:
for b in takewhile(lambda b: b.computed, block_table[:-1])
]

def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Return the block ids that are common for a given sequence group.
Used in prefill (can skip prefill of some blocks).
Expand Down
4 changes: 3 additions & 1 deletion vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A block manager that manages token blocks."""
from collections.abc import Sequence as GenericSequence
from typing import Dict, List, Optional

from vllm.core.block.block_table import BlockTable
Expand Down Expand Up @@ -205,7 +206,8 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup):
# as computed.
self.block_allocator.mark_blocks_as_computed()

def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
"""Determine which blocks for which we skip prefill.
With prefix caching we can skip prefill for previously-generated blocks.
Expand Down
4 changes: 3 additions & 1 deletion vllm/core/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
from abc import ABC, abstractmethod
from collections.abc import Sequence as GenericSequence
from typing import Dict, List

from vllm.sequence import Sequence, SequenceGroup
Expand Down Expand Up @@ -103,7 +104,8 @@ def access_all_blocks_in_seq(
pass

@abstractmethod
def get_common_computed_block_ids(self, seqs: List[Sequence]) -> List[int]:
def get_common_computed_block_ids(
self, seqs: List[Sequence]) -> GenericSequence[int]:
pass

@abstractmethod
Expand Down
25 changes: 15 additions & 10 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class SchedulingBudget:
"""
token_budget: int
max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[int] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[int] = field(default_factory=set)
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0
_num_curr_seqs: int = 0

Expand Down Expand Up @@ -133,7 +133,7 @@ def is_empty(self) -> bool:
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
and not self.blocks_to_swap_out and not self.blocks_to_copy)

def _sort_by_lora_ids(self) -> bool:
def _sort_by_lora_ids(self):
self.scheduled_seq_groups = sorted(
self.scheduled_seq_groups,
key=lambda g: (g.seq_group.lora_int_id, g.seq_group.request_id))
Expand Down Expand Up @@ -337,7 +337,8 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
self.free_seq(seq)

def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0

def get_num_unfinished_seq_groups(self) -> int:
return len(self.waiting) + len(self.running) + len(self.swapped)
Expand Down Expand Up @@ -404,7 +405,7 @@ def _schedule_running(
budget.subtract_num_seqs(seq_group.request_id,
num_running_seqs)
if curr_loras is not None and seq_group.lora_int_id > 0:
curr_loras.pop(seq_group.lora_int_id)
curr_loras.remove(seq_group.lora_int_id)

if running_queue:
# Preempt the lowest-priority sequence groups.
Expand Down Expand Up @@ -496,7 +497,7 @@ def _schedule_swapped(
now = time.time()
swapped_queue = policy.sort_by_priority(now, swapped_queue)

leftover_swapped = deque()
leftover_swapped: Deque[SequenceGroup] = deque()
while swapped_queue:
seq_group = swapped_queue[0]

Expand All @@ -507,7 +508,9 @@ def _schedule_swapped(
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
if (lora_int_id > 0 and lora_int_id not in curr_loras
assert curr_loras is not None
assert self.lora_config is not None
if (lora_int_id > 0 and (lora_int_id not in curr_loras)
and len(curr_loras) >= self.lora_config.max_loras):
# We don't have a space for another LoRA, so
# we ignore this request for now.
Expand Down Expand Up @@ -593,7 +596,7 @@ def _schedule_prefills(
# Copy the queue so that the input queue is not modified.
waiting_queue = deque([s for s in waiting_queue])

leftover_waiting_sequences = deque()
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
seq_group = waiting_queue[0]

Expand Down Expand Up @@ -635,6 +638,8 @@ def _schedule_prefills(
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
assert curr_loras is not None
assert self.lora_config is not None
if (self.lora_enabled and lora_int_id > 0
and lora_int_id not in curr_loras
and len(curr_loras) >= self.lora_config.max_loras):
Expand Down Expand Up @@ -780,7 +785,7 @@ def _schedule_chunked_prefill(self):
token_budget=self.scheduler_config.max_num_batched_tokens,
max_num_seqs=self.scheduler_config.max_num_seqs,
)
curr_loras = set()
curr_loras: Set[int] = set()

remaining_waiting, prefills = (self.waiting,
SchedulerPrefillOutputs.create_empty())
Expand Down Expand Up @@ -1108,7 +1113,7 @@ def _get_num_lookahead_slots(self, is_prefill: bool) -> int:

def _get_num_new_tokens(self, seq_group: SequenceGroup,
status: SequenceStatus, enable_chunking: bool,
budget: SchedulingBudget) -> Tuple[int, bool]:
budget: SchedulingBudget) -> int:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.
Expand Down
10 changes: 5 additions & 5 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from torch.distributed import ProcessGroup
Expand Down Expand Up @@ -144,7 +144,7 @@ def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
Expand All @@ -157,10 +157,10 @@ def broadcast_tensor_dict(

rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
Expand Down Expand Up @@ -190,10 +190,10 @@ def broadcast_tensor_dict(
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0]
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
Expand Down
Loading

0 comments on commit 801ad22

Please sign in to comment.