From 6531fd6eabeef18d21cb66fbd49056d75ed274fe Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Tue, 8 Oct 2024 11:51:14 +0800 Subject: [PATCH] [Intel GPU] Fix xpu decode input (#9145) Signed-off-by: Sumit Dubey --- vllm/worker/xpu_model_runner.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 8282736cf479b..612428180226a 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -15,6 +15,7 @@ from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger +from vllm.model_executor import SamplingMetadataCache from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, @@ -136,7 +137,7 @@ def build(self) -> ModelInputForXPU: (input_tokens, input_positions, attn_metadata) = self._prepare_decode( self.seq_group_metadata_list) - seq_lens = [] + seq_lens = None multi_modal_kwargs = None return self.model_input_cls( @@ -390,6 +391,10 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model + self.sampling_metadata_cache: SamplingMetadataCache = \ + SamplingMetadataCache() \ + if self.parallel_config.pipeline_parallel_size == 1 else None + def load_model(self) -> None: with DeviceMemoryProfiler() as m: self.model = get_model( @@ -524,12 +529,14 @@ def prepare_model_input( seq_group_metadata_list, finished_requests_ids) # Sampling metadata is only required for the final pp group generators = self.get_generators(finished_requests_ids) - sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list, - model_input.seq_lens, - model_input.query_lens, - self.device, - pin_memory=False, - generators=generators) + sampling_metadata = SamplingMetadata.prepare( + seq_group_metadata_list, + model_input.seq_lens, + model_input.query_lens, + self.device, + pin_memory=False, + generators=generators, + cache=self.sampling_metadata_cache) return dataclasses.replace(model_input, sampling_metadata=sampling_metadata,