Skip to content

Commit

Permalink
Add audio to text pipeline (livepeer#103)
Browse files Browse the repository at this point in the history
Add audio to text pipeline
---------
Co-authored-by: Rick Staa
  • Loading branch information
eliteprox committed Jul 26, 2024
1 parent 20226d2 commit e7f939b
Show file tree
Hide file tree
Showing 20 changed files with 675 additions and 185 deletions.
3 changes: 2 additions & 1 deletion runner/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get update && \
apt-get install -y build-essential libssl-dev zlib1g-dev libbz2-dev \
libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \
xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git
xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git \
ffmpeg

# Install pyenv
RUN curl https://pyenv.run | bash
Expand Down
8 changes: 8 additions & 0 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.image_to_video import ImageToVideoPipeline

return ImageToVideoPipeline(model_id)
case "audio-to-text":
from app.pipelines.audio_to_text import AudioToTextPipeline

return AudioToTextPipeline(model_id)
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
Expand All @@ -67,6 +71,10 @@ def load_route(pipeline: str) -> any:
from app.routes import image_to_video

return image_to_video.router
case "audio-to-text":
from app.routes import audio_to_text

return audio_to_text.router
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
Expand Down
78 changes: 78 additions & 0 deletions runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import logging
import os
from typing import List

import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.pipelines.utils.audio import AudioConverter
from fastapi import File, UploadFile
from huggingface_hub import file_download
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

logger = logging.getLogger(__name__)


MODEL_INCOMPATIBLE_EXTENSIONS = {
"openai/whisper-large-v3": ["mp4", "m4a", "ac3"],
}


class AudioToTextPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {}

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load fp16 variant if fp16 safetensors files are found in cache
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("AudioToTextPipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
logger.info("AudioToTextPipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, low_cpu_mem_usage=True, use_safetensors=True, cache_dir=get_model_dir(), **kwargs
).to(torch_device)

processor = AutoProcessor.from_pretrained(model_id, cache_dir=get_model_dir())

self.ldm = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
chunk_length_s=30,
batch_size=16,
return_timestamps=True,
**kwargs,
)

def __call__(self, audio: UploadFile, **kwargs) -> List[File]:
# Convert M4A/MP4 files for pipeline compatibility.
if (
os.path.splitext(audio.filename)[1].lower().lstrip(".")
in MODEL_INCOMPATIBLE_EXTENSIONS[self.model_id]
):
audio_converter = AudioConverter()
converted_bytes = audio_converter.convert(audio, "mp3")
audio_converter.write_bytes_to_file(converted_bytes, audio)

return self.ldm(audio.file.read(), **kwargs)

def __str__(self) -> str:
return f"AudioToTextPipeline model_id={self.model_id}"
2 changes: 1 addition & 1 deletion runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import (
from app.pipelines.utils import (
get_torch_device,
get_model_dir,
SafetyChecker,
Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker
from app.pipelines.utils import get_torch_device, get_model_dir, SafetyChecker

from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
Expand Down
4 changes: 2 additions & 2 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.util import (
from app.pipelines.utils import (
get_model_dir,
get_torch_device,
SafetyChecker,
Expand Down Expand Up @@ -152,7 +152,7 @@ def __init__(self, model_id: str):
self.ldm = compile_model(self.ldm)

# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
# TODO: Not yet supported for TextToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
logger.warning(
"The 'SFAST_WARMUP' flag is not yet supported for the "
Expand Down
2 changes: 1 addition & 1 deletion runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model
from app.pipelines.utils import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model

from diffusers import (
StableDiffusionUpscalePipeline
Expand Down
133 changes: 0 additions & 133 deletions runner/app/pipelines/util.py

This file was deleted.

5 changes: 2 additions & 3 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
"""This module contains several utility functions that are used across the pipelines module."""

from app.pipelines.utils.utils import (
SafetyChecker,
get_model_dir,
get_model_path,
get_torch_device,
validate_torch_device,
is_lightning_model,
is_turbo_model,
get_temp_file,
SafetyChecker,
validate_torch_device,
)
20 changes: 0 additions & 20 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,26 +81,6 @@ def is_turbo_model(model_id: str) -> bool:
return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None


def get_temp_file(prefix: str, extension: str) -> str:
"""Generates a temporary file path with the specified prefix and extension.
Args:
prefix: The prefix for the temporary file.
extension: The extension for the temporary file.
Returns:
The path to a non-existing temporary file with the specified prefix and extension.
"""
if not extension.startswith("."):
extension = "." + extension
filename = f"{prefix}{uuid.uuid4()}{extension}"
temp_path = os.path.join(tempfile.gettempdir(), filename)
while os.path.exists(temp_path):
filename = f"{prefix}{uuid.uuid4()}{extension}"
temp_path = os.path.join(tempfile.gettempdir(), filename)
return temp_path


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.
Expand Down
Loading

0 comments on commit e7f939b

Please sign in to comment.