Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

[WiP] Whisper Implementation #147

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 62 additions & 13 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import pytest
import torch
from transformers import AutoModelForCausalLM

from transformers import AutoModelForCausalLM, WhisperForConditionalGeneration, AutoProcessor
from vllm.sequence import MultiModalData
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer

Expand All @@ -16,7 +16,10 @@
def _read_prompts(filename: str) -> List[str]:
with open(filename, "r") as f:
prompts = f.readlines()
return prompts
return prompts


AUDIO_MODELS = {"openai/whisper-tiny": WhisperForConditionalGeneration}


@pytest.fixture
Expand Down Expand Up @@ -51,26 +54,51 @@ def __init__(
dtype: str = "half",
) -> None:
assert dtype in _STR_DTYPE_TO_TORCH_DTYPE
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=True,
).cuda()
self.torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
self.model_name = model_name
if model_name in AUDIO_MODELS:
self.model = AUDIO_MODELS[model_name].from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
trust_remote_code=True).cuda()
self.processor = AutoProcessor.from_pretrained(
model_name, torch_dtype=self.torch_dtype)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=self.torch_dtype,
trust_remote_code=True).cuda()
self.processor = None
if tokenizer_name is None:
tokenizer_name = model_name
self.tokenizer = get_tokenizer(tokenizer_name, trust_remote_code=True)

def generate(
self,
prompts: List[str],
audio_samples: Optional["TODO"] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
if audio_samples:
assert len(prompts) == len(audio_samples)
for i, prompt in enumerate(prompts):
if self.model_name in AUDIO_MODELS:
audio_sample = audio_samples[i]
input_features = self.processor(
audio_sample["array"],
sampling_rate=audio_sample["sampling_rate"],
return_tensors="pt").input_features

inputs = dict(
input_features=input_features.cuda()) #TODO: Fix this
else:
input_ids = self.tokenizer(prompt,
return_tensors="pt").input_ids
inputs = dict(input_ids=input_ids.cuda())

output_ids = self.model.generate(
input_ids.cuda(),
**inputs,
use_cache=True,
**kwargs,
)
Expand All @@ -87,8 +115,10 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
audio_samples: Optional["TODO"] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
audio_samples=audio_samples,
do_sample=False,
max_new_tokens=max_tokens)
for i in range(len(outputs)):
Expand Down Expand Up @@ -182,8 +212,24 @@ def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
audio_samples: Optional["TODO"] = None,
) -> List[Tuple[List[int], str]]:
if audio_samples:
assert len(prompts) == len(audio_samples)
processor = AutoProcessor.from_pretrained("openai/whisper-tiny")
input_features = processor(
audio_samples[0]["array"],
sampling_rate=audio_samples[0]["sampling_rate"],
return_tensors="pt").input_features
# change type of input features
input_features = input_features.to(
dtype=self.model.llm_engine.model_config.dtype)
multi_modal_data = MultiModalData(type=input_features.dtype,
data=input_features[0])
else:
multi_modal_data = None
req_outputs = self.model.generate(prompts,
multi_modal_data=multi_modal_data,
sampling_params=sampling_params)
outputs = []
for req_output in req_outputs:
Expand Down Expand Up @@ -221,9 +267,12 @@ def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
audio_samples: Optional["TODO"] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts, greedy_params)
outputs = self.generate(prompts,
greedy_params,
audio_samples=audio_samples)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand Down
67 changes: 67 additions & 0 deletions tests/models/test_whisper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import pytest
import torch
from datasets import load_dataset
from vllm.config import AudioFeaturesConfig

import os

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = "6"


@pytest.fixture()
def model_id():
return "openai/whisper-tiny"


@pytest.fixture()
def audio_features_config():
return AudioFeaturesConfig()


def sample_from_librispeech():
dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy",
"clean",
split="validation")
return dataset[0]


audio_sample = sample_from_librispeech()["audio"]


