Skip to content

Commit

Permalink
[Mypy] Part 3 fix typing for nested directories for most of directory (
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored Apr 23, 2024
1 parent 34128a6 commit 0ae11f7
Show file tree
Hide file tree
Showing 29 changed files with 126 additions and 88 deletions.
29 changes: 15 additions & 14 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,20 @@ jobs:
pip install types-setuptools
- name: Mypy
run: |
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
# TODO(sang): Fix nested dir
mypy vllm/model_executor/*.py --config-file pyproject.toml
# TODO(sang): Fix nested dir
# mypy vllm/lora/*.py --config-file pyproject.toml
26 changes: 12 additions & 14 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -94,21 +94,19 @@ echo 'vLLM yapf: Done'

# Run mypy
echo 'vLLM mypy:'
mypy vllm/attention/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/attention --config-file pyproject.toml
mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/entrypoints/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/usage/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml

# TODO(sang): Follow up
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/distributed --config-file pyproject.toml
mypy vllm/entrypoints --config-file pyproject.toml
mypy vllm/executor --config-file pyproject.toml
mypy vllm/usage --config-file pyproject.toml
mypy vllm/*.py --config-file pyproject.toml
mypy vllm/transformers_utils --config-file pyproject.toml
mypy vllm/engine --config-file pyproject.toml
mypy vllm/worker --config-file pyproject.toml
mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor/*.py --config-file pyproject.toml
# mypy vllm/lora/*.py --config-file pyproject.toml


CODESPELL_EXCLUDES=(
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,17 @@ ignore = [
python_version = "3.8"

ignore_missing_imports = true
check_untyped_defs = true
check_untyped_defs = true
follow_imports = "skip"

files = "vllm"
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = [
"vllm/model_executor/parallel_utils/|vllm/model_executor/models/",
# Ignore triton kernels in ops.
'vllm/attention/ops/.*\.py$'
]


[tool.codespell]
ignore-words-list = "dout, te, indicies"
skip = "./tests/prompts,./benchmarks/sonnet.txt"
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata[AttentionMetadataPerStage],
attn_metadata: AttentionMetadata,
kv_scale: float,
) -> torch.Tensor:
raise NotImplementedError
1 change: 1 addition & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
# Prompt run.
assert prefill_meta.prompt_lens is not None
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
# triton attention
# When block_tables are not filled, it means q and k are the
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def forward(
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: TorchSDPAMetadata,
attn_metadata: TorchSDPAMetadata, # type: ignore
kv_scale: float,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down Expand Up @@ -136,6 +136,7 @@ def forward(
kv_scale)

if attn_metadata.is_prompt:
assert attn_metadata.prompt_lens is not None
if (kv_cache is None or attn_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def _run_memory_efficient_xformers_forward(
value: shape = [num_prefill_tokens, num_kv_heads, head_size]
attn_metadata: Metadata for attention.
"""
assert attn_metadata.prompt_lens is not None
original_query = query
if self.num_kv_heads != self.num_heads:
# GQA/MQA requires the shape [B, M, G, H, K].
Expand Down
1 change: 1 addition & 0 deletions vllm/core/block/block_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def append_token_ids(self,
token_ids (List[int]): The sequence of token IDs to be appended.
"""
assert self._is_allocated
assert self._blocks is not None

self.ensure_num_empty_slots(num_empty_slots=len(token_ids) +
num_lookahead_slots)
Expand Down
6 changes: 4 additions & 2 deletions vllm/core/block/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
refcounter: RefCounter,
allocator: BlockAllocator,
):
self._copy_on_writes = defaultdict(list)
self._copy_on_writes: Dict[BlockId, List[BlockId]] = defaultdict(list)
self._refcounter = refcounter
self._allocator = allocator

Expand Down Expand Up @@ -138,6 +138,8 @@ def cow_block_if_not_appendable(self, block: Block) -> Optional[BlockId]:
prev_block=block.prev_block).block_id

# Track src/dst copy.
assert src_block_id is not None
assert block_id is not None
self._copy_on_writes[src_block_id].append(block_id)

return block_id
Expand Down Expand Up @@ -180,6 +182,6 @@ def recurse(block: Block, lst: List[Block]) -> None:
recurse(block.prev_block, lst)
lst.append(block)

all_blocks = []
all_blocks: List[Block] = []
recurse(last_block, all_blocks)
return all_blocks
6 changes: 2 additions & 4 deletions vllm/core/block/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ def __call__(
class BlockAllocator(ABC):

@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass

@abstractmethod
Expand Down Expand Up @@ -98,8 +97,7 @@ class NoFreeBlocksError(ValueError):
class DeviceAwareBlockAllocator(BlockAllocator):

@abstractmethod
def allocate_mutable(self, prev_block: Optional[Block],
device: Device) -> Block:
def allocate_mutable(self, prev_block: Optional[Block]) -> Block:
pass

@abstractmethod
Expand Down
16 changes: 9 additions & 7 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from contextlib import contextmanager
from typing import List, Optional
from typing import Any, List, Optional

import torch
import torch.distributed as dist
Expand All @@ -18,7 +18,7 @@

logger = init_logger(__name__)

_CA_HANDLE = None
_CA_HANDLE: Optional["CustomAllreduce"] = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]

Expand Down Expand Up @@ -51,7 +51,7 @@ def init_custom_ar() -> None:
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
return
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
if "CUDA_VISIBLE_DEVICES" in os.environ:
Expand Down Expand Up @@ -117,7 +117,7 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return
return None
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
Expand All @@ -135,6 +135,8 @@ def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)

return None


@contextmanager
def _nvml():
Expand Down Expand Up @@ -224,14 +226,14 @@ def _get_ipc_meta(self, inp: torch.Tensor):
return self._gather_ipc_meta(shard_data)

def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size
all_data: List[Optional[Any]] = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)

handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0])
offsets.append(all_data[i][1])
handles.append(all_data[i][0]) # type: ignore
offsets.append(all_data[i][1]) # type: ignore
return handles, offsets

def register_buffer(self, inp: torch.Tensor):
Expand Down
22 changes: 13 additions & 9 deletions vllm/distributed/device_communicators/pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,10 @@ def ncclGetUniqueId() -> NcclUniqueId:
ctypes.POINTER(ctypes.c_void_p), ctypes.c_int, NcclUniqueId, ctypes.c_int
]

ncclDataType_t = ctypes.c_int

# enums
class ncclDataType_t(ctypes.c_int):

class ncclDataTypeEnum:
ncclInt8 = 0
ncclChar = 0
ncclUint8 = 1
Expand All @@ -128,7 +129,7 @@ class ncclDataType_t(ctypes.c_int):
ncclNumTypes = 10

@classmethod
def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
def from_torch(cls, dtype: torch.dtype) -> int:
if dtype == torch.int8:
return cls.ncclInt8
if dtype == torch.uint8:
Expand All @@ -148,7 +149,10 @@ def from_torch(cls, dtype: torch.dtype) -> 'ncclDataType_t':
raise ValueError(f"Unsupported dtype: {dtype}")


class ncclRedOp_t(ctypes.c_int):
ncclRedOp_t = ctypes.c_int


class ncclRedOpTypeEnum:
ncclSum = 0
ncclProd = 1
ncclMax = 2
Expand All @@ -157,7 +161,7 @@ class ncclRedOp_t(ctypes.c_int):
ncclNumOps = 5

@classmethod
def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
def from_torch(cls, op: ReduceOp) -> int:
if op == ReduceOp.SUM:
return cls.ncclSum
if op == ReduceOp.PRODUCT:
Expand All @@ -180,8 +184,8 @@ def from_torch(cls, op: ReduceOp) -> 'ncclRedOp_t':
_c_ncclAllReduce = nccl.ncclAllReduce
_c_ncclAllReduce.restype = ctypes.c_int
_c_ncclAllReduce.argtypes = [
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclDataType_t,
ncclRedOp_t, ctypes.c_void_p, ctypes.c_void_p
ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ncclRedOp_t,
ncclDataType_t, ctypes.c_void_p, ctypes.c_void_p
]

# equivalent to c declaration:
Expand Down Expand Up @@ -251,8 +255,8 @@ def all_reduce(self,
result = _c_ncclAllReduce(ctypes.c_void_p(tensor.data_ptr()),
ctypes.c_void_p(tensor.data_ptr()),
tensor.numel(),
ncclDataType_t.from_torch(tensor.dtype),
ncclRedOp_t.from_torch(op), self.comm,
ncclDataTypeEnum.from_torch(tensor.dtype),
ncclRedOpTypeEnum.from_torch(op), self.comm,
ctypes.c_void_p(stream.cuda_stream))
assert result == 0

Expand Down
5 changes: 4 additions & 1 deletion vllm/distributed/device_communicators/pynccl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def is_initialized() -> bool:
def set_pynccl_stream(stream: torch.cuda.Stream):
"""Set the cuda stream for communication"""
try:
assert comm is not None
comm.stream = stream
yield
finally:
Expand All @@ -52,6 +53,7 @@ def init_process_group(world_size: int,
def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None:
"""All-reduces the input tensor across the process group."""
assert input_.is_cuda, f"{input_} should be a cuda tensor"
assert comm is not None
comm.all_reduce(input_, op)


Expand All @@ -62,8 +64,9 @@ def destroy_process_group() -> None:

def get_world_size() -> int:
"""Returns the world size."""
assert comm is not None
return comm.world_size


def get_nccl_backend():
def get_nccl_backend() -> Optional["NCCLCommunicator"]:
return comm
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterable, List
from typing import Callable, List

from transformers import PreTrainedTokenizer

Expand All @@ -8,6 +8,7 @@
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.sequence import Sequence, SequenceGroup, SequenceGroupOutput
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter


class SequenceGroupOutputProcessor(ABC):
Expand All @@ -27,7 +28,7 @@ def create_output_processor(
scheduler_config: SchedulerConfig,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: "StopChecker",
):
Expand Down
5 changes: 3 additions & 2 deletions vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, List
from typing import Callable, List

from transformers import PreTrainedTokenizer

Expand All @@ -11,6 +11,7 @@
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.utils import Counter

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@ def __init__(
self,
detokenizer: Detokenizer,
scheduler: Scheduler,
seq_counter: Iterable[int],
seq_counter: Counter,
get_tokenizer_for_seq: Callable[[Sequence], PreTrainedTokenizer],
stop_checker: StopChecker,
):
Expand Down
Loading

0 comments on commit 0ae11f7

Please sign in to comment.