diff --git a/setup.py b/setup.py index 0531e1f01d33f..6f1f2faf54dbc 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,6 @@ import io import os import re -import shutil import subprocess import warnings from pathlib import Path diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index f0a88ac8e27f8..ebba0ba0a261a 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -1,8 +1,10 @@ +from dataclasses import dataclass from typing import Optional import torch +@dataclass class InputMetadata: """Metadata for input sequences. Used in PagedAttention. @@ -15,40 +17,17 @@ class InputMetadata: kv_cache_dtype: Data type to store kv cache. """ - def __init__( - self, - is_prompt: bool, - slot_mapping: torch.Tensor, - prompt_lens: Optional[torch.Tensor], - max_seq_len: Optional[int], - start_loc: Optional[torch.Tensor], - max_context_len: Optional[int], - context_lens: Optional[torch.Tensor], - block_tables: Optional[torch.Tensor], - use_cuda_graph: bool, - kv_cache_dtype: str, - ) -> None: - self.is_prompt = is_prompt - self.prompt_lens = prompt_lens - self.max_seq_len = max_seq_len - self.start_loc = start_loc - self.max_context_len = max_context_len - self.slot_mapping = slot_mapping - self.context_lens = context_lens - self.block_tables = block_tables - self.use_cuda_graph = use_cuda_graph - self.kv_cache_dtype = kv_cache_dtype + is_prompt: bool + slot_mapping: torch.Tensor + prompt_lens: Optional[torch.Tensor] + max_seq_len: Optional[int] + start_loc: Optional[torch.Tensor] + max_context_len: Optional[int] + context_lens: Optional[torch.Tensor] + block_tables: Optional[torch.Tensor] + use_cuda_graph: bool + kv_cache_dtype: str - # Set during the execution of the first attention op. - # FIXME(woosuk): This is a hack. + def __post_init__(self): + # will not appear in the __repr__ and __init__ self.attn_bias = None - - def __repr__(self) -> str: - return ("InputMetadata(" - f"is_prompt={self.is_prompt}, " - f"max_context_len={self.max_context_len}, " - f"slot_mapping={self.slot_mapping}, " - f"context_lens={self.context_lens}, " - f"block_tables={self.block_tables}, " - f"use_cuda_graph={self.use_cuda_graph}, " - f"kv_cache_dtype={self.kv_cache_dtype})") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 7eac576e3f0fe..1ef783da6d08e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,4 +1,5 @@ import contextlib +import dataclasses import time from typing import Dict, List, Optional, Tuple, Set, Union @@ -521,45 +522,27 @@ def prepare_input_tensors( metadata_dict = { "input_tokens": input_tokens, "input_positions": input_positions, - "is_prompt": input_metadata.is_prompt, - "slot_mapping": input_metadata.slot_mapping, - "prompt_lens": input_metadata.prompt_lens, - "max_seq_len": input_metadata.max_seq_len, - "start_loc": input_metadata.start_loc, - "max_context_len": input_metadata.max_context_len, - "context_lens": input_metadata.context_lens, - "block_tables": input_metadata.block_tables, - "use_cuda_graph": input_metadata.use_cuda_graph, - "kv_cache_dtype": input_metadata.kv_cache_dtype, "selected_token_indices": sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, } + metadata_dict.update(dataclasses.asdict(input_metadata)) broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) - input_tokens = metadata_dict["input_tokens"] - input_positions = metadata_dict["input_positions"] - lora_mapping = metadata_dict["lora_mapping"] - lora_requests = metadata_dict["lora_requests"] - input_metadata = InputMetadata( - is_prompt=metadata_dict["is_prompt"], - slot_mapping=metadata_dict["slot_mapping"], - prompt_lens=metadata_dict["prompt_lens"], - max_seq_len=metadata_dict["max_seq_len"], - start_loc=metadata_dict["start_loc"], - max_context_len=metadata_dict["max_context_len"], - context_lens=metadata_dict["context_lens"], - block_tables=metadata_dict["block_tables"], - use_cuda_graph=metadata_dict["use_cuda_graph"], - kv_cache_dtype=metadata_dict["kv_cache_dtype"], - ) + input_tokens = metadata_dict.pop("input_tokens") + input_positions = metadata_dict.pop("input_positions") + selected_token_indices = metadata_dict.pop( + "selected_token_indices") + lora_mapping = metadata_dict.pop("lora_mapping") + lora_requests = metadata_dict.pop("lora_requests") + input_metadata = InputMetadata(**metadata_dict) sampling_metadata = SamplingMetadata( seq_groups=None, seq_data=None, prompt_lens=None, - selected_token_indices=metadata_dict["selected_token_indices"], + selected_token_indices=selected_token_indices, categorized_sample_indices=None, generators=None, perform_sampling=False,