@pytest.mark.parametrize("dtype", ["float"]) # TODO fix that
@pytest.mark.parametrize("max_tokens", [2])
@pytest.mark.parametrize("prompts, audio_samples", [([""], [audio_sample])])
def test_text_to_audio_scenario(hf_runner, vllm_runner, model_id, prompts,
audio_samples, dtype: str,
max_tokens: int) -> None:
# hf_model = hf_runner(model_id, dtype=dtype)
# hf_outputs = hf_model.generate_greedy(prompts=prompts,
# audio_samples=audio_samples,
# max_tokens=max_tokens)
# del hf_model

# Truly cleans up GPU memory.
torch.cuda.empty_cache()

vllm_model = vllm_runner(model_id,
dtype=dtype,
enforce_eager=True,
tensor_parallel_size=1,
gpu_memory_utilization=0.9)
vllm_outputs = vllm_model.generate_greedy(prompts,
max_tokens,
audio_samples=audio_samples)
del vllm_model
# Truly cleans up GPU memory.
torch.cuda.empty_cache()

for i in range(len(hf_image_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = sanitize_vllm_output(
vllm_outputs[i], vision_language_config, model_id)
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")

6 changes: 6 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,12 @@ def is_neuron(self):
return self.device_type == "neuron"


@dataclass
class AudioFeaturesConfig:
feature_dims: int = 80
sequence_length: int = 3000


@dataclass
class LoRAConfig:
max_lora_rank: int
Expand Down
2 changes: 2 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
lora_request=seq_group.lora_request,
prefix=seq_group.prefix,
state=seq_group.state,
multi_modal_data=seq_group.multi_modal_data
if scheduler_outputs.prompt_run else None,
)
seq_group_metadata_list.append(seq_group_metadata)
return seq_group_metadata_list, scheduler_outputs
Expand Down
10 changes: 7 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Optional, Tuple

from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
ParallelConfig, SchedulerConfig, LoRAConfig,
AudioFeaturesConfig)


@dataclass
Expand Down Expand Up @@ -50,6 +51,7 @@ def __post_init__(self):
if self.tokenizer is None:
self.tokenizer = self.model

# TODO: Add Whisper-specific args maybe
@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
Expand Down Expand Up @@ -282,7 +284,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, Optional[LoRAConfig]]:
DeviceConfig, Optional[LoRAConfig],
Optional[AudioFeaturesConfig]]:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(
self.model, self.tokenizer, self.tokenizer_mode,
Expand Down Expand Up @@ -310,8 +313,9 @@ def create_engine_configs(
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
audio_features_config = AudioFeaturesConfig() # for now
return (model_config, cache_config, parallel_config, scheduler_config,
device_config, lora_config)
device_config, lora_config, audio_features_config)


@dataclass
Expand Down
26 changes: 16 additions & 10 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData

logger = init_logger(__name__)

Expand Down Expand Up @@ -226,6 +227,7 @@ async def add_request_async(
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
Expand All @@ -246,6 +248,7 @@ async def add_request_async(
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
multi_modal_data=multi_modal_data,
)

async def _run_workers_async(
Expand Down Expand Up @@ -423,6 +426,7 @@ async def add_request(
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncStream:
if self.log_requests:
shortened_prompt = prompt
Expand All @@ -439,6 +443,7 @@ async def add_request(
f"sampling_params: {sampling_params}, "
f"prompt_token_ids: {shortened_token_ids}, "
f"lora_request: {lora_request}.")
# TODO: Add multi_modal_data to the log.

if not self.is_running:
if self.start_engine_loop:
Expand Down Expand Up @@ -473,7 +478,8 @@ async def add_request(
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos)
prefix_pos=prefix_pos,
multi_modal_data=multi_modal_data)

return stream

Expand All @@ -485,6 +491,7 @@ async def generate(
prompt_token_ids: Optional[List[int]] = None,
lora_request: Optional[LoRARequest] = None,
prefix_pos: Optional[int] = None,
multi_modal_data: Optional[MultiModalData] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.

Expand Down Expand Up @@ -558,15 +565,14 @@ async def generate(
arrival_time = time.monotonic()

try:
stream = await self.add_request(
request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
)
stream = await self.add_request(request_id,
prompt,
sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time,
lora_request=lora_request,
prefix_pos=prefix_pos,
multi_modal_data=multi_modal_data)

async for request_output in stream:
yield request_output
Expand Down
Loading