forked from livepeer/ai-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add audio to text pipeline (livepeer#103)
Add audio to text pipeline --------- Co-authored-by: Rick Staa
- Loading branch information
Showing
20 changed files
with
675 additions
and
185 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import logging | ||
import os | ||
from typing import List | ||
|
||
import torch | ||
from app.pipelines.base import Pipeline | ||
from app.pipelines.utils import get_model_dir, get_torch_device | ||
from app.pipelines.utils.audio import AudioConverter | ||
from fastapi import File, UploadFile | ||
from huggingface_hub import file_download | ||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
MODEL_INCOMPATIBLE_EXTENSIONS = { | ||
"openai/whisper-large-v3": ["mp4", "m4a", "ac3"], | ||
} | ||
|
||
|
||
class AudioToTextPipeline(Pipeline): | ||
def __init__(self, model_id: str): | ||
self.model_id = model_id | ||
kwargs = {} | ||
|
||
torch_device = get_torch_device() | ||
folder_name = file_download.repo_folder_name( | ||
repo_id=model_id, repo_type="model" | ||
) | ||
folder_path = os.path.join(get_model_dir(), folder_name) | ||
# Load fp16 variant if fp16 safetensors files are found in cache | ||
has_fp16_variant = any( | ||
".fp16.safetensors" in fname | ||
for _, _, files in os.walk(folder_path) | ||
for fname in files | ||
) | ||
if torch_device != "cpu" and has_fp16_variant: | ||
logger.info("AudioToTextPipeline loading fp16 variant for %s", model_id) | ||
|
||
kwargs["torch_dtype"] = torch.float16 | ||
kwargs["variant"] = "fp16" | ||
|
||
if os.environ.get("BFLOAT16"): | ||
logger.info("AudioToTextPipeline using bfloat16 precision for %s", model_id) | ||
kwargs["torch_dtype"] = torch.bfloat16 | ||
|
||
model = AutoModelForSpeechSeq2Seq.from_pretrained( | ||
model_id, low_cpu_mem_usage=True, use_safetensors=True, cache_dir=get_model_dir(), **kwargs | ||
).to(torch_device) | ||
|
||
processor = AutoProcessor.from_pretrained(model_id, cache_dir=get_model_dir()) | ||
|
||
self.ldm = pipeline( | ||
"automatic-speech-recognition", | ||
model=model, | ||
tokenizer=processor.tokenizer, | ||
feature_extractor=processor.feature_extractor, | ||
max_new_tokens=128, | ||
chunk_length_s=30, | ||
batch_size=16, | ||
return_timestamps=True, | ||
**kwargs, | ||
) | ||
|
||
def __call__(self, audio: UploadFile, **kwargs) -> List[File]: | ||
# Convert M4A/MP4 files for pipeline compatibility. | ||
if ( | ||
os.path.splitext(audio.filename)[1].lower().lstrip(".") | ||
in MODEL_INCOMPATIBLE_EXTENSIONS[self.model_id] | ||
): | ||
audio_converter = AudioConverter() | ||
converted_bytes = audio_converter.convert(audio, "mp3") | ||
audio_converter.write_bytes_to_file(converted_bytes, audio) | ||
|
||
return self.ldm(audio.file.read(), **kwargs) | ||
|
||
def __str__(self) -> str: | ||
return f"AudioToTextPipeline model_id={self.model_id}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,11 @@ | ||
"""This module contains several utility functions that are used across the pipelines module.""" | ||
|
||
from app.pipelines.utils.utils import ( | ||
SafetyChecker, | ||
get_model_dir, | ||
get_model_path, | ||
get_torch_device, | ||
validate_torch_device, | ||
is_lightning_model, | ||
is_turbo_model, | ||
get_temp_file, | ||
SafetyChecker, | ||
validate_torch_device, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.