Skip to content

Commit

Permalink
[Typing] Mypy typing part 2 (vllm-project#4043)
Browse files Browse the repository at this point in the history
Co-authored-by: SangBin Cho <[email protected]>
  • Loading branch information
rkooo567 and SangBin Cho authored Apr 18, 2024
1 parent 42f9152 commit aa59b2f
Show file tree
Hide file tree
Showing 20 changed files with 180 additions and 126 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ jobs:
mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml
8 changes: 4 additions & 4 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml

# TODO(sang): Follow up
# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml
mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml
# mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml


Expand Down
44 changes: 25 additions & 19 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import os
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union)

from transformers import PreTrainedTokenizer

Expand Down Expand Up @@ -52,7 +52,7 @@ class AsyncStream:

def __init__(self, request_id: str) -> None:
self.request_id = request_id
self._queue = asyncio.Queue()
self._queue: asyncio.Queue = asyncio.Queue()
self._finished = False

def put(self, item: Union[RequestOutput, Exception]) -> None:
Expand Down Expand Up @@ -312,15 +312,17 @@ def __init__(self,
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)

self.background_loop = None
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
# task as well to prevent it from being garbage
# collected
self._background_loop_unshielded = None
self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None
self.start_engine_loop = start_engine_loop
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None

# Lazy initialized fields
self._request_tracker: RequestTracker

