Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device-specific request validation in LLMEngine and modify request-specific asserts in TTModelRunner to not crash server instances #41

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from vllm.engine.multiprocessing.client import MQLLMEngineClient

# Import and register model from tt-metal
from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration)
from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)


def run_inference(
Expand Down
4 changes: 2 additions & 2 deletions examples/server_example_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from vllm import ModelRegistry

# Import and register model from tt-metal
from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration)
from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM
ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM)


def main():
Expand Down
16 changes: 16 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,9 @@ def _add_processed_request(
raise ValueError(
"Either SamplingParams or PoolingParams must be provided.")

# Validate that the sequence group is compatible with the device's model executor.
self._validate_device_inputs(seq_group)

# Add the sequence group to the scheduler with least unfinished seqs.
costs = [
scheduler.get_num_unfinished_seq_groups()
Expand Down Expand Up @@ -1905,6 +1908,19 @@ def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens

def _validate_device_inputs(self, seq_group: SequenceGroup):
'''
Device-specific validations for model inputs and params. This is in
contrast to the device-agnostic validations in _validate_model_inputs.
These validations must be performed when adding a request (prior to step
execution) so that server instances can reject requests that are not
compatible with the device they are running on (instead of crashing).
'''

# Currently only supported for TT devices
if self.device_config.device_type == "tt":
self.model_executor.validate_seq_group(seq_group)

def _build_logits_processors(
self, sampling_params: SamplingParams,
Expand Down
8 changes: 7 additions & 1 deletion vllm/executor/tt_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
from vllm.sequence import ExecuteModelRequest, SequenceGroup
from vllm.utils import make_async

logger = init_logger(__name__)
Expand Down Expand Up @@ -99,6 +99,12 @@ def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
def list_prompt_adapters(self) -> Set[int]:
raise NotImplementedError(
"Soft prompt is currently not supported by the TT backend.")

def validate_seq_group(self, seq_group: SequenceGroup) -> None:
'''
Validate the sequence group before it is scheduled for execution in LLMEngine::_add_processed_request.
'''
self.driver_worker.model_runner.validate_seq_group(seq_group)

class TTExecutorAsync(TTExecutor, ExecutorAsyncBase):

Expand Down
44 changes: 29 additions & 15 deletions vllm/worker/tt_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

import torch
import torch.nn.functional as F
Expand All @@ -13,11 +13,9 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.tt_loader import TTModelLoader
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput, SequenceGroup
from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase
from vllm.utils import make_tensor_with_pad
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend

logger = init_logger(__name__)

Expand Down Expand Up @@ -136,6 +134,26 @@ def make_model_input_from_broadcasted_tensor_dict(
tensor_dict,
)

def validate_seq_group(self, seq_group: SequenceGroup) -> None:
'''
Validate a sequence group before it is scheduled for execution.
Called by TTExecutor::validate_seq_group before the sequence group
is scheduled for execution in LLMEngine::_add_processed_request.
'''

sampling_params = seq_group.sampling_params

if seq_group.num_seqs() != 1:
raise ValueError("Currently only supporting one sequence per request group")
if sampling_params.n != 1:
raise ValueError("Currently only supporting n=1")
if sampling_params.best_of is not None:
raise ValueError("Currently not supporting best_of")
if sampling_params.logprobs is not None:
raise ValueError("Currently not supporting logprobs")
if sampling_params.prompt_logprobs is not None:
raise ValueError("Currently not supporting prompt_logprobs")

def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand All @@ -160,7 +178,7 @@ def prepare_model_input(

for seq_group_metadata in seq_group_metadata_list:
seq_ids = list(seq_group_metadata.seq_data.keys())
assert len(seq_ids) == 1 # Only support one sequence per request group
assert len(seq_ids) == 1, "Currently only supporting one sequence per request group"
seq_id = seq_ids[0]
seq_groups.append(seq_id)

Expand Down Expand Up @@ -189,15 +207,17 @@ def prepare_model_input(
# Sampling params
# TODO: Add support for different sampling params in the same batch
sampling_params = seq_group_metadata.sampling_params
self._validate_sampling_params(sampling_params)
if len(top_pk_sampling_params) == 0:
top_pk_sampling_params["temperature"] = sampling_params.temperature
top_pk_sampling_params["top_k"] = sampling_params.top_k
top_pk_sampling_params["top_p"] = sampling_params.top_p
else:
assert top_pk_sampling_params["temperature"] == sampling_params.temperature, "Currently only supporting same temperature for all sequences in batch"
assert top_pk_sampling_params["top_k"] == sampling_params.top_k, "Currently only supporting same top_k for all sequences in batch"
assert top_pk_sampling_params["top_p"] == sampling_params.top_p, "Currently only supporting same top_p for all sequences in batch"
if top_pk_sampling_params["temperature"] != sampling_params.temperature:
logger.warning(f"Currently only supporting same temperature for all sequences in batch, falling back to first sequence's temperature ({top_pk_sampling_params['temperature']})")
if top_pk_sampling_params["top_k"] != sampling_params.top_k:
logger.warning(f"Currently only supporting same top_k for all sequences in batch, falling back to first sequence's top_k ({top_pk_sampling_params['top_k']})")
if top_pk_sampling_params["top_p"] != sampling_params.top_p:
logger.warning(f"Currently only supporting same top_p for all sequences in batch, falling back to first sequence's top_p ({top_pk_sampling_params['top_p']})")

tt_sampling_params = TTSamplingParams(
temperature=top_pk_sampling_params["temperature"],
Expand Down Expand Up @@ -423,12 +443,6 @@ def _sample_tokens(self, logits, tt_sampling_params : TTSamplingParams):
k=tt_sampling_params.top_k,
temperature=tt_sampling_params.temperature
)

def _validate_sampling_params(self, sampling_params):
assert sampling_params.n == 1, "Currently only supporting n=1"
assert sampling_params.best_of is None, "Currently not supporting best_of"
assert sampling_params.logprobs is None, "Currently not supporting logprobs"
assert sampling_params.prompt_logprobs is None, "Currently not supporting prompt_logprobs"

## Destructor (used to delete ttnn trace if using trace mode)

Expand Down
Loading