diff --git a/examples/offline_inference_tt.py b/examples/offline_inference_tt.py index 67e0c733f20d9..d1b3347c276e5 100644 --- a/examples/offline_inference_tt.py +++ b/examples/offline_inference_tt.py @@ -16,8 +16,8 @@ from vllm.engine.multiprocessing.client import MQLLMEngineClient # Import and register model from tt-metal -from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration -ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration) +from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM +ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM) def run_inference( diff --git a/examples/server_example_tt.py b/examples/server_example_tt.py index 8116a54584593..941573c69ebcd 100644 --- a/examples/server_example_tt.py +++ b/examples/server_example_tt.py @@ -5,8 +5,8 @@ from vllm import ModelRegistry # Import and register model from tt-metal -from models.demos.t3000.llama2_70b.tt.llama_generation import TtLlamaModelForGeneration -ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaModelForGeneration) +from models.demos.t3000.llama2_70b.tt.generator_vllm import TtLlamaForCausalLM +ModelRegistry.register_model("TTLlamaForCausalLM", TtLlamaForCausalLM) def main(): diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f90ef219fcfc9..a6320d235c113 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -694,6 +694,9 @@ def _add_processed_request( raise ValueError( "Either SamplingParams or PoolingParams must be provided.") + # Validate that the sequence group is compatible with the device's model executor. + self._validate_device_inputs(seq_group) + # Add the sequence group to the scheduler with least unfinished seqs. costs = [ scheduler.get_num_unfinished_seq_groups() @@ -1905,6 +1908,19 @@ def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs, # TODO: Find out how many placeholder tokens are there so we can # check that chunked prefill does not truncate them # max_batch_len = self.scheduler_config.max_num_batched_tokens + + def _validate_device_inputs(self, seq_group: SequenceGroup): + ''' + Device-specific validations for model inputs and params. This is in + contrast to the device-agnostic validations in _validate_model_inputs. + These validations must be performed when adding a request (prior to step + execution) so that server instances can reject requests that are not + compatible with the device they are running on (instead of crashing). + ''' + + # Currently only supported for TT devices + if self.device_config.device_type == "tt": + self.model_executor.validate_seq_group(seq_group) def _build_logits_processors( self, sampling_params: SamplingParams, diff --git a/vllm/executor/tt_executor.py b/vllm/executor/tt_executor.py index 737217fbf815d..fa441c43b758c 100644 --- a/vllm/executor/tt_executor.py +++ b/vllm/executor/tt_executor.py @@ -4,7 +4,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest +from vllm.sequence import ExecuteModelRequest, SequenceGroup from vllm.utils import make_async logger = init_logger(__name__) @@ -99,6 +99,12 @@ def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: def list_prompt_adapters(self) -> Set[int]: raise NotImplementedError( "Soft prompt is currently not supported by the TT backend.") + + def validate_seq_group(self, seq_group: SequenceGroup) -> None: + ''' + Validate the sequence group before it is scheduled for execution in LLMEngine::_add_processed_request. + ''' + self.driver_worker.model_runner.validate_seq_group(seq_group) class TTExecutorAsync(TTExecutor, ExecutorAsyncBase): diff --git a/vllm/worker/tt_model_runner.py b/vllm/worker/tt_model_runner.py index 1c586dd313bcb..240ee786063f8 100644 --- a/vllm/worker/tt_model_runner.py +++ b/vllm/worker/tt_model_runner.py @@ -1,7 +1,7 @@ import dataclasses import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union import torch import torch.nn.functional as F @@ -13,11 +13,9 @@ from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tt_loader import TTModelLoader -from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata, Logprob, SequenceOutput, CompletionSequenceGroupOutput, SequenceGroup from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase from vllm.utils import make_tensor_with_pad -if TYPE_CHECKING: - from vllm.attention.backends.abstract import AttentionBackend logger = init_logger(__name__) @@ -136,6 +134,26 @@ def make_model_input_from_broadcasted_tensor_dict( tensor_dict, ) + def validate_seq_group(self, seq_group: SequenceGroup) -> None: + ''' + Validate a sequence group before it is scheduled for execution. + Called by TTExecutor::validate_seq_group before the sequence group + is scheduled for execution in LLMEngine::_add_processed_request. + ''' + + sampling_params = seq_group.sampling_params + + if seq_group.num_seqs() != 1: + raise ValueError("Currently only supporting one sequence per request group") + if sampling_params.n != 1: + raise ValueError("Currently only supporting n=1") + if sampling_params.best_of is not None: + raise ValueError("Currently not supporting best_of") + if sampling_params.logprobs is not None: + raise ValueError("Currently not supporting logprobs") + if sampling_params.prompt_logprobs is not None: + raise ValueError("Currently not supporting prompt_logprobs") + def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -160,7 +178,7 @@ def prepare_model_input( for seq_group_metadata in seq_group_metadata_list: seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 # Only support one sequence per request group + assert len(seq_ids) == 1, "Currently only supporting one sequence per request group" seq_id = seq_ids[0] seq_groups.append(seq_id) @@ -189,15 +207,17 @@ def prepare_model_input( # Sampling params # TODO: Add support for different sampling params in the same batch sampling_params = seq_group_metadata.sampling_params - self._validate_sampling_params(sampling_params) if len(top_pk_sampling_params) == 0: top_pk_sampling_params["temperature"] = sampling_params.temperature top_pk_sampling_params["top_k"] = sampling_params.top_k top_pk_sampling_params["top_p"] = sampling_params.top_p else: - assert top_pk_sampling_params["temperature"] == sampling_params.temperature, "Currently only supporting same temperature for all sequences in batch" - assert top_pk_sampling_params["top_k"] == sampling_params.top_k, "Currently only supporting same top_k for all sequences in batch" - assert top_pk_sampling_params["top_p"] == sampling_params.top_p, "Currently only supporting same top_p for all sequences in batch" + if top_pk_sampling_params["temperature"] != sampling_params.temperature: + logger.warning(f"Currently only supporting same temperature for all sequences in batch, falling back to first sequence's temperature ({top_pk_sampling_params['temperature']})") + if top_pk_sampling_params["top_k"] != sampling_params.top_k: + logger.warning(f"Currently only supporting same top_k for all sequences in batch, falling back to first sequence's top_k ({top_pk_sampling_params['top_k']})") + if top_pk_sampling_params["top_p"] != sampling_params.top_p: + logger.warning(f"Currently only supporting same top_p for all sequences in batch, falling back to first sequence's top_p ({top_pk_sampling_params['top_p']})") tt_sampling_params = TTSamplingParams( temperature=top_pk_sampling_params["temperature"], @@ -423,12 +443,6 @@ def _sample_tokens(self, logits, tt_sampling_params : TTSamplingParams): k=tt_sampling_params.top_k, temperature=tt_sampling_params.temperature ) - - def _validate_sampling_params(self, sampling_params): - assert sampling_params.n == 1, "Currently only supporting n=1" - assert sampling_params.best_of is None, "Currently not supporting best_of" - assert sampling_params.logprobs is None, "Currently not supporting logprobs" - assert sampling_params.prompt_logprobs is None, "Currently not supporting prompt_logprobs" ## Destructor (used to delete ttnn trace if using trace mode)