Skip to content

Commit

Permalink
[mypy] Enable mypy type checking for vllm/core (vllm-project#7229)
Browse files Browse the repository at this point in the history
  • Loading branch information
jberkhahn authored and Jeffwan committed Sep 19, 2024
1 parent eb060d4 commit b0ce749
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 18 deletions.
1 change: 0 additions & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ jobs:
mypy
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/executor --follow-imports skip
Expand Down
1 change: 0 additions & 1 deletion format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ echo 'vLLM mypy:'
mypy --follow-imports skip # Note that this is less strict than CI
mypy tests --follow-imports skip
mypy vllm/attention --follow-imports skip
mypy vllm/core --follow-imports skip
mypy vllm/distributed --follow-imports skip
mypy vllm/engine --follow-imports skip
mypy vllm/executor --follow-imports skip
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ files = [
"vllm/adapter_commons",
"vllm/assets",
"vllm/entrypoints",
"vllm/core",
"vllm/inputs",
"vllm/logging",
"vllm/multimodal",
Expand Down
9 changes: 7 additions & 2 deletions vllm/block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Token blocks."""
from typing import List, Optional
from typing import TYPE_CHECKING, Iterator, List, Optional

from vllm.utils import Device

DEFAULT_LAST_ACCESSED_TIME = -1
DEFAULT_LAST_ACCESSED_TIME: float = -1


class PhysicalTokenBlock:
Expand Down Expand Up @@ -59,6 +59,11 @@ def __len__(self) -> int:
def __getitem__(self, key):
return self._blocks[key]

if TYPE_CHECKING:

def __iter__(self) -> Iterator[PhysicalTokenBlock]:
raise RuntimeError("Method should be automatically generated")

def __setitem__(self, key, value):
if isinstance(key, slice):
blocks = value
Expand Down
2 changes: 1 addition & 1 deletion vllm/core/block/cpu_gpu_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def allocate_mutable_block(self, prev_block: Optional[Block],

def allocate_immutable_blocks(self, prev_block: Optional[Block],
block_token_ids: List[List[int]],
device: Optional[Device]) -> List[Block]:
device: Device) -> List[Block]:
"""Allocates a new group of immutable blocks with the provided block
token IDs on the specified device.
Expand Down
7 changes: 4 additions & 3 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __init__(
# request ID
self.cross_block_tables: Dict[str, BlockTable] = {}

def _get_seq_num_required_blocks(self, seq: Sequence) -> int:
def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int:
return 0 if seq is None else seq.n_blocks

def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
Expand Down Expand Up @@ -310,13 +310,14 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
return AllocStatus.LATER

def _allocate_sequence(self, \
seq: Sequence, \
seq: Optional[Sequence], \
ref_count: int, \
is_encoder_decoder: bool = True) -> BlockTable:
# Allocate new physical token blocks that will store the prompt tokens.
num_prompt_blocks = seq.n_blocks
num_prompt_blocks = self._get_seq_num_required_blocks(seq)

block_table: BlockTable = BlockTable()
assert seq is not None
for logical_idx in range(num_prompt_blocks):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
Expand Down
8 changes: 6 additions & 2 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
)

if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
num_required_blocks += BlockTable.get_num_required_blocks(
seq_group.get_encoder_seq().get_token_ids(),
encoder_seq.get_token_ids(),
block_size=self.block_size,
)

Expand Down Expand Up @@ -189,7 +191,9 @@ def allocate(self, seq_group: SequenceGroup) -> None:
check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group)

if seq_group.is_encoder_decoder():
block_table = self._allocate_sequence(seq_group.get_encoder_seq())
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
block_table = self._allocate_sequence(encoder_seq)
self.cross_block_tables[request_id] = block_table

def can_append_slots(self, seq_group: SequenceGroup,
Expand Down
4 changes: 2 additions & 2 deletions vllm/core/embedding_model_block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def access_all_blocks_in_seq(
pass

def get_common_computed_block_ids(self,
seq_group: SequenceGroup) -> List[int]:
return None # type: ignore
seq_group: List[Sequence]) -> List[int]:
return []

def mark_blocks_as_computed(self, seq_group: SequenceGroup):
pass
Expand Down
16 changes: 10 additions & 6 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,10 @@ class SchedulerSwappedInOutputs:
"""
# Selected sequences that are going to be swapped in and is in a
# decoding phase.
decode_seq_groups: List[SequenceGroup]
decode_seq_groups: List[ScheduledSequenceGroup]
# Selected sequences that are going to be swapped in and in a prefill
# phase. I.e., it means the prefill has been chunked.
prefill_seq_groups: List[SequenceGroup]
prefill_seq_groups: List[ScheduledSequenceGroup]
# The blocks to swap in.
blocks_to_swap_in: List[Tuple[int, int]]
# The blocks to copy.
Expand Down Expand Up @@ -254,7 +254,7 @@ class SchedulerPrefillOutputs:
to be recomputed from scratch.
"""
# Selected sequences for prefill.
seq_groups: List[SequenceGroup]
seq_groups: List[ScheduledSequenceGroup]
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
Expand Down Expand Up @@ -289,7 +289,9 @@ def scheduler_running_outputs_builder():


def scheduled_seq_group_builder():
return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)
return ScheduledSequenceGroup(SequenceGroup("", [], -1),
token_chunk_size=0)
# return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0)


class Scheduler:
Expand Down Expand Up @@ -791,7 +793,7 @@ def _schedule_prefills(
SchedulerPrefillOutputs.
"""
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []

waiting_queue = self.waiting

Expand Down Expand Up @@ -1130,7 +1132,9 @@ def schedule(

if seq_group.is_encoder_decoder():
# Encoder associated with SequenceGroup
encoder_seq_data = seq_group.get_encoder_seq().data
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None
encoder_seq_data = encoder_seq.data
# Block table for cross-attention
# Also managed at SequenceGroup level
cross_block_table = self.block_manager.get_cross_block_table(
Expand Down

0 comments on commit b0ce749

Please sign in to comment.