From ffe24e86789f781cf7bff115506b6b9d2c51d2ba Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Tue, 26 Mar 2024 10:37:08 +0000 Subject: [PATCH 1/5] initial commit --- tests/conftest.py | 75 ++- tests/models/test_whisper.py | 67 +++ vllm/config.py | 6 + vllm/core/scheduler.py | 2 + vllm/engine/arg_utils.py | 10 +- vllm/engine/async_llm_engine.py | 26 +- vllm/engine/llm_engine.py | 14 +- vllm/entrypoints/llm.py | 5 + .../layers/enc_dec_attention.py | 12 +- vllm/model_executor/model_loader.py | 11 +- vllm/model_executor/models/__init__.py | 2 + vllm/model_executor/models/whisper.py | 444 ++++++++++++++++++ vllm/sequence.py | 24 + vllm/worker/model_runner.py | 105 ++++- vllm/worker/worker.py | 25 +- 15 files changed, 761 insertions(+), 67 deletions(-) create mode 100644 tests/models/test_whisper.py create mode 100644 vllm/model_executor/models/whisper.py diff --git a/tests/conftest.py b/tests/conftest.py index 6eb8159837d51..fab45f0e5668b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,8 +3,8 @@ import pytest import torch -from transformers import AutoModelForCausalLM - +from transformers import AutoModelForCausalLM, WhisperForConditionalGeneration, AutoProcessor +from vllm.sequence import MultiModalData from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer @@ -16,7 +16,10 @@ def _read_prompts(filename: str) -> List[str]: with open(filename, "r") as f: prompts = f.readlines() - return prompts + return prompts + + +AUDIO_MODELS = {"openai/whisper-tiny": WhisperForConditionalGeneration} @pytest.fixture @@ -51,12 +54,21 @@ def __init__( dtype: str = "half", ) -> None: assert dtype in _STR_DTYPE_TO_TORCH_DTYPE - torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch_dtype, - trust_remote_code=True, - ).cuda() + self.torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + self.model_name = model_name + if model_name in AUDIO_MODELS: + self.model = AUDIO_MODELS[model_name].from_pretrained( + model_name, + torch_dtype=self.torch_dtype, + trust_remote_code=True).cuda() + self.processor = AutoProcessor.from_pretrained( + model_name, torch_dtype=self.torch_dtype) + else: + self.model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=self.torch_dtype, + trust_remote_code=True).cuda() + self.processor = None if tokenizer_name is None: tokenizer_name = model_name self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True) @@ -64,13 +76,29 @@ def __init__( def generate( self, prompts: List[str], + audio_samples: Optional["TODO"] = None, **kwargs, ) -> List[Tuple[List[int], str]]: outputs: List[Tuple[List[int], str]] = [] - for prompt in prompts: - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids + if audio_samples: + assert len(prompts) == len(audio_samples) + for i, prompt in enumerate(prompts): + if self.model_name in AUDIO_MODELS: + audio_sample = audio_samples[i] + input_features = self.processor( + audio_sample["array"], + sampling_rate=audio_sample["sampling_rate"], + return_tensors="pt").input_features + + inputs = dict( + input_features=input_features.cuda()) #TODO: Fix this + else: + input_ids = self.tokenizer(prompt, + return_tensors="pt").input_ids + inputs = dict(input_ids=input_ids.cuda()) + output_ids = self.model.generate( - input_ids.cuda(), + **inputs, use_cache=True, **kwargs, ) @@ -87,8 +115,10 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, + audio_samples: Optional["TODO"] = None, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, + audio_samples=audio_samples, do_sample=False, max_new_tokens=max_tokens) for i in range(len(outputs)): @@ -182,8 +212,24 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, + audio_samples: Optional["TODO"] = None, ) -> List[Tuple[List[int], str]]: + if audio_samples: + assert len(prompts) == len(audio_samples) + processor = AutoProcessor.from_pretrained("openai/whisper-tiny") + input_features = processor( + audio_samples[0]["array"], + sampling_rate=audio_samples[0]["sampling_rate"], + return_tensors="pt").input_features + # change type of input features + input_features = input_features.to( + dtype=self.model.llm_engine.model_config.dtype) + multi_modal_data = MultiModalData(type=input_features.dtype, + data=input_features[0]) + else: + multi_modal_data = None req_outputs = self.model.generate(prompts, + multi_modal_data=multi_modal_data, sampling_params=sampling_params) outputs = [] for req_output in req_outputs: @@ -221,9 +267,12 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, + audio_samples: Optional["TODO"] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) - outputs = self.generate(prompts, greedy_params) + outputs = self.generate(prompts, + greedy_params, + audio_samples=audio_samples) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] diff --git a/tests/models/test_whisper.py b/tests/models/test_whisper.py new file mode 100644 index 0000000000000..80cba1e778a37 --- /dev/null +++ b/tests/models/test_whisper.py @@ -0,0 +1,67 @@ +import pytest +import torch +from datasets import load_dataset +from vllm.config import AudioFeaturesConfig + +import os + +os.environ['CUDA_LAUNCH_BLOCKING'] = "0" +os.environ['CUDA_VISIBLE_DEVICES'] = "6" + + +@pytest.fixture() +def model_id(): + return "openai/whisper-tiny" + + +@pytest.fixture() +def audio_features_config(): + return AudioFeaturesConfig() + + +def sample_from_librispeech(): + dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", + "clean", + split="validation") + return dataset[0] + + +audio_sample = sample_from_librispeech()["audio"] + + +@pytest.mark.parametrize("dtype", ["half"]) # TODO fix that +@pytest.mark.parametrize("max_tokens", [2]) +@pytest.mark.parametrize("prompts, audio_samples", [([""], [audio_sample])]) +def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts, + audio_samples, dtype: str, + max_tokens: int) -> None: + # hf_model = hf_runner(model_id, dtype=dtype) + # hf_outputs = hf_model.generate_greedy(prompts=prompts, + # audio_samples=audio_samples, + # max_tokens=max_tokens) + # del hf_model + + # Truly cleans up GPU memory. + torch.cuda.empty_cache() + + vllm_model = vllm_runner(model_id, + dtype=dtype, + enforce_eager=True, + tensor_parallel_size=1, + gpu_memory_utilization=0.9) + vllm_outputs = vllm_model.generate_greedy(prompts, + max_tokens, + audio_samples=audio_samples) + del vllm_model + # Truly cleans up GPU memory. + torch.cuda.empty_cache() + + for i in range(len(hf_image_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = sanitize_vllm_output( + vllm_outputs[i], vision_language_config, model_id) + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + diff --git a/vllm/config.py b/vllm/config.py index e260e6a0cb1d6..9c9b59bf565b6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -510,6 +510,12 @@ def is_neuron(self): return self.device_type == "neuron" +@dataclass +class AudioFeaturesConfig: + feature_dims: int = 80 + sequence_length: int = 3000 + + @dataclass class LoRAConfig: max_lora_rank: int diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5e7cc3091d775..b0854858ee7a7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -388,6 +388,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: lora_request=seq_group.lora_request, prefix=seq_group.prefix, state=seq_group.state, + multi_modal_data=seq_group.multi_modal_data + if scheduler_outputs.prompt_run else None, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c01e7311fb89a..4db3322a0a084 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,8 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) + ParallelConfig, SchedulerConfig, LoRAConfig, + AudioFeaturesConfig) @dataclass @@ -50,6 +51,7 @@ def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model + # TODO: Add Whisper-specific args maybe @staticmethod def add_cli_args( parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -282,7 +284,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, - DeviceConfig, Optional[LoRAConfig]]: + DeviceConfig, Optional[LoRAConfig], + Optional[AudioFeaturesConfig]]: device_config = DeviceConfig(self.device) model_config = ModelConfig( self.model, self.tokenizer, self.tokenizer_mode, @@ -310,8 +313,9 @@ def create_engine_configs( lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None + audio_features_config = AudioFeaturesConfig() # for now return (model_config, cache_config, parallel_config, scheduler_config, - device_config, lora_config) + device_config, lora_config, audio_features_config) @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index daa6419cdad3b..893e13cf4da0d 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -12,6 +12,7 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams +from vllm.sequence import MultiModalData logger = init_logger(__name__) @@ -226,6 +227,7 @@ async def add_request_async( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, prefix_pos: Optional[int] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -246,6 +248,7 @@ async def add_request_async( arrival_time=arrival_time, lora_request=lora_request, prefix_pos=prefix_pos, + multi_modal_data=multi_modal_data, ) async def _run_workers_async( @@ -423,6 +426,7 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, prefix_pos: Optional[int] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -439,6 +443,7 @@ async def add_request( f"sampling_params: {sampling_params}, " f"prompt_token_ids: {shortened_token_ids}, " f"lora_request: {lora_request}.") + # TODO: Add multi_modal_data to the log. if not self.is_running: if self.start_engine_loop: @@ -473,7 +478,8 @@ async def add_request( prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, lora_request=lora_request, - prefix_pos=prefix_pos) + prefix_pos=prefix_pos, + multi_modal_data=multi_modal_data) return stream @@ -485,6 +491,7 @@ async def generate( prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, prefix_pos: Optional[int] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -558,15 +565,14 @@ async def generate( arrival_time = time.monotonic() try: - stream = await self.add_request( - request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - lora_request=lora_request, - prefix_pos=prefix_pos, - ) + stream = await self.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + prefix_pos=prefix_pos, + multi_modal_data=multi_modal_data) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3f63015c382fe..8961d9c4283ad 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -9,7 +9,8 @@ from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) + ParallelConfig, SchedulerConfig, LoRAConfig, + AudioFeaturesConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -18,7 +19,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) + SequenceGroupOutput, SequenceOutput, SequenceStatus, + MultiModalData) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, TokenizerGroup) from vllm.utils import (Counter, set_cuda_visible_devices, get_ip, @@ -81,6 +83,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + audio_features_config: Optional[AudioFeaturesConfig], placement_group: Optional["PlacementGroup"], log_stats: bool, ) -> None: @@ -108,6 +111,7 @@ def __init__( self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config + self.audio_features_config = audio_features_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config @@ -177,6 +181,7 @@ def _init_workers(self): rank=0, distributed_init_method=distributed_init_method, lora_config=self.lora_config, + audio_features_config=self.audio_features_config, kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=True, ) @@ -402,6 +407,7 @@ def encode_request( prompt: Optional[str], prompt_token_ids: Optional[List[int]] = None, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ): if prompt_token_ids is None: assert prompt is not None @@ -418,6 +424,7 @@ def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, prefix_pos: Optional[int] = None, ) -> None: """Add a request to the engine's request pool. @@ -493,7 +500,8 @@ def add_request( # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, lora_request, prefix) + arrival_time, lora_request, prefix, + multi_modal_data) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 2f475dd7924ae..57ffafc3d7f9d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -9,6 +9,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.utils import Counter +from vllm.sequence import MultiModalData class LLM: @@ -127,6 +128,7 @@ def generate( prefix_pos: Optional[Union[int, List[int]]] = None, use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -182,6 +184,7 @@ def generate( sampling_params, token_ids, lora_request=lora_request, + multi_modal_data=multi_modal_data, prefix_pos=prefix_pos_i) return self._run_engine(use_tqdm) @@ -191,6 +194,7 @@ def _add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, prefix_pos: Optional[int] = None, ) -> None: request_id = str(next(self.request_counter)) @@ -199,6 +203,7 @@ def _add_request( sampling_params, prompt_token_ids, lora_request=lora_request, + multi_modal_data=multi_modal_data, prefix_pos=prefix_pos) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py index bbdeee8e5e343..8cbb1c703c1f8 100644 --- a/vllm/model_executor/layers/enc_dec_attention.py +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -73,11 +73,11 @@ def forward( query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_heads, self.head_size) value = value.view(batch_size, seq_len, self.num_heads, self.head_size) - if input_metadata.attn_bias is None: - input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * batch_size) + # if input_metadata.attn_bias is None: + # input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( + # [seq_len] * batch_size) - input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] + # input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] # Normal attention out = xops.memory_efficient_attention_forward( @@ -137,7 +137,6 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - cache_ops.reshape_and_cache( key, value, key_cache, value_cache, input_metadata.slot_mapping[:, -1].flatten().contiguous(), @@ -159,7 +158,8 @@ def forward( num_kv_heads=self.num_heads, scale=self.scale, alibi_slopes=None, - custom_bias=input_metadata.attn_bias.to(torch.float32), + custom_bias=input_metadata.attn_bias.to(torch.float32) + if input_metadata.attn_bias is not None else None, kv_cache_dtype=input_metadata.kv_cache_dtype, ) return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index cb64d80c8147d..d7bfcbac682d5 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -4,12 +4,14 @@ import torch import torch.nn as nn - +from vllm.model_executor.models.whisper import WhisperForConditionalGeneration from vllm.config import DeviceConfig, ModelConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) +_AUDIO_MODEL_CLASSES = [WhisperForConditionalGeneration] + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): @@ -40,6 +42,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig, device_config: DeviceConfig, **kwargs) -> nn.Module: lora_config = kwargs.get("lora_config", None) + audio_features_config = kwargs.get("audio_features_config", None) model_class = _get_model_architecture(model_config) # Get the (maybe quantized) linear method. @@ -76,7 +79,11 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig, "be added in the future. If this is important to you, " "please open an issue on github.") else: - model = model_class(model_config.hf_config, linear_method) + if model_class in _AUDIO_MODEL_CLASSES: + model = model_class(model_config.hf_config, + audio_features_config, linear_method) + else: + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index a7d68a7cb3c7a..e305b452dd324 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -46,6 +46,8 @@ "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), "T5ForConditionalGeneration": ("t5", "T5ForConditionalGeneration"), + "WhisperForConditionalGeneration": + ("whisper", "WhisperForConditionalGeneration"), "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), } diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py new file mode 100644 index 0000000000000..20edfc43643d0 --- /dev/null +++ b/vllm/model_executor/models/whisper.py @@ -0,0 +1,444 @@ +from torch import nn +import torch +from torch import Tensor +from typing import List, Tuple, Optional, Union +from transformers import WhisperConfig +from transformers.activations import GELUActivation +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + RowParallelLinear, + ColumnParallelLinear, +) +from vllm.config import AudioFeaturesConfig +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.layers.enc_dec_attention import EncoderAttention, DecoderAttention, CrossAttention +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import hf_model_weights_iterator, default_weight_loader, load_tensor_parallel_weights +from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class WhisperPositionalEmbedding(nn.Embedding): + + def __init__(self, num_positions: int, embedding_dim: int): + super().__init__(num_positions, embedding_dim) + + def forward(self, input_ids, past_key_values_length=0, position_ids=None): + if position_ids is None: + return self.weight[past_key_values_length:past_key_values_length + + input_ids.shape[1]] + else: + return self.weight[position_ids] + + +class WhisperAttention(nn.Module): + + def __init__( + self, + config: WhisperConfig, + num_heads: int, + is_decoder: bool = False, + bias: bool = True, + is_cross: bool = False, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = num_heads + self.num_heads = num_heads // get_tensor_model_parallel_world_size() + self.is_decoder = is_decoder + self.is_cross = is_cross + self.key_value_proj_dim = self.d_model + self.head_dim = self.d_model // self.total_num_heads + if (self.head_dim * self.total_num_heads) != self.d_model: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.d_model}" + f" and `num_heads`: {num_heads}).") + + self.scaling = self.head_dim**-0.5 + + self.k_proj = ColumnParallelLinear(self.d_model, + self.d_model, + bias=False) + self.v_proj = ColumnParallelLinear(self.d_model, + self.d_model, + bias=bias) + self.q_proj = ColumnParallelLinear(self.d_model, + self.d_model, + bias=bias) + self.out_proj = RowParallelLinear(self.d_model, + self.d_model, + bias=True) + + if self.is_decoder and is_cross: + self.attn = CrossAttention(self.num_heads, self.head_dim, 1) + elif self.is_decoder and not is_cross: + self.attn = DecoderAttention(self.num_heads, self.head_dim, 1) + else: + self.attn = EncoderAttention(self.num_heads, self.head_dim, 1) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return (tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous()) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Union[Tuple[Tensor, Tensor], None], + input_metadata: InputMetadata, + encoder_hidden_states: Optional[Tensor] = None, + ) -> torch.Tensor: + + bsz, seq_len, _ = hidden_states.size() + q, _ = self.q_proj(hidden_states) + q = q * self.scaling # could be potentially done elsewhere + + if self.is_decoder and self.is_cross: + print("Decoder Cross Attn") + if encoder_hidden_states is None: + raise ValueError( + "Decoder cross-attention step. The encoder_hidden_states must be specified" + ) + assert kv_cache is not None + key_cache, value_cache = kv_cache + k, _ = self.k_proj(encoder_hidden_states) + v, _ = self.v_proj(encoder_hidden_states) + # reshape the tensors to the shape required by the EncoderAttention + proj_shape = (bsz, -1, self.head_dim * self.num_heads) + q = q.reshape(*proj_shape) + k = k.reshape(*proj_shape) + v = v.reshape(*proj_shape) + input_metadata.attn_bias = None + + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + + elif self.is_decoder and not self.is_cross: + print("Decoder Self Attn") + key_cache, value_cache = kv_cache + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + # reshape the tensors to the shape required by the EncoderAttention + proj_shape = (bsz, -1, self.head_dim * self.num_heads) + q = q.reshape(*proj_shape) + k = k.reshape(*proj_shape) + v = v.reshape(*proj_shape) + # from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + + # input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( + # [seq_len] * bsz) + # input_metadata.attn_bias = input_metadata.attn_bias.materialize((bsz, 1, seq_len, seq_len), device = q.device) + + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + + else: + # Encoding step. This means that the transformer blocks + # only employ self-attention and there is no KV cache + # available to be used + print("Encoder Attn") + if kv_cache is not None: + raise ValueError( + "Encoder self-attention step. The KV cache should not be populated." + ) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + + # reshape the tensors to the shape required by the EncoderAttention + proj_shape = (bsz, -1, self.head_dim * self.num_heads) + q = q.reshape(*proj_shape) + k = k.reshape(*proj_shape) + v = v.reshape(*proj_shape) + input_metadata.attn_bias = None + attn_output = self.attn(q, k, v, input_metadata) + + o, _ = self.out_proj(attn_output) + + return o + + +class WhisperEncoderBlock(nn.Module): + + def __init__(self, + config: WhisperConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.d_model = config.d_model + + self.self_attn = WhisperAttention( + config=config, + num_heads=config.encoder_attention_heads, + linear_method=linear_method, + ) + self.self_attn_layer_norm = nn.LayerNorm(self.d_model) + self.activation_fn = GELUActivation() + self.fc1 = ColumnParallelLinear(self.d_model, config.encoder_ffn_dim) + self.fc2 = RowParallelLinear(config.encoder_ffn_dim, self.d_model) + self.final_layer_norm = nn.LayerNorm(self.d_model) + + def forward( + self, + hidden_states: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states, None, input_metadata) + hidden_states = residual + hidden_states + + residual = hidden_states + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperEncoder(nn.Module): + + def __init__(self, + config: WhisperConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.d_model = config.d_model + self.num_mel_bins = config.num_mel_bins + self.max_source_positions = config.max_source_positions + + self.conv1 = nn.Conv1d(self.num_mel_bins, + self.d_model, + kernel_size=3, + padding=1) + self.conv2 = nn.Conv1d(self.d_model, + self.d_model, + kernel_size=3, + stride=2, + padding=1) + + self.embed_positions = nn.Embedding(self.max_source_positions, + self.d_model) + self.layers = nn.ModuleList([ + WhisperEncoderBlock(config, linear_method) + for i in range(config.encoder_layers) + ]) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_features: torch.Tensor, + input_metadata: InputMetadata, + ) -> torch.Tensor: + + expected_seq_length = (self.max_source_positions * + self.conv1.stride[0] * self.conv2.stride[0]) + if input_features.shape[-1] != expected_seq_length: + raise ValueError( + f"Whisper expects the mel input features to be of length {expected_seq_length}, " + f"but found {input_features.shape[-1]}. Make sure to pad the " + f"input mel features to {expected_seq_length}.") + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight + + hidden_states = inputs_embeds + embed_pos + + for enc_block in self.layers: + hidden_states = enc_block(hidden_states, input_metadata) + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class WhisperDecoderBlock(nn.Module): + + def __init__(self, + config: WhisperConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.d_model = config.d_model + + self.self_attn = WhisperAttention( + config=config, + is_decoder=True, + num_heads=config.decoder_attention_heads, + linear_method=linear_method, + ) + + self.encoder_attn = WhisperAttention( + config=config, + is_decoder=True, + num_heads=config.decoder_attention_heads, + is_cross=True, + linear_method=linear_method, + ) + + self.activation_fn = GELUActivation() + + self.self_attn_layer_norm = nn.LayerNorm(self.d_model) + self.encoder_attn_layer_norm = nn.LayerNorm(self.d_model) + self.fc1 = ColumnParallelLinear(self.d_model, config.decoder_ffn_dim) + self.fc2 = RowParallelLinear(config.decoder_ffn_dim, self.d_model) + self.final_layer_norm = nn.LayerNorm(self.d_model) + + def forward(self, hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, kv_cache: Tuple[Tensor, + Tensor], + input_metadata: InputMetadata) -> torch.Tensor: + + residual = hidden_states + # self-attention + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states, kv_cache, input_metadata) + hidden_states = residual + hidden_states + + residual = hidden_states + # cross-attention + hidden_states = self.encoder_attn_layer_norm(hidden_states) + hidden_states = self.encoder_attn(hidden_states, kv_cache, + input_metadata, + encoder_hidden_states) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class WhisperDecoder(nn.Module): + + def __init__(self, + config: WhisperConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.d_model = config.d_model + + self.embed_tokens = nn.Embedding(config.vocab_size, self.d_model) + self.embed_positions = WhisperPositionalEmbedding( + config.max_target_positions, self.d_model) + self.layers = nn.ModuleList([ + WhisperDecoderBlock(config, linear_method=linear_method) + for _ in range(config.decoder_layers) + ]) + self.layer_norm = nn.LayerNorm(config.d_model) + + def forward( + self, + input_ids: torch.Tensor, + encoder_hidden_states: torch.Tensor, + kv_cache: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + + inputs_embeds = self.embed_tokens(input_ids) + positions = self.embed_positions( + inputs_embeds, + past_key_values_length=0, + ) + hidden_states = inputs_embeds + positions + for i, dec_block in enumerate(self.layers): + hidden_states = dec_block(hidden_states, encoder_hidden_states, + kv_cache[i], input_metadata) + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +class WhisperForConditionalGeneration(nn.Module): + + def __init__( + self, + config: WhisperConfig, + audio_features_config: AudioFeaturesConfig, + linear_method: Optional[LinearMethodBase] = None # probably not needed + ): + super().__init__() + self.config = config + self.encoder = WhisperEncoder(config, linear_method) + self.decoder = WhisperDecoder(config, linear_method) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_features: torch.FloatTensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + input_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if input_metadata.is_prompt: + # prompt run, need to run encoder once + hidden_states = self.encoder(input_features, input_metadata=input_metadata) + input_metadata.attn_bias = None + + if input_ids is None: + decoder_input_ids = torch.tensor([[1]]).to( + input_features.device) * self.config.decoder_start_token_id + else: + decoder_input_ids = input_ids + + if kv_caches[0][0] is None: + hidden_states = None + else: + hidden_states = self.decoder(input_ids=decoder_input_ids, + encoder_hidden_states=hidden_states, + kv_cache=kv_caches, + input_metadata=input_metadata) + return hidden_states + + def sample(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata): + # TODO: For now we are not implementing the sampling method + return hidden_states + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + column_parallel_weight_names = [ + "k_proj.weight", "v_proj.weight", "q_proj.weight", "q_proj.bias", + "v_proj.bias", "fc1.bias", "fc1.weight" + ] + row_parallel_weight_names = ["out_proj.weight", "fc2.weight"] + + parallel_weight_names = column_parallel_weight_names + row_parallel_weight_names + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + name = name.replace("model.", "") + assert name in params_dict, f"{name} not in params_dict" + param = params_dict[name] + if any(_name in name for _name in parallel_weight_names): + load_tensor_parallel_weights( + param, + loaded_weight, + name, + column_parallel_weight_names=column_parallel_weight_names, + row_parallel_weight_names=row_parallel_weight_names, + tensor_model_parallel_rank=get_tensor_model_parallel_rank( + )) + continue + assert param.shape == loaded_weight.shape, ( + f"{name} shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}") + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/sequence.py b/vllm/sequence.py index 9f74e74d20b04..5b768a1e9e13d 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -269,6 +269,25 @@ class SequenceGroupState: generator: Optional = None +class MultiModalData: + """Multi modal request. + + Args: + type: The data type. + data: The actual data. + The required shape and semantic meaning of it depends on the vision + language config of the hosted model. + See `VisionLanguageConfig` in `config.py`. + """ + + class Type(enum.Enum): + FEATURES = enum.auto() + + def __init__(self, type: Type, data: "torch.Tensor"): + self.type = type + self.data = data + + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -289,6 +308,7 @@ def __init__( arrival_time: float, lora_request: Optional[LoRARequest] = None, prefix: Optional[Prefix] = None, + multi_modal_data: Optional[MultiModalData] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -300,6 +320,7 @@ def __init__( time_in_queue=None) self.lora_request = lora_request self.prefix: Optional[Prefix] = prefix + self.multi_modal_data: Optional[MultiModalData] = multi_modal_data self.prompt_logprobs: Optional[PromptLogprobs] = None self.state = SequenceGroupState() @@ -432,6 +453,7 @@ def __init__( sampling_params: SamplingParams, block_tables: Dict[int, List[int]], lora_request: Optional[LoRARequest] = None, + multi_modal_data: Optional[MultiModalData] = None, prefix: Optional[Prefix] = None, state: Optional[SequenceGroupState] = None, ) -> None: @@ -441,7 +463,9 @@ def __init__( self.sampling_params = sampling_params self.block_tables = block_tables self.lora_request = lora_request + self.multi_modal_data = multi_modal_data self.prefix = prefix + self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state @property diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 8d6da05ffc334..107ba303b34af 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -7,7 +7,7 @@ import torch.nn as nn from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, AudioFeaturesConfig) from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils import cupy_utils @@ -17,7 +17,7 @@ with_cupy_nccl_for_all_reduce) from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.sampling_params import SamplingParams, SamplingType -from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, MultiModalData from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest @@ -42,6 +42,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], + audio_features_config: Optional[AudioFeaturesConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ): @@ -49,6 +50,7 @@ def __init__( self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.lora_config = lora_config + self.audio_features_config = audio_features_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. @@ -92,11 +94,13 @@ def __init__( getattr(self.model_config.hf_config, "is_encoder_decoder", False) def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + self.model = get_model( + self.model_config, + self.device_config, + audio_features_config=self.audio_features_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) vocab_size = self.model.config.vocab_size @@ -129,10 +133,11 @@ def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], - List[int], List[int], Set[LoRARequest]]: + List[int], List[int], Set[LoRARequest], torch.Tensor]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] + multi_modal_inputs: List = [] slot_mapping: List[List[int]] = [] lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] @@ -187,6 +192,10 @@ def _prepare_prompt( (prompt_len - prefix_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if seq_group_metadata.multi_modal_data: + multi_modal_inputs.append( + seq_group_metadata.multi_modal_data.data) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -253,6 +262,12 @@ def _prepare_prompt( context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device=self.device) + + if multi_modal_inputs: + multi_modal_input = torch.stack(multi_modal_inputs).to('cuda') + else: + multi_modal_input = None + if self.is_encoder_decoder: padded_block_tables = [] # Pad the encoder block tables to the same length and then add a decoder block table in the end @@ -300,7 +315,7 @@ def _prepare_prompt( ) return (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests) + lora_requests, multi_modal_input) def _prepare_decode( self, @@ -458,7 +473,8 @@ def _prepare_decode( kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, input_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests) + lora_index_mapping, lora_prompt_mapping, lora_requests, + multi_modal_input) def _prepare_sample( self, @@ -550,7 +566,7 @@ def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, - Set[int], LoRAMapping]: + Set[int], LoRAMapping, torch.Tensor]: if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -559,13 +575,15 @@ def prepare_input_tensors( if is_prompt: (input_tokens, input_positions, input_metadata, prompt_lens, subquery_lens, lora_index_mapping, lora_prompt_mapping, - lora_requests) = self._prepare_prompt(seq_group_metadata_list) + lora_requests, multi_modal_input + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] subquery_lens = None + multi_modal_input = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) @@ -599,6 +617,7 @@ def prepare_input_tensors( sampling_metadata.selected_token_indices, "lora_requests": lora_requests, "lora_mapping": lora_mapping, + "multi_modal_input": multi_modal_input, } broadcast_tensor_dict(metadata_dict, src=0) else: @@ -607,6 +626,7 @@ def prepare_input_tensors( input_positions = metadata_dict["input_positions"] lora_mapping = metadata_dict["lora_mapping"] lora_requests = metadata_dict["lora_requests"] + multi_modal_input = metadata_dict["multi_modal_input"] input_metadata = InputMetadata( is_prompt=metadata_dict["is_prompt"], slot_mapping=metadata_dict["slot_mapping"], @@ -630,7 +650,8 @@ def prepare_input_tensors( ) return (input_tokens, input_positions, input_metadata, - sampling_metadata, lora_requests, lora_mapping) + sampling_metadata, lora_requests, lora_mapping, + multi_modal_input) @torch.inference_mode() def execute_model( @@ -639,8 +660,8 @@ def execute_model( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, input_metadata, sampling_metadata, - lora_requests, - lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) + lora_requests, lora_mapping, multi_modal_input + ) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -651,12 +672,17 @@ def execute_model( model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - hidden_states = model_executable( - input_ids=input_tokens, - positions=input_positions, - kv_caches=kv_caches, - input_metadata=input_metadata, - ) + + execute_model_kwargs = { + "input_ids": input_tokens, + "positions": input_positions, + "kv_caches": kv_caches, + "input_metadata": input_metadata, + } + # TODO: Hard code for now + execute_model_kwargs.update({"input_features": multi_modal_input}) + + hidden_states = model_executable(**execute_model_kwargs) # Sample the next token. output = self.model.sample( @@ -701,7 +727,9 @@ def profile_run(self) -> None: for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) - seq_data = SequenceData([0] * seq_len) + self.audio_features_config = AudioFeaturesConfig() # hack + seq_data, fake_multi_modal_input = _prepare_fake_inputs( + seq_len, self.audio_features_config) seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, @@ -710,6 +738,7 @@ def profile_run(self) -> None: block_tables=None, lora_request=dummy_lora_requests_per_seq[group_id] if dummy_lora_requests_per_seq else None, + multi_modal_data=fake_multi_modal_input, ) seqs.append(seq) @@ -774,6 +803,15 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda() + kwargs = {} + if self.audio_features_config: + fake_multi_modal_input = torch.zeros( + max_batch_size, + self.audio_features_config.feature_dims, + self.audio_features_config.sequence_length, + dtype=torch.float32).cuda() + kwargs["input_features"] = fake_multi_modal_input + graph_batch_size = _get_graph_batch_size( self.scheduler_config.max_num_seqs) batch_size_capture_list = [ @@ -818,6 +856,7 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: kv_caches, input_metadata, memory_pool=self.graph_memory_pool, + **kwargs, ) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[batch_size] = graph_runner @@ -851,6 +890,7 @@ def capture( kv_caches: List[KVCache], input_metadata: InputMetadata, memory_pool, + **kwargs, ) -> None: assert self.graph is None # Run the model once without capturing the graph. @@ -862,6 +902,7 @@ def capture( positions, kv_caches, input_metadata, + **kwargs, ) torch.cuda.synchronize() @@ -876,6 +917,7 @@ def capture( positions, kv_caches, input_metadata, + **kwargs, ) torch.cuda.synchronize() @@ -897,6 +939,7 @@ def forward( positions: torch.Tensor, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], input_metadata: InputMetadata, + **kwargs, ) -> torch.Tensor: # KV caches are fixed tensors, so we don't need to copy them. del kv_caches @@ -963,3 +1006,21 @@ def _async_h2d( ) -> torch.Tensor: t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") return t.to(device=target_device, non_blocking=True) + + +def _prepare_fake_inputs(seq_len: int, + audio_features_config: Optional[AudioFeaturesConfig]): + """Prepare fake inputs for profile run.""" + if audio_features_config: + prompt_tokens = [0] * seq_len + fake_input_features = MultiModalData( + type=MultiModalData.Type.FEATURES, + data=torch.zeros(audio_features_config.feature_dims, + audio_features_config.sequence_length, + dtype=torch.float16)) + fake_multi_modal_inputs = fake_input_features + else: + + prompt_tokens = [0] * seq_len + fake_multi_modal_inputs = None + return SequenceData(prompt_tokens), fake_multi_modal_inputs diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 9df518d155ec2..61d90be34cd0a 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,7 +7,8 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, - ParallelConfig, SchedulerConfig, LoRAConfig) + ParallelConfig, SchedulerConfig, LoRAConfig, + AudioFeaturesConfig) from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils import cupy_utils from vllm.model_executor.parallel_utils.communication_op import ( @@ -40,9 +41,11 @@ def __init__( rank: int, distributed_init_method: str, lora_config: Optional[LoRAConfig] = None, + audio_features_config: Optional[AudioFeaturesConfig] = None, kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, ) -> None: + self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -51,17 +54,23 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.audio_features_config = audio_features_config + if self.audio_features_config and self.lora_config: + raise NotImplementedError( + "Not yet tested: audio models with LoRA settings.") self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.model_runner = ModelRunner(model_config, - parallel_config, - scheduler_config, - device_config, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=is_driver_worker) + self.model_runner = ModelRunner( + model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + audio_features_config=self.audio_features_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None From 33821391d869825349b84115fbc40e2704a0d818 Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Tue, 26 Mar 2024 10:49:27 +0000 Subject: [PATCH 2/5] trying something new --- tests/models/test_whisper.py | 4 ++-- vllm/model_executor/models/whisper.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/models/test_whisper.py b/tests/models/test_whisper.py index 80cba1e778a37..03b1f5e138045 100644 --- a/tests/models/test_whisper.py +++ b/tests/models/test_whisper.py @@ -5,7 +5,7 @@ import os -os.environ['CUDA_LAUNCH_BLOCKING'] = "0" +os.environ['CUDA_LAUNCH_BLOCKING'] = "1" os.environ['CUDA_VISIBLE_DEVICES'] = "6" @@ -29,7 +29,7 @@ def sample_from_librispeech(): audio_sample = sample_from_librispeech()["audio"] -@pytest.mark.parametrize("dtype", ["half"]) # TODO fix that +@pytest.mark.parametrize("dtype", ["float"]) # TODO fix that @pytest.mark.parametrize("max_tokens", [2]) @pytest.mark.parametrize("prompts, audio_samples", [([""], [audio_sample])]) def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts, diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 20edfc43643d0..e1e97a4354ef9 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -380,6 +380,7 @@ def forward( input_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: + input_features = input_features.to(dtype=torch.float32) if input_metadata.is_prompt: # prompt run, need to run encoder once hidden_states = self.encoder(input_features, input_metadata=input_metadata) From 73891d203037625573c24b5956f99fc2d1869f4e Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Wed, 27 Mar 2024 13:43:19 +0000 Subject: [PATCH 3/5] sucessfull pass, time to check correctness --- examples/offline_inference_enc_dec.py | 6 +--- tests/models/test_whisper.py | 3 +- vllm/model_executor/models/whisper.py | 51 ++++++++++----------------- vllm/sequence.py | 7 ++-- vllm/worker/model_runner.py | 14 ++++---- 5 files changed, 32 insertions(+), 49 deletions(-) diff --git a/examples/offline_inference_enc_dec.py b/examples/offline_inference_enc_dec.py index 20c9bc06f2d82..746c931fae611 100644 --- a/examples/offline_inference_enc_dec.py +++ b/examples/offline_inference_enc_dec.py @@ -22,11 +22,7 @@ hf_model_id = "t5-small" dtype = "bfloat16" -prompts = [ - "Who are you?", - "Who are you?", - "How do you like your egg made", - "How do you like your egg made", +prompts = ["How do you like your egg made" * 50, ] dtype_obj = getattr(torch, dtype) diff --git a/tests/models/test_whisper.py b/tests/models/test_whisper.py index 03b1f5e138045..6629781b559c8 100644 --- a/tests/models/test_whisper.py +++ b/tests/models/test_whisper.py @@ -48,7 +48,8 @@ def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts, dtype=dtype, enforce_eager=True, tensor_parallel_size=1, - gpu_memory_utilization=0.9) + gpu_memory_utilization=0.4, + max_model_len=64) vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens, audio_samples=audio_samples) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index e1e97a4354ef9..fb5a0d57c1b5e 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -106,12 +106,11 @@ def forward( key_cache, value_cache = kv_cache k, _ = self.k_proj(encoder_hidden_states) v, _ = self.v_proj(encoder_hidden_states) - # reshape the tensors to the shape required by the EncoderAttention - proj_shape = (bsz, -1, self.head_dim * self.num_heads) - q = q.reshape(*proj_shape) - k = k.reshape(*proj_shape) - v = v.reshape(*proj_shape) - input_metadata.attn_bias = None + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + + input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * bsz) + input_metadata.attn_bias = input_metadata.attn_bias.materialize((bsz, 1, seq_len, seq_len), device = q.device) attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) @@ -121,11 +120,6 @@ def forward( key_cache, value_cache = kv_cache k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - # reshape the tensors to the shape required by the EncoderAttention - proj_shape = (bsz, -1, self.head_dim * self.num_heads) - q = q.reshape(*proj_shape) - k = k.reshape(*proj_shape) - v = v.reshape(*proj_shape) # from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask # input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( @@ -146,12 +140,6 @@ def forward( ) k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - - # reshape the tensors to the shape required by the EncoderAttention - proj_shape = (bsz, -1, self.head_dim * self.num_heads) - q = q.reshape(*proj_shape) - k = k.reshape(*proj_shape) - v = v.reshape(*proj_shape) input_metadata.attn_bias = None attn_output = self.attn(q, k, v, input_metadata) @@ -385,26 +373,23 @@ def forward( # prompt run, need to run encoder once hidden_states = self.encoder(input_features, input_metadata=input_metadata) input_metadata.attn_bias = None - - if input_ids is None: - decoder_input_ids = torch.tensor([[1]]).to( - input_features.device) * self.config.decoder_start_token_id - else: - decoder_input_ids = input_ids + bsz = hidden_states.shape[0] + decoder_input_ids = torch.ones((bsz, 1), dtype=torch.int32).to(input_features.device) * self.config.decoder_start_token_id - if kv_caches[0][0] is None: - hidden_states = None else: - hidden_states = self.decoder(input_ids=decoder_input_ids, - encoder_hidden_states=hidden_states, - kv_cache=kv_caches, - input_metadata=input_metadata) + if kv_caches[0][0] is None: + hidden_states = None + else: + hidden_states = self.decoder(input_ids=decoder_input_ids, + encoder_hidden_states=hidden_states, + kv_cache=kv_caches, + input_metadata=input_metadata) return hidden_states - def sample(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata): - # TODO: For now we are not implementing the sampling method - return hidden_states + def sample(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata): + next_tokens = self.sampler(self.decoder.embed_tokens.weight, hidden_states, + sampling_metadata) + return next_tokens def load_weights( self, diff --git a/vllm/sequence.py b/vllm/sequence.py index 5b768a1e9e13d..8cea667cf1a5e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -150,12 +150,13 @@ def __init__( self.logical_token_blocks: List[LogicalTokenBlock] = [] initial_token_ids = prompt_token_ids if is_decoder_encoder: + from vllm.config import AudioFeaturesConfig # We need to separate the prompt and generated tokens for encoder-decoder models. - num_prompt_blocks = (len(prompt_token_ids) + block_size - + num_prompt_blocks = (AudioFeaturesConfig().sequence_length + block_size - 1) // block_size padded_prompt_len = num_prompt_blocks * block_size - initial_token_ids = prompt_token_ids + [0] * ( - padded_prompt_len - len(prompt_token_ids)) + initial_token_ids = [0] * ( + padded_prompt_len) # Also need to append decoder_start_token_id initial_token_ids.append(0) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 107ba303b34af..76d9967c1f471 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -154,10 +154,14 @@ def _prepare_prompt( seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 seq_id = seq_ids[0] + + if seq_group_metadata.multi_modal_data: + multi_modal_inputs.append( + seq_group_metadata.multi_modal_data.data) seq_data = seq_group_metadata.seq_data[seq_id] prompt_tokens = seq_data.get_token_ids() - prompt_len = len(prompt_tokens) + prompt_len = multi_modal_inputs[-1].shape[1] prompt_lens.append(prompt_len) prefix_len = 0 @@ -179,7 +183,7 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.append( - list(range(prefix_len, prefix_len + len(prompt_tokens)))) + list(range(prefix_len, prefix_len + multi_modal_inputs[-1].shape[1]))) lora_id = seq_group_metadata.lora_int_id @@ -192,10 +196,6 @@ def _prepare_prompt( (prompt_len - prefix_len if seq_group_metadata.sampling_params.prompt_logprobs else 1)) - if seq_group_metadata.multi_modal_data: - multi_modal_inputs.append( - seq_group_metadata.multi_modal_data.data) - if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -232,7 +232,7 @@ def _prepare_prompt( len(block_table)) max_prompt_len = max(subquery_lens) input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, + len(input_tokens[0]), pad=0, dtype=torch.long, device=self.device) From 642c8c1d484d39a3d9ac866d15ff0adb3060f5e4 Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Tue, 2 Apr 2024 12:22:04 +0000 Subject: [PATCH 4/5] some progress --- tests/models/test_whisper.py | 30 ++++++------ .../layers/enc_dec_attention.py | 11 ++--- vllm/model_executor/models/whisper.py | 49 ++++++++----------- vllm/worker/model_runner.py | 3 +- 4 files changed, 41 insertions(+), 52 deletions(-) diff --git a/tests/models/test_whisper.py b/tests/models/test_whisper.py index 6629781b559c8..c0d75f0e24fb9 100644 --- a/tests/models/test_whisper.py +++ b/tests/models/test_whisper.py @@ -30,16 +30,17 @@ def sample_from_librispeech(): @pytest.mark.parametrize("dtype", ["float"]) # TODO fix that -@pytest.mark.parametrize("max_tokens", [2]) +@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("prompts, audio_samples", [([""], [audio_sample])]) def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts, audio_samples, dtype: str, max_tokens: int) -> None: - # hf_model = hf_runner(model_id, dtype=dtype) - # hf_outputs = hf_model.generate_greedy(prompts=prompts, - # audio_samples=audio_samples, - # max_tokens=max_tokens) - # del hf_model + + hf_model = hf_runner(model_id, dtype=dtype) + hf_outputs = hf_model.generate_greedy(prompts=prompts, + audio_samples=audio_samples, + max_tokens=max_tokens) + del hf_model # Truly cleans up GPU memory. torch.cuda.empty_cache() @@ -48,8 +49,8 @@ def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts, dtype=dtype, enforce_eager=True, tensor_parallel_size=1, - gpu_memory_utilization=0.4, - max_model_len=64) + gpu_memory_utilization=0.5) + vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens, audio_samples=audio_samples) @@ -57,12 +58,11 @@ def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts, # Truly cleans up GPU memory. torch.cuda.empty_cache() - for i in range(len(hf_image_prompts)): + for i in range(len(prompts)): hf_output_ids, hf_output_str = hf_outputs[i] - vllm_output_ids, vllm_output_str = sanitize_vllm_output( - vllm_outputs[i], vision_language_config, model_id) - assert hf_output_str == vllm_output_str, ( - f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") - assert hf_output_ids == vllm_output_ids, ( - f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") + vllm_output_ids, vllm_output_str = vllm_outputs[i] + print(f"hf_output_str: {hf_output_str}") + print(f"first 10 tokens: {hf_output_ids[:10]}") + print(f"vllm_output_str: {vllm_output_str}") + print(f"first 10 tokens: {vllm_output_ids[:10]}") diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py index 8cbb1c703c1f8..c27be0c6dce61 100644 --- a/vllm/model_executor/layers/enc_dec_attention.py +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -73,11 +73,8 @@ def forward( query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_heads, self.head_size) value = value.view(batch_size, seq_len, self.num_heads, self.head_size) - # if input_metadata.attn_bias is None: - # input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( - # [seq_len] * batch_size) - - # input_metadata.attn_bias = input_metadata.attn_bias[:, :, :, :seq_len] + if input_metadata.attn_bias is None: + pass # Normal attention out = xops.memory_efficient_attention_forward( @@ -137,6 +134,7 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: + cache_ops.reshape_and_cache( key, value, key_cache, value_cache, input_metadata.slot_mapping[:, -1].flatten().contiguous(), @@ -158,8 +156,7 @@ def forward( num_kv_heads=self.num_heads, scale=self.scale, alibi_slopes=None, - custom_bias=input_metadata.attn_bias.to(torch.float32) - if input_metadata.attn_bias is not None else None, + custom_bias=input_metadata.attn_bias.to(torch.float32) if input_metadata.attn_bias is not None else None, kv_cache_dtype=input_metadata.kv_cache_dtype, ) return output.view(batch_size, seq_len, hidden_size) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index fb5a0d57c1b5e..ad22e1da9f1d2 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -94,24 +94,21 @@ def forward( bsz, seq_len, _ = hidden_states.size() q, _ = self.q_proj(hidden_states) - q = q * self.scaling # could be potentially done elsewhere if self.is_decoder and self.is_cross: - print("Decoder Cross Attn") - if encoder_hidden_states is None: - raise ValueError( - "Decoder cross-attention step. The encoder_hidden_states must be specified" - ) assert kv_cache is not None key_cache, value_cache = kv_cache - k, _ = self.k_proj(encoder_hidden_states) - v, _ = self.v_proj(encoder_hidden_states) - from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - - input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( - [seq_len] * bsz) - input_metadata.attn_bias = input_metadata.attn_bias.materialize((bsz, 1, seq_len, seq_len), device = q.device) - + print("Decoder Cross Attn") + if input_metadata.is_prompt: + if encoder_hidden_states is None: + raise ValueError( + "Decoder cross-attention step. The encoder_hidden_states must be specified" + ) + + k, _ = self.k_proj(encoder_hidden_states) + v, _ = self.v_proj(encoder_hidden_states) + else: + k, v = None, None attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) @@ -120,11 +117,6 @@ def forward( key_cache, value_cache = kv_cache k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) - # from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask - - # input_metadata.attn_bias = BlockDiagonalCausalMask.from_seqlens( - # [seq_len] * bsz) - # input_metadata.attn_bias = input_metadata.attn_bias.materialize((bsz, 1, seq_len, seq_len), device = q.device) attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata) @@ -138,6 +130,7 @@ def forward( raise ValueError( "Encoder self-attention step. The KV cache should not be populated." ) + q = q * self.scaling # could be potentially done elsewhere k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) input_metadata.attn_bias = None @@ -368,22 +361,22 @@ def forward( input_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - input_features = input_features.to(dtype=torch.float32) if input_metadata.is_prompt: + input_features = input_features.to(dtype=torch.float32) # prompt run, need to run encoder once hidden_states = self.encoder(input_features, input_metadata=input_metadata) input_metadata.attn_bias = None bsz = hidden_states.shape[0] decoder_input_ids = torch.ones((bsz, 1), dtype=torch.int32).to(input_features.device) * self.config.decoder_start_token_id - else: - if kv_caches[0][0] is None: - hidden_states = None - else: - hidden_states = self.decoder(input_ids=decoder_input_ids, - encoder_hidden_states=hidden_states, - kv_cache=kv_caches, - input_metadata=input_metadata) + hidden_states = None + decoder_input_ids = input_ids + + if kv_caches[0][0] is not None: + hidden_states = self.decoder(input_ids=decoder_input_ids, + encoder_hidden_states=hidden_states, + kv_cache=kv_caches, + input_metadata=input_metadata) return hidden_states def sample(self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 76d9967c1f471..1c737ff388681 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -473,8 +473,7 @@ def _prepare_decode( kv_cache_dtype=self.kv_cache_dtype, ) return (input_tokens, input_positions, input_metadata, - lora_index_mapping, lora_prompt_mapping, lora_requests, - multi_modal_input) + lora_index_mapping, lora_prompt_mapping, lora_requests) def _prepare_sample( self, From ff1b0a9f0cc08e56059edcebbf87302c7058863e Mon Sep 17 00:00:00 2001 From: dbogunowicz Date: Tue, 2 Apr 2024 15:00:25 +0000 Subject: [PATCH 5/5] cleanup --- examples/offline_inference_enc_dec.py | 6 +++++- tests/conftest.py | 1 + tests/models/test_whisper.py | 4 +--- vllm/model_executor/layers/enc_dec_attention.py | 1 - vllm/model_executor/models/whisper.py | 2 ++ vllm/sequence.py | 3 +-- vllm/worker/model_runner.py | 8 +++++++- 7 files changed, 17 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference_enc_dec.py b/examples/offline_inference_enc_dec.py index 746c931fae611..20c9bc06f2d82 100644 --- a/examples/offline_inference_enc_dec.py +++ b/examples/offline_inference_enc_dec.py @@ -22,7 +22,11 @@ hf_model_id = "t5-small" dtype = "bfloat16" -prompts = ["How do you like your egg made" * 50, +prompts = [ + "Who are you?", + "Who are you?", + "How do you like your egg made", + "How do you like your egg made", ] dtype_obj = getattr(torch, dtype) diff --git a/tests/conftest.py b/tests/conftest.py index fab45f0e5668b..8c6b7fd4b04c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,6 +60,7 @@ def __init__( self.model = AUDIO_MODELS[model_name].from_pretrained( model_name, torch_dtype=self.torch_dtype, + attn_implementation="eager", trust_remote_code=True).cuda() self.processor = AutoProcessor.from_pretrained( model_name, torch_dtype=self.torch_dtype) diff --git a/tests/models/test_whisper.py b/tests/models/test_whisper.py index c0d75f0e24fb9..6a04e225296f0 100644 --- a/tests/models/test_whisper.py +++ b/tests/models/test_whisper.py @@ -5,9 +5,7 @@ import os -os.environ['CUDA_LAUNCH_BLOCKING'] = "1" -os.environ['CUDA_VISIBLE_DEVICES'] = "6" - +os.environ['CUDA_LAUNCH_BLOCKING'] = "0" @pytest.fixture() def model_id(): diff --git a/vllm/model_executor/layers/enc_dec_attention.py b/vllm/model_executor/layers/enc_dec_attention.py index c27be0c6dce61..ee8efa597055f 100644 --- a/vllm/model_executor/layers/enc_dec_attention.py +++ b/vllm/model_executor/layers/enc_dec_attention.py @@ -134,7 +134,6 @@ def forward( # vectors will not be cached. This happens during the initial memory # profiling run. if key_cache is not None and value_cache is not None: - cache_ops.reshape_and_cache( key, value, key_cache, value_cache, input_metadata.slot_mapping[:, -1].flatten().contiguous(), diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index ad22e1da9f1d2..e36d9c66b0197 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -99,6 +99,7 @@ def forward( assert kv_cache is not None key_cache, value_cache = kv_cache print("Decoder Cross Attn") + q = q * self.scaling if input_metadata.is_prompt: if encoder_hidden_states is None: raise ValueError( @@ -115,6 +116,7 @@ def forward( elif self.is_decoder and not self.is_cross: print("Decoder Self Attn") key_cache, value_cache = kv_cache + q = q * self.scaling k, _ = self.k_proj(hidden_states) v, _ = self.v_proj(hidden_states) diff --git a/vllm/sequence.py b/vllm/sequence.py index 8cea667cf1a5e..363def788d709 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -155,8 +155,7 @@ def __init__( num_prompt_blocks = (AudioFeaturesConfig().sequence_length + block_size - 1) // block_size padded_prompt_len = num_prompt_blocks * block_size - initial_token_ids = [0] * ( - padded_prompt_len) + initial_token_ids = [0] * padded_prompt_len # Also need to append decoder_start_token_id initial_token_ids.append(0) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 1c737ff388681..3443d0d7faec3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -349,10 +349,16 @@ def _prepare_decode( input_tokens.append([generation_token]) seq_len = seq_data.get_len() + prompt_len = 3000 + # seq_len = len(prompt_token) + len(gen_tokens) + # we need to make it: + # seq_len = len(input_features + len(gen_tokens) + seq_len += prompt_len # add len(input_features) + seq_len -= 3 # remove len(prompt_token) + position = seq_len - 1 input_positions.append([position]) - prompt_len = len(seq_data.prompt_token_ids) prompt_lens.append(prompt_len) if self.is_encoder_decoder: