Skip to content

Commit

Permalink
[Misc] Use dataclass for InputMetadata (#3452)
Browse files Browse the repository at this point in the history
Co-authored-by: youkaichao <[email protected]>
  • Loading branch information
WoosukKwon and youkaichao authored Mar 17, 2024
1 parent 6b78837 commit abfc4f3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 63 deletions.
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import io
import os
import re
import shutil
import subprocess
import warnings
from pathlib import Path
Expand Down
49 changes: 14 additions & 35 deletions vllm/model_executor/input_metadata.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import dataclass
from typing import Optional

import torch


@dataclass
class InputMetadata:
"""Metadata for input sequences. Used in PagedAttention.
Expand All @@ -15,40 +17,17 @@ class InputMetadata:
kv_cache_dtype: Data type to store kv cache.
"""

def __init__(
self,
is_prompt: bool,
slot_mapping: torch.Tensor,
prompt_lens: Optional[torch.Tensor],
max_seq_len: Optional[int],
start_loc: Optional[torch.Tensor],
max_context_len: Optional[int],
context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor],
use_cuda_graph: bool,
kv_cache_dtype: str,
) -> None:
self.is_prompt = is_prompt
self.prompt_lens = prompt_lens
self.max_seq_len = max_seq_len
self.start_loc = start_loc
self.max_context_len = max_context_len
self.slot_mapping = slot_mapping
self.context_lens = context_lens
self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph
self.kv_cache_dtype = kv_cache_dtype
is_prompt: bool
slot_mapping: torch.Tensor
prompt_lens: Optional[torch.Tensor]
max_seq_len: Optional[int]
start_loc: Optional[torch.Tensor]
max_context_len: Optional[int]
context_lens: Optional[torch.Tensor]
block_tables: Optional[torch.Tensor]
use_cuda_graph: bool
kv_cache_dtype: str

# Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack.
def __post_init__(self):
# will not appear in the __repr__ and __init__
self.attn_bias = None

def __repr__(self) -> str:
return ("InputMetadata("
f"is_prompt={self.is_prompt}, "
f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, "
f"block_tables={self.block_tables}, "
f"use_cuda_graph={self.use_cuda_graph}, "
f"kv_cache_dtype={self.kv_cache_dtype})")
37 changes: 10 additions & 27 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import dataclasses
import time
from typing import Dict, List, Optional, Tuple, Set, Union

Expand Down Expand Up @@ -521,45 +522,27 @@ def prepare_input_tensors(
metadata_dict = {
"input_tokens": input_tokens,
"input_positions": input_positions,
"is_prompt": input_metadata.is_prompt,
"slot_mapping": input_metadata.slot_mapping,
"prompt_lens": input_metadata.prompt_lens,
"max_seq_len": input_metadata.max_seq_len,
"start_loc": input_metadata.start_loc,
"max_context_len": input_metadata.max_context_len,
"context_lens": input_metadata.context_lens,
"block_tables": input_metadata.block_tables,
"use_cuda_graph": input_metadata.use_cuda_graph,
"kv_cache_dtype": input_metadata.kv_cache_dtype,
"selected_token_indices":
sampling_metadata.selected_token_indices,
"lora_requests": lora_requests,
"lora_mapping": lora_mapping,
}
metadata_dict.update(dataclasses.asdict(input_metadata))
broadcast_tensor_dict(metadata_dict, src=0)
else:
metadata_dict = broadcast_tensor_dict(src=0)
input_tokens = metadata_dict["input_tokens"]
input_positions = metadata_dict["input_positions"]
lora_mapping = metadata_dict["lora_mapping"]
lora_requests = metadata_dict["lora_requests"]
input_metadata = InputMetadata(
is_prompt=metadata_dict["is_prompt"],
slot_mapping=metadata_dict["slot_mapping"],
prompt_lens=metadata_dict["prompt_lens"],
max_seq_len=metadata_dict["max_seq_len"],
start_loc=metadata_dict["start_loc"],
max_context_len=metadata_dict["max_context_len"],
context_lens=metadata_dict["context_lens"],
block_tables=metadata_dict["block_tables"],
use_cuda_graph=metadata_dict["use_cuda_graph"],
kv_cache_dtype=metadata_dict["kv_cache_dtype"],
)
input_tokens = metadata_dict.pop("input_tokens")
input_positions = metadata_dict.pop("input_positions")
selected_token_indices = metadata_dict.pop(
"selected_token_indices")
lora_mapping = metadata_dict.pop("lora_mapping")
lora_requests = metadata_dict.pop("lora_requests")
input_metadata = InputMetadata(**metadata_dict)
sampling_metadata = SamplingMetadata(
seq_groups=None,
seq_data=None,
prompt_lens=None,
selected_token_indices=metadata_dict["selected_token_indices"],
selected_token_indices=selected_token_indices,
categorized_sample_indices=None,
generators=None,
perform_sampling=False,
Expand Down

0 comments on commit abfc4f3

Please sign in to comment.