Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

[WiP] Whisper Implementation #147

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions examples/offline_inference_enc_dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@

hf_model_id = "t5-small"
dtype = "bfloat16"
prompts = [
"Who are you?",
"Who are you?",
"How do you like your egg made",
"How do you like your egg made",
prompts = ["How do you like your egg made" * 50,
]

dtype_obj = getattr(torch, dtype)
Expand Down
75 changes: 62 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import pytest
import torch
from transformers import AutoModelForCausalLM

from transformers import AutoModelForCausalLM, WhisperForConditionalGeneration, AutoProcessor
from vllm.sequence import MultiModalData
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand All @@ -16,7 +16,10 @@
def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
prompts = f.readlines()
return prompts
return prompts


AUDIO_MODELS = {"openai/whisper-tiny": WhisperForConditionalGeneration}


@pytest.fixture
Expand Down Expand Up @@ -51,26 +54,51 @@ def __init__(
dtype: str = "half",
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if model_name in AUDIO_MODELS:
self.model = AUDIO_MODELS[model_name].from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
trust_remote_code=True).cuda()
self.processor = AutoProcessor.from_pretrained(
model_name, torch_dtype=self.torch_dtype)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
trust_remote_code=True).cuda()
self.processor = None
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)

def generate(
self,
prompts: List[str],
audio_samples: Optional["TODO"] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if audio_samples:
assert len(prompts) == len(audio_samples)
for i, prompt in enumerate(prompts):
if self.model_name in AUDIO_MODELS:
audio_sample = audio_samples[i]
input_features = self.processor(
audio_sample["array"],
sampling_rate=audio_sample["sampling_rate"],
return_tensors="pt").input_features

inputs = dict(
input_features=input_features.cuda()) #TODO: Fix this
else:
input_ids = self.tokenizer(prompt,
return_tensors="pt").input_ids
inputs = dict(input_ids=input_ids.cuda())

output_ids = self.model.generate(
input_ids.cuda(),
**inputs,
use_cache=True,
**kwargs,
)
Expand All @@ -87,8 +115,10 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
audio_samples: Optional["TODO"] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
audio_samples=audio_samples,
do_sample=False,
max_new_tokens=max_tokens)
for i in range(len(outputs)):
Expand Down Expand Up @@ -182,8 +212,24 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
audio_samples: Optional["TODO"] = None,
) -> List[Tuple[List[int], str]]:
if audio_samples:
assert len(prompts) == len(audio_samples)
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
input_features = processor(
audio_samples[0]["array"],
sampling_rate=audio_samples[0]["sampling_rate"],
return_tensors="pt").input_features
# change type of input features
input_features = input_features.to(
dtype=self.model.llm_engine.model_config.dtype)
multi_modal_data = MultiModalData(type=input_features.dtype,
data=input_features[0])
else:
multi_modal_data = None
req_outputs = self.model.generate(prompts,
multi_modal_data=multi_modal_data,
sampling_params=sampling_params)
outputs = []
for req_output in req_outputs:
Expand Down Expand Up @@ -221,9 +267,12 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
audio_samples: Optional["TODO"] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
outputs = self.generate(prompts,
greedy_params,
audio_samples=audio_samples)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand Down
68 changes: 68 additions & 0 deletions tests/models/test_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
import torch
from datasets import load_dataset
from vllm.config import AudioFeaturesConfig

import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = "6"


@pytest.fixture()
def model_id():
return "openai/whisper-tiny"


@pytest.fixture()
def audio_features_config():
return AudioFeaturesConfig()


def sample_from_librispeech():
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy",
"clean",
split="validation")
return dataset[0]


audio_sample = sample_from_librispeech()["audio"]


@pytest.mark.parametrize("dtype", ["float"]) # TODO fix that
@pytest.mark.parametrize("max_tokens", [2])
@pytest.mark.parametrize("prompts, audio_samples", [([""], [audio_sample])])
def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts,
audio_samples, dtype: str,
max_tokens: int) -> None:
# hf_model = hf_runner(model_id, dtype=dtype)
# hf_outputs = hf_model.generate_greedy(prompts=prompts,
# audio_samples=audio_samples,
# max_tokens=max_tokens)
# del hf_model

# Truly cleans up GPU memory.
torch.cuda.empty_cache()

vllm_model = vllm_runner(model_id,
dtype=dtype,
enforce_eager=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.4,
max_model_len=64)
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
audio_samples=audio_samples)
del vllm_model
# Truly cleans up GPU memory.
torch.cuda.empty_cache()

for i in range(len(hf_image_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = sanitize_vllm_output(
vllm_outputs[i], vision_language_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,12 @@ def is_neuron(self):
return self.device_type == "neuron"


@dataclass
class AudioFeaturesConfig:
feature_dims: int = 80
sequence_length: int = 3000


@dataclass
class LoRAConfig:
max_lora_rank: int
Expand Down
2 changes: 2 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
state=seq_group.state,
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.prompt_run else None,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Optional, Tuple

from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
ParallelConfig, SchedulerConfig, LoRAConfig,
AudioFeaturesConfig)


@dataclass
Expand Down Expand Up @@ -50,6 +51,7 @@ def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model

# TODO: Add Whisper-specific args maybe
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -282,7 +284,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]:
DeviceConfig, Optional[LoRAConfig],
Optional[AudioFeaturesConfig]]:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
Expand Down Expand Up @@ -310,8 +313,9 @@ def create_engine_configs(
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
audio_features_config = AudioFeaturesConfig() # for now
return (model_config, cache_config, parallel_config, scheduler_config,
device_config, lora_config)
device_config, lora_config, audio_features_config)


@dataclass
Expand Down
26 changes: 16 additions & 10 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData

logger = init_logger(__name__)

Expand Down Expand Up @@ -226,6 +227,7 @@ async def add_request_async(
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
Expand All @@ -246,6 +248,7 @@ async def add_request_async(
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
multi_modal_data=multi_modal_data,
)

async def _run_workers_async(
Expand Down Expand Up @@ -423,6 +426,7 @@ async def add_request(
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
Expand All @@ -439,6 +443,7 @@ async def add_request(
f"sampling_params: {sampling_params}, "
f"prompt_token_ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
# TODO: Add multi_modal_data to the log.

if not self.is_running:
if self.start_engine_loop:
Expand Down Expand Up @@ -473,7 +478,8 @@ async def add_request(
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos)
prefix_pos=prefix_pos,
multi_modal_data=multi_modal_data)

return stream

Expand All @@ -485,6 +491,7 @@ async def generate(
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.

Expand Down Expand Up @@ -558,15 +565,14 @@ async def generate(
arrival_time = time.monotonic()

try:
stream = await self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
stream = await self.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
multi_modal_data=multi_modal_data)

async for request_output in stream:
yield request_output
Expand Down
Loading