@classmethod
def from_engine_args(
cls,
Expand Down Expand Up @@ -361,11 +363,13 @@ def from_engine_args(
@property
def is_running(self) -> bool:
return (self.background_loop is not None
and self._background_loop_unshielded is not None
and not self._background_loop_unshielded.done())

@property
def is_stopped(self) -> bool:
return self.errored or (self.background_loop is not None
return self.errored or (self.background_loop is not None and
self._background_loop_unshielded is not None
and self._background_loop_unshielded.done())

@property
Expand All @@ -381,7 +385,7 @@ def _error_callback(self, exc: Exception) -> None:

async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
return await self.engine.get_tokenizer.remote() # type: ignore
else:
return self.engine.get_tokenizer()

Expand Down Expand Up @@ -434,7 +438,8 @@ async def engine_step(self) -> bool:
# TODO: Maybe add add_request_batch to reduce Ray overhead
try:
if self.engine_use_ray:
await self.engine.add_request.remote(**new_request)
await self.engine.add_request.remote( # type: ignore
**new_request)
else:
await self.engine.add_request_async(**new_request)
except ValueError as e:
Expand All @@ -449,7 +454,7 @@ async def engine_step(self) -> bool:
await self._engine_abort(finished_requests)

if self.engine_use_ray:
request_outputs = await self.engine.step.remote()
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()

Expand All @@ -462,7 +467,7 @@ async def engine_step(self) -> bool:

async def _engine_abort(self, request_ids: Iterable[str]):
if self.engine_use_ray:
await self.engine.abort_request.remote(request_ids)
await self.engine.abort_request.remote(request_ids) # type: ignore
else:
self.engine.abort_request(request_ids)

Expand Down Expand Up @@ -525,11 +530,12 @@ async def add_request(
arrival_time = time.time()

if self.engine_use_ray:
prompt_token_ids = await self.engine.encode_request_async.remote(
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request)
prompt_token_ids = await (
self.engine.encode_request_async.remote( # type: ignore
request_id=request_id,
prompt=prompt,
prompt_token_ids=prompt_token_ids,
lora_request=lora_request))
else:
prompt_token_ids = await self.engine.encode_request_async(
request_id=request_id,
Expand Down Expand Up @@ -676,13 +682,13 @@ def _abort(self, request_id: str) -> None:
async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_model_config.remote()
return await self.engine.get_model_config.remote() # type: ignore
else:
return self.engine.get_model_config()

async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote()
await self.engine.do_log_stats.remote() # type: ignore
else:
self.engine.do_log_stats()

Expand All @@ -695,7 +701,7 @@ async def check_health(self) -> None:

if self.engine_use_ray:
try:
await self.engine.check_health.remote()
await self.engine.check_health.remote() # type: ignore
except ray.exceptions.RayActorError as e:
raise RuntimeError("Engine is dead.") from e
else:
Expand Down
4 changes: 2 additions & 2 deletions vllm/lora/worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,12 @@ def create_lora_manager(
self._lora_manager: LoRAModelManager = lora_manager
return lora_manager.model

def set_active_loras(self, lora_requests: List[LoRARequest],
def set_active_loras(self, lora_requests: Set[LoRARequest],
lora_mapping: LoRAMapping) -> None:
self._apply_loras(lora_requests)
self._lora_manager.set_lora_mapping(lora_mapping)

def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None:
loras_that_exist = self.list_loras()
loras_map = {
lora_request.lora_int_id: lora_request
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/guided_decoding/outlines_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GuidedDecodingMode(Enum):

async def get_outlines_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
Expand Down Expand Up @@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor(

def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
) -> Tuple[str, GuidedDecodingMode]:
) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]:

if request.guided_json:
json = request.guided_json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@
from typing import Callable, DefaultDict, Dict, List, Optional, Union

import torch
from outlines.fsm.fsm import CFGFSM, RegexFSM
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase


class BaseLogitsProcessor:

def __init__(self):
# Child class should use initialize in their init.
self.fsm: FSM

def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/model_loader/neuron.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models."""
import importlib
import os
from typing import Optional, Type
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
Expand All @@ -27,7 +27,7 @@
}

# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS = {
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
Expand All @@ -43,11 +43,13 @@ def __init__(
) -> None:
super().__init__()
self.config = config
self.model = None
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()

# Lazy initialized
self.model: nn.Module

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -74,17 +76,17 @@ def sample(

def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls, hf_model_cls = (
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)

split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
hf_model_cls = getattr(transformers, hf_model_cls)
hf_model_cls = getattr(transformers, hf_model_cls_name)
from transformers_neuronx.module import save_pretrained_split

hf_model = hf_model_cls.from_pretrained(model_name_or_path,
Expand All @@ -96,7 +98,7 @@ def load_weights(self, model_name_or_path: str, **kwargs):
self.model.to_neuron()


def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/model_loader/tensorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __post_init__(self):
decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params

@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Tensorizer CLI arguments"""
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ def from_sampling_metadata(
get_num_triton_sampler_splits(vocab_size))

sample_indices_start_idx = 0
assert sampling_metadata.seq_groups is not None
assert sampling_metadata.seq_data is not None
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
Expand Down Expand Up @@ -147,6 +149,7 @@ def from_sampling_metadata(
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
# their logprobs
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]
temperatures += [temperature] * (prompt_len - 1)
top_ps += [top_p] * (prompt_len - 1)
Expand All @@ -172,6 +175,7 @@ def from_sampling_metadata(
is_prompt = i < sampling_metadata.num_prompts
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
assert sampling_metadata.prompt_lens is not None
prompt_len = sampling_metadata.prompt_lens[i]

if sampling_params.prompt_logprobs is not None:
Expand Down
6 changes: 3 additions & 3 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def score_proposals(
def _expand_batch(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
proposal_token_ids_list: List[TokenId],
proposal_token_ids_list: List[List[TokenId]],
proposal_lens_list: List[int],
) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]:
"""Given the input sequences and potentially multiple corresponding
Expand Down Expand Up @@ -218,7 +218,7 @@ def _create_scoring_model_input(
def _create_target_seq_group_metadata(
self,
input_seq_group_metadata: SequenceGroupMetadata,
proposal_token_ids: List[TokenId], # shape: [batch_size, k]
proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k]
batch_index: int,
target_seq_ids_iter: Iterator[TargetSeqId],
) -> List[SequenceGroupMetadata]:
Expand Down Expand Up @@ -360,7 +360,7 @@ def _get_token_ids_to_score(
[0, 1, 2]
[0, 1, 2, 3]
"""
empty_token_ids = []
empty_token_ids: List[TokenId] = []

token_ids_to_score = [empty_token_ids]
token_ids_to_score.extend([
Expand Down
4 changes: 2 additions & 2 deletions vllm/spec_decode/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional

import torch

Expand Down Expand Up @@ -73,5 +73,5 @@ def score_proposals(
blocks_to_copy: Optional[Dict[int, List[int]]],
k: int,
proposals: SpeculativeProposals,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> SpeculativeScores:
raise NotImplementedError
1 change: 1 addition & 0 deletions vllm/spec_decode/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event:
Returns a CUDA event recording when the copy is complete.
"""
assert self._copy_stream is not None
self._copy_stream.wait_stream(torch.cuda.current_stream())

with torch.cuda.stream(self._copy_stream):
Expand Down
Loading

0 comments on commit aa59b2f

Please sign in to comment.