diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 6db0bb7645ecd..477ce9bc9ce85 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -41,10 +41,10 @@ jobs: mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml + mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml # TODO(sang): Follow up - # mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml - # mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml - # mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml - # mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/format.sh b/format.sh index 1c195b899c742..84ee88b5b4c8a 100755 --- a/format.sh +++ b/format.sh @@ -104,10 +104,10 @@ mypy vllm/*.py --follow-imports=skip --config-file pyproject.toml mypy vllm/transformers_utils/*.py --follow-imports=skip --config-file pyproject.toml # TODO(sang): Follow up -# mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml -# mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml -# mypy vllm/spec_decoding/*.py --follow-imports=skip --config-file pyproject.toml -# mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/engine/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/worker/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/spec_decode/*.py --follow-imports=skip --config-file pyproject.toml +mypy vllm/model_executor/*.py --follow-imports=skip --config-file pyproject.toml # mypy vllm/lora/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 27192449bf15a..c3020d2b38db0 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import os import time from functools import partial -from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional, - Set, Tuple, Type, Union) +from typing import (Any, AsyncIterator, Callable, Dict, Iterable, List, + Optional, Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -52,7 +52,7 @@ class AsyncStream: def __init__(self, request_id: str) -> None: self.request_id = request_id - self._queue = asyncio.Queue() + self._queue: asyncio.Queue = asyncio.Queue() self._finished = False def put(self, item: Union[RequestOutput, Exception]) -> None: @@ -312,15 +312,17 @@ def __init__(self, self.max_log_len = max_log_len self.engine = self._init_engine(*args, **kwargs) - self.background_loop = None + self.background_loop: Optional[asyncio.Future] = None # We need to keep a reference to unshielded # task as well to prevent it from being garbage # collected - self._background_loop_unshielded = None + self._background_loop_unshielded: Optional[asyncio.Task[Any]] = None self.start_engine_loop = start_engine_loop - self._request_tracker: Optional[RequestTracker] = None self._errored_with: Optional[BaseException] = None + # Lazy initialized fields + self._request_tracker: RequestTracker + @classmethod def from_engine_args( cls, @@ -361,11 +363,13 @@ def from_engine_args( @property def is_running(self) -> bool: return (self.background_loop is not None + and self._background_loop_unshielded is not None and not self._background_loop_unshielded.done()) @property def is_stopped(self) -> bool: - return self.errored or (self.background_loop is not None + return self.errored or (self.background_loop is not None and + self._background_loop_unshielded is not None and self._background_loop_unshielded.done()) @property @@ -381,7 +385,7 @@ def _error_callback(self, exc: Exception) -> None: async def get_tokenizer(self) -> "PreTrainedTokenizer": if self.engine_use_ray: - return await self.engine.get_tokenizer.remote() + return await self.engine.get_tokenizer.remote() # type: ignore else: return self.engine.get_tokenizer() @@ -434,7 +438,8 @@ async def engine_step(self) -> bool: # TODO: Maybe add add_request_batch to reduce Ray overhead try: if self.engine_use_ray: - await self.engine.add_request.remote(**new_request) + await self.engine.add_request.remote( # type: ignore + **new_request) else: await self.engine.add_request_async(**new_request) except ValueError as e: @@ -449,7 +454,7 @@ async def engine_step(self) -> bool: await self._engine_abort(finished_requests) if self.engine_use_ray: - request_outputs = await self.engine.step.remote() + request_outputs = await self.engine.step.remote() # type: ignore else: request_outputs = await self.engine.step_async() @@ -462,7 +467,7 @@ async def engine_step(self) -> bool: async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: - await self.engine.abort_request.remote(request_ids) + await self.engine.abort_request.remote(request_ids) # type: ignore else: self.engine.abort_request(request_ids) @@ -525,11 +530,12 @@ async def add_request( arrival_time = time.time() if self.engine_use_ray: - prompt_token_ids = await self.engine.encode_request_async.remote( - request_id=request_id, - prompt=prompt, - prompt_token_ids=prompt_token_ids, - lora_request=lora_request) + prompt_token_ids = await ( + self.engine.encode_request_async.remote( # type: ignore + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request)) else: prompt_token_ids = await self.engine.encode_request_async( request_id=request_id, @@ -676,13 +682,13 @@ def _abort(self, request_id: str) -> None: async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" if self.engine_use_ray: - return await self.engine.get_model_config.remote() + return await self.engine.get_model_config.remote() # type: ignore else: return self.engine.get_model_config() async def do_log_stats(self) -> None: if self.engine_use_ray: - await self.engine.do_log_stats.remote() + await self.engine.do_log_stats.remote() # type: ignore else: self.engine.do_log_stats() @@ -695,7 +701,7 @@ async def check_health(self) -> None: if self.engine_use_ray: try: - await self.engine.check_health.remote() + await self.engine.check_health.remote() # type: ignore except ray.exceptions.RayActorError as e: raise RuntimeError("Engine is dead.") from e else: diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index a0868defbd3ca..5356b79537b05 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -107,12 +107,12 @@ def create_lora_manager( self._lora_manager: LoRAModelManager = lora_manager return lora_manager.model - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: self._apply_loras(lora_requests) self._lora_manager.set_lora_mapping(lora_mapping) - def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: loras_that_exist = self.list_loras() loras_map = { lora_request.lora_int_id: lora_request diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bd4564a36e1ed..53efebb604048 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -55,7 +55,7 @@ class GuidedDecodingMode(Enum): async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: + tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, None]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -84,7 +84,7 @@ async def get_outlines_guided_decoding_logits_processor( def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest] -) -> Tuple[str, GuidedDecodingMode]: +) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: if request.guided_json: json = request.guided_json diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 28041695546dc..95a67b612f08b 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,7 +21,7 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Union import torch -from outlines.fsm.fsm import CFGFSM, RegexFSM +from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -29,6 +29,10 @@ class BaseLogitsProcessor: + def __init__(self): + # Child class should use initialize in their init. + self.fsm: FSM + def init_state(self): """Initialize the FSM states.""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 43d17ad373b87..07e23aca6cc5f 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,7 +1,7 @@ """Utilities for selecting and loading neuron models.""" import importlib import os -from typing import Optional, Type +from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -27,7 +27,7 @@ } # Models supported by Neuron. -_NEURON_SUPPORTED_MODELS = { +_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = { "LlamaForCausalLM": ("transformers_neuronx.llama.model", "LlamaForSampling", "LlamaForCausalLM"), "MistralForCausalLM": ("transformers_neuronx.mistral.model", @@ -43,11 +43,13 @@ def __init__( ) -> None: super().__init__() self.config = config - self.model = None self.logits_processor = LogitsProcessor(config.vocab_size, logits_as_input=True) self.sampler = Sampler() + # Lazy initialized + self.model: nn.Module + def forward( self, input_ids: torch.Tensor, @@ -74,17 +76,17 @@ def sample( def load_weights(self, model_name_or_path: str, **kwargs): arch = _get_model_architecture(self.config) - neuronx_module_path, neuronx_model_cls, hf_model_cls = ( + neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = ( _NEURON_SUPPORTED_MODELS[arch]) neuronx_module = importlib.import_module(neuronx_module_path) - neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls) + neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) split_model_dir = f"{model_name_or_path}-split" if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): split_model_dir = model_name_or_path elif not os.path.exists(f"{model_name_or_path}-split"): - hf_model_cls = getattr(transformers, hf_model_cls) + hf_model_cls = getattr(transformers, hf_model_cls_name) from transformers_neuronx.module import save_pretrained_split hf_model = hf_model_cls.from_pretrained(model_name_or_path, @@ -96,7 +98,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() -def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: +def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: if arch in _NEURON_SUPPORTED_MODELS: diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index ad554844384eb..16be0ecf9ce07 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -167,6 +167,7 @@ def __post_init__(self): decryption_params = DecryptionParams.from_key(key) self.deserializer_params['encryption'] = decryption_params + @staticmethod def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Tensorizer CLI arguments""" diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 534cb75c2fd2f..31032c4cead20 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -113,6 +113,8 @@ def from_sampling_metadata( get_num_triton_sampler_splits(vocab_size)) sample_indices_start_idx = 0 + assert sampling_metadata.seq_groups is not None + assert sampling_metadata.seq_data is not None for i, seq_group in enumerate(sampling_metadata.seq_groups): seq_ids, sampling_params = seq_group temperature = sampling_params.temperature @@ -147,6 +149,7 @@ def from_sampling_metadata( and sampling_params.prompt_logprobs is not None): # For tokens in the prompt that we only need to get # their logprobs + assert sampling_metadata.prompt_lens is not None prompt_len = sampling_metadata.prompt_lens[i] temperatures += [temperature] * (prompt_len - 1) top_ps += [top_p] * (prompt_len - 1) @@ -172,6 +175,7 @@ def from_sampling_metadata( is_prompt = i < sampling_metadata.num_prompts if is_prompt: prompt_best_of.append(sampling_params.best_of) + assert sampling_metadata.prompt_lens is not None prompt_len = sampling_metadata.prompt_lens[i] if sampling_params.prompt_logprobs is not None: diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 88af1dd360155..bbc5b1778854f 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -106,7 +106,7 @@ def score_proposals( def _expand_batch( self, seq_group_metadata_list: List[SequenceGroupMetadata], - proposal_token_ids_list: List[TokenId], + proposal_token_ids_list: List[List[TokenId]], proposal_lens_list: List[int], ) -> Tuple[List[int], List[int], List[SequenceGroupMetadata], int]: """Given the input sequences and potentially multiple corresponding @@ -218,7 +218,7 @@ def _create_scoring_model_input( def _create_target_seq_group_metadata( self, input_seq_group_metadata: SequenceGroupMetadata, - proposal_token_ids: List[TokenId], # shape: [batch_size, k] + proposal_token_ids: List[List[TokenId]], # shape: [batch_size, k] batch_index: int, target_seq_ids_iter: Iterator[TargetSeqId], ) -> List[SequenceGroupMetadata]: @@ -360,7 +360,7 @@ def _get_token_ids_to_score( [0, 1, 2] [0, 1, 2, 3] """ - empty_token_ids = [] + empty_token_ids: List[TokenId] = [] token_ids_to_score = [empty_token_ids] token_ids_to_score.extend([ diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 2a72974d01bdc..f0715120192e5 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional import torch @@ -73,5 +73,5 @@ def score_proposals( blocks_to_copy: Optional[Dict[int, List[int]]], k: int, proposals: SpeculativeProposals, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> SpeculativeScores: raise NotImplementedError diff --git a/vllm/spec_decode/metrics.py b/vllm/spec_decode/metrics.py index 5df8fc4316d48..d1e72b6640548 100644 --- a/vllm/spec_decode/metrics.py +++ b/vllm/spec_decode/metrics.py @@ -112,6 +112,7 @@ def _copy_rejsample_metrics_async(self) -> torch.cuda.Event: Returns a CUDA event recording when the copy is complete. """ + assert self._copy_stream is not None self._copy_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._copy_stream): diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index ce63c329a40aa..8b722476853fa 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -26,7 +26,8 @@ class MultiStepWorker(Worker): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._proposer: Optional[DraftModelTop1Proposer] = None + # Lazy initialization list. + self._proposer: DraftModelTop1Proposer def init_device(self): super().init_device() @@ -338,10 +339,10 @@ def _merge_outputs( self._vocab_size, dtype=torch.float32, device=self._device) - proposal_lens = torch.zeros(len(proposal_lens), - dtype=torch.long, - device=self._device) - return proposal_tokens, proposal_probs, proposal_lens + proposal_lens_tensor = torch.zeros(len(proposal_lens), + dtype=torch.long, + device=self._device) + return proposal_tokens, proposal_probs, proposal_lens_tensor sampler_output = maybe_sampler_output @@ -376,9 +377,9 @@ def _merge_outputs( proposal_tokens, proposal_probs = (entire_proposal_tokens, entire_proposal_probs) - proposal_lens = torch.zeros(batch_size, - dtype=torch.long, - device=self._device) - proposal_lens[nonzero_proposal_len_indices] = max_proposal_len + proposal_lens_tensor = torch.zeros(batch_size, + dtype=torch.long, + device=self._device) + proposal_lens_tensor[nonzero_proposal_len_indices] = max_proposal_len - return proposal_tokens, proposal_probs, proposal_lens + return proposal_tokens, proposal_probs, proposal_lens_tensor diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index be3af7be93864..68a2a774ef4b7 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -89,7 +89,8 @@ def __init__( self.probs_dtype = self.rejection_sampler.probs_dtype self.token_id_dtype = self.rejection_sampler.token_id_dtype - self.scorer: SpeculativeScorer = None + # Lazy initiazliation. + self.scorer: SpeculativeScorer def init_device(self) -> None: """Initialize both scorer and proposer models. @@ -233,6 +234,9 @@ def _run_speculative_decoding_step( logger.info("get spec proposals") # Generate proposals using draft worker. + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None proposals = self.proposer_worker.get_spec_proposals( seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy, k) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d378e3a90e1e7..7377c8931cefa 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Tuple import torch +from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, @@ -48,14 +49,15 @@ def __init__( if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.model = None - self.block_size = None # Set after initial profiling. - self.kv_cache_dtype = kv_cache_dtype self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) + # Lazy initialization. + self.model: nn.Module # Set after init_Model + self.block_size: int # Set after initial profiling. + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config, @@ -245,7 +247,11 @@ def _prepare_sample( selected_token_indices: List[int] = [] generators: List[torch.Generator] = [] selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices: Dict[SamplingType, + List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } categorized_sample_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0 @@ -262,10 +268,9 @@ def _prepare_sample( categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) + sampling_params.sampling_type].append( + (categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx)) categorized_sample_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1 @@ -328,7 +333,7 @@ def _prepare_sample( def prepare_input_tensors( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata]: if self.is_driver_worker: @@ -381,7 +386,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 8468ace5a2fdc..3652830b7d519 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -1,5 +1,5 @@ """A CPU worker class.""" -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import torch import torch.distributed @@ -152,8 +152,8 @@ def __init__( is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine = None - self.cpu_cache = None + self.cache_engine: CPUCacheEngine + self.cpu_cache: List[torch.Tensor] def init_device(self) -> None: self.init_distributed_environment() @@ -257,13 +257,13 @@ def execute_model( ) -> List[SamplerOutput]: if self.is_driver_worker: assert seq_group_metadata_list is not None - num_seq_groups = len(seq_group_metadata_list) + num_seq_groups: int = len(seq_group_metadata_list) assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None assert blocks_to_copy is not None assert len(blocks_to_swap_in) == 0 assert len(blocks_to_swap_out) == 0 - data = { + data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_copy": blocks_to_copy, } @@ -273,6 +273,7 @@ def execute_model( num_seq_groups = data["num_seq_groups"] blocks_to_copy = data["blocks_to_copy"] + assert blocks_to_copy is not None self.cache_copy(blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 42c06a1b19361..31e08789dfd1f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -128,23 +128,17 @@ def __init__( if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.model = None - self.block_size = None # Set after initial profiling. - self.lora_manager = None + # Set after load_model. + self.lora_manager: LRUCacheWorkerLoRAManager = None self.graph_runners: Dict[int, CUDAGraphRunner] = {} - self.graph_memory_pool = None # Set during graph capture. + self.graph_memory_pool: Optional[Tuple[ + int, int]] = None # Set during graph capture. self.max_context_len_to_capture = ( self.model_config.max_context_len_to_capture if self.model_config is not None else 0) - # When using CUDA graph, the input block tables must be padded to - # max_context_len_to_capture. However, creating the block table in - # Python can be expensive. To optimize this, we cache the block table - # in numpy and only copy the actual input content at every iteration. - # The shape of the cached block table will be - # (max batch size to capture, max context len to capture / block size). - self.graph_block_tables = None # Set after initial profiling. + self.pin_memory = is_pin_memory_available() self.kv_cache_dtype = kv_cache_dtype self.vision_language_config = vision_language_config @@ -152,6 +146,17 @@ def __init__( self.attn_backend = get_attn_backend( self.model_config.dtype if model_config is not None else None) + # Lazy initialization + self.model: torch.nn.Module # Set after load_model + self.block_size: int # Set after initial profiling. + # When using CUDA graph, the input block tables must be padded to + # max_context_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables: torch.Tensor # Set after initial profiling. + def load_model(self) -> None: with CudaMemoryProfiler() as m: self.model = get_model( @@ -489,16 +494,16 @@ def _prepare_decode( lora_index_mapping.append(0) batch_size = graph_batch_size - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device=self.device) + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) if use_captured_graph: # When using cuda-graph all these tensors should be # padded. - assert context_lens.shape[0] == len(input_tokens) - assert context_lens.shape[0] == len(input_positions) - assert context_lens.shape[0] == len(slot_mapping) + assert context_lens_tensor.shape[0] == len(input_tokens) + assert context_lens_tensor.shape[0] == len(input_positions) + assert context_lens_tensor.shape[0] == len(slot_mapping) # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -527,7 +532,7 @@ def _prepare_decode( max_prompt_len=None, subquery_start_loc=None, seq_start_loc=None, - context_lens=context_lens, + context_lens=context_lens_tensor, block_tables=block_tables, use_cuda_graph=use_captured_graph, ) @@ -551,7 +556,11 @@ def _prepare_sample( selected_token_indices: List[int] = [] generators: List[torch.Generator] = [] selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices: Dict[SamplingType, + List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } categorized_sample_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0 @@ -569,10 +578,9 @@ def _prepare_sample( categorized_sample_indices_start_idx += subquery_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) + sampling_params.sampling_type].append( + (categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx)) categorized_sample_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1 @@ -596,15 +604,16 @@ def _prepare_sample( categorized_sample_indices[ sampling_params.sampling_type].extend( - zip( - range( - categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + - num_seqs), - range( - categorized_sampled_token_indices_start_idx, - categorized_sampled_token_indices_start_idx + - num_seqs))) + list( + zip( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + + num_seqs), + range( + categorized_sampled_token_indices_start_idx, + categorized_sampled_token_indices_start_idx + + num_seqs)))) categorized_sample_indices_start_idx += num_seqs categorized_sampled_token_indices_start_idx += num_seqs @@ -641,9 +650,9 @@ def _prepare_sample( def prepare_input_tensors( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, - Set[int], LoRAMapping, torch.Tensor]: + Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: prefill_reqs = [] decode_reqs = [] @@ -741,6 +750,7 @@ def prepare_input_tensors( if prefill_attn_metadata is not None: metadata_dict.update(prefill_attn_metadata.asdict_zerocopy()) else: + assert decode_attn_metadata is not None metadata_dict.update(decode_attn_metadata.asdict_zerocopy()) broadcast_tensor_dict(metadata_dict, src=0) @@ -809,7 +819,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, @@ -923,7 +933,7 @@ def remove_all_loras(self) -> bool: raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_all_loras() - def set_active_loras(self, lora_requests: List[LoRARequest], + def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") @@ -1065,10 +1075,16 @@ class CUDAGraphRunner: def __init__(self, model: nn.Module): self.model = model - self.graph = None self.input_buffers: Dict[str, torch.Tensor] = {} self.output_buffers: Dict[str, torch.Tensor] = {} + self._graph: Optional[torch.cuda.CUDAGraph] = None + + @property + def graph(self): + assert self._graph is not None + return self._graph + def capture( self, input_ids: torch.Tensor, @@ -1078,7 +1094,7 @@ def capture( memory_pool, **kwargs, ) -> None: - assert self.graph is None + assert self._graph is None # Run the model once without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). @@ -1095,8 +1111,8 @@ def capture( # Capture the graph. # NOTE(woosuk): Python 3.8 does not support multi-line with statements. # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement - self.graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 + self._graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self._graph, pool=memory_pool): # noqa: SIM117 with _maybe_pynccl(): hidden_states = self.model( input_ids, diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index f70a7193effeb..487df334d73e3 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,6 +1,7 @@ from typing import Dict, List, Optional, Tuple import torch +from torch import nn from vllm.config import (DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -34,9 +35,11 @@ def __init__( self.device_config = (device_config if device_config is not None else DeviceConfig()) self.device = self.device_config.device - self.model = None self.pin_memory = is_pin_memory_available() + # Lazy initialization. + self.model: nn.Module # initialize after load_model. + def load_model(self) -> None: self.model = get_neuron_model(self.model_config, parallel_config=self.parallel_config, @@ -147,7 +150,11 @@ def _prepare_sample( selected_token_indices: List[int] = [] generators: List[torch.Generator] = [] selected_token_start_idx = 0 - categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices: Dict[SamplingType, + List[Tuple[int, int]]] = { + t: [] + for t in SamplingType + } categorized_sample_indices_start_idx = 0 categorized_sampled_token_indices_start_idx = 0 @@ -165,10 +172,9 @@ def _prepare_sample( categorized_sample_indices_start_idx += prompt_len - 1 categorized_sample_indices[ - sampling_params.sampling_type].append([ - categorized_sample_indices_start_idx, - categorized_sampled_token_indices_start_idx - ]) + sampling_params.sampling_type].append( + (categorized_sample_indices_start_idx, + categorized_sampled_token_indices_start_idx)) categorized_sample_indices_start_idx += 1 categorized_sampled_token_indices_start_idx += 1 @@ -237,7 +243,7 @@ def _prepare_sample( def prepare_input_tensors( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, SamplingMetadata]: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -259,7 +265,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_block_ids, sampling_metadata ) = self.prepare_input_tensors(seq_group_metadata_list) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b021866965401..2203570b37ad6 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,7 +1,7 @@ """A GPU worker class.""" import gc import os -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import torch import torch.distributed @@ -82,8 +82,8 @@ def __init__( ) # Uninitialized cache engine. Will be initialized by # initialize_cache. - self.cache_engine = None - self.gpu_cache = None + self.cache_engine: CacheEngine + self.gpu_cache: List[torch.Tensor] def init_device(self) -> None: if self.device_config.device.type == "cuda": @@ -223,7 +223,7 @@ def execute_model( assert blocks_to_swap_in is not None assert blocks_to_swap_out is not None assert blocks_to_copy is not None - data = { + data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_swap_in": blocks_to_swap_in, "blocks_to_swap_out": blocks_to_swap_out, @@ -237,6 +237,9 @@ def execute_model( blocks_to_swap_out = data["blocks_to_swap_out"] blocks_to_copy = data["blocks_to_copy"] + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) # If there is no input, we don't need to execute the model. diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 309aa6256acea..13e062fe64b29 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,7 +1,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -56,7 +56,7 @@ def execute_model( raise NotImplementedError @abstractmethod - def get_cache_block_size_bytes() -> int: + def get_cache_block_size_bytes(self) -> int: """Return the size of a single cache block, in bytes. Used in speculative decoding. """ @@ -71,7 +71,7 @@ def remove_lora(self, lora_id: int) -> bool: raise NotImplementedError @abstractmethod - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise NotImplementedError @@ -86,7 +86,7 @@ def add_lora(self, lora_request: LoRARequest) -> bool: def remove_lora(self, lora_id: int) -> bool: raise ValueError(f"{type(self)} does not support LoRA") - def list_loras(self) -> List[int]: + def list_loras(self) -> Set[int]: raise ValueError(f"{type(self)} does not support LoRA")