diff --git a/runner/Dockerfile b/runner/Dockerfile index f68e8591..f4138e8e 100644 --- a/runner/Dockerfile +++ b/runner/Dockerfile @@ -12,7 +12,8 @@ ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ apt-get install -y build-essential libssl-dev zlib1g-dev libbz2-dev \ libreadline-dev libsqlite3-dev wget curl llvm libncurses5-dev libncursesw5-dev \ - xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git + xz-utils tk-dev libffi-dev liblzma-dev python3-openssl git \ + ffmpeg # Install pyenv RUN curl https://pyenv.run | bash diff --git a/runner/app/main.py b/runner/app/main.py index 015734e0..604808b1 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -42,6 +42,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.image_to_video import ImageToVideoPipeline return ImageToVideoPipeline(model_id) + case "audio-to-text": + from app.pipelines.audio_to_text import AudioToTextPipeline + + return AudioToTextPipeline(model_id) case "frame-interpolation": raise NotImplementedError("frame-interpolation pipeline not implemented") case "upscale": @@ -67,6 +71,10 @@ def load_route(pipeline: str) -> any: from app.routes import image_to_video return image_to_video.router + case "audio-to-text": + from app.routes import audio_to_text + + return audio_to_text.router case "frame-interpolation": raise NotImplementedError("frame-interpolation pipeline not implemented") case "upscale": diff --git a/runner/app/pipelines/audio_to_text.py b/runner/app/pipelines/audio_to_text.py new file mode 100644 index 00000000..300831b9 --- /dev/null +++ b/runner/app/pipelines/audio_to_text.py @@ -0,0 +1,78 @@ +import logging +import os +from typing import List + +import torch +from app.pipelines.base import Pipeline +from app.pipelines.utils import get_model_dir, get_torch_device +from app.pipelines.utils.audio import AudioConverter +from fastapi import File, UploadFile +from huggingface_hub import file_download +from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline + +logger = logging.getLogger(__name__) + + +MODEL_INCOMPATIBLE_EXTENSIONS = { + "openai/whisper-large-v3": ["mp4", "m4a", "ac3"], +} + + +class AudioToTextPipeline(Pipeline): + def __init__(self, model_id: str): + self.model_id = model_id + kwargs = {} + + torch_device = get_torch_device() + folder_name = file_download.repo_folder_name( + repo_id=model_id, repo_type="model" + ) + folder_path = os.path.join(get_model_dir(), folder_name) + # Load fp16 variant if fp16 safetensors files are found in cache + has_fp16_variant = any( + ".fp16.safetensors" in fname + for _, _, files in os.walk(folder_path) + for fname in files + ) + if torch_device != "cpu" and has_fp16_variant: + logger.info("AudioToTextPipeline loading fp16 variant for %s", model_id) + + kwargs["torch_dtype"] = torch.float16 + kwargs["variant"] = "fp16" + + if os.environ.get("BFLOAT16"): + logger.info("AudioToTextPipeline using bfloat16 precision for %s", model_id) + kwargs["torch_dtype"] = torch.bfloat16 + + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_id, low_cpu_mem_usage=True, use_safetensors=True, cache_dir=get_model_dir(), **kwargs + ).to(torch_device) + + processor = AutoProcessor.from_pretrained(model_id, cache_dir=get_model_dir()) + + self.ldm = pipeline( + "automatic-speech-recognition", + model=model, + tokenizer=processor.tokenizer, + feature_extractor=processor.feature_extractor, + max_new_tokens=128, + chunk_length_s=30, + batch_size=16, + return_timestamps=True, + **kwargs, + ) + + def __call__(self, audio: UploadFile, **kwargs) -> List[File]: + # Convert M4A/MP4 files for pipeline compatibility. + if ( + os.path.splitext(audio.filename)[1].lower().lstrip(".") + in MODEL_INCOMPATIBLE_EXTENSIONS[self.model_id] + ): + audio_converter = AudioConverter() + converted_bytes = audio_converter.convert(audio, "mp3") + audio_converter.write_bytes_to_file(converted_bytes, audio) + + return self.ldm(audio.file.read(), **kwargs) + + def __str__(self) -> str: + return f"AudioToTextPipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index fda2dfbf..f2132a08 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -1,5 +1,5 @@ from app.pipelines.base import Pipeline -from app.pipelines.util import ( +from app.pipelines.utils import ( get_torch_device, get_model_dir, SafetyChecker, diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index c3a528ee..cb11ab97 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -1,5 +1,5 @@ from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker +from app.pipelines.utils import get_torch_device, get_model_dir, SafetyChecker from diffusers import StableVideoDiffusionPipeline from huggingface_hub import file_download diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 0f9f4795..baf6b63d 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -16,7 +16,7 @@ from safetensors.torch import load_file from app.pipelines.base import Pipeline -from app.pipelines.util import ( +from app.pipelines.utils import ( get_model_dir, get_torch_device, SafetyChecker, @@ -152,7 +152,7 @@ def __init__(self, model_id: str): self.ldm = compile_model(self.ldm) # Warm-up the pipeline. - # TODO: Not yet supported for ImageToImagePipeline. + # TODO: Not yet supported for TextToImagePipeline. if os.getenv("SFAST_WARMUP", "true").lower() == "true": logger.warning( "The 'SFAST_WARMUP' flag is not yet supported for the " diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index bd2e4f2a..bc93eed0 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -1,5 +1,5 @@ from app.pipelines.base import Pipeline -from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model +from app.pipelines.utils import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model from diffusers import ( StableDiffusionUpscalePipeline diff --git a/runner/app/pipelines/util.py b/runner/app/pipelines/util.py deleted file mode 100644 index 584d788c..00000000 --- a/runner/app/pipelines/util.py +++ /dev/null @@ -1,133 +0,0 @@ -import torch -import os -import numpy as np -from torch import dtype as TorchDtype -from pathlib import Path -from PIL import Image -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from transformers import CLIPFeatureExtractor -from typing import Optional -import logging -import re - -logger = logging.getLogger(__name__) - - -def get_model_dir() -> Path: - return Path(os.environ["MODEL_DIR"]) - - -def get_model_path(model_id: str) -> Path: - return get_model_dir() / model_id.lower() - - -def get_torch_device(): - if torch.cuda.is_available(): - return torch.device("cuda") - elif torch.backends.mps.is_available(): - return torch.device("mps") - else: - return torch.device("cpu") - - -def validate_torch_device(device_name: str) -> bool: - """Checks if the given PyTorch device name is valid and available. - - Args: - device_name: Name of the device ('cuda:0', 'cuda', 'cpu'). - - Returns: - True if valid and available, False otherwise. - """ - try: - device = torch.device(device_name) - if device.type == "cuda": - # Check if CUDA is available and the specified index is within range - if device.index is None: - return torch.cuda.is_available() - else: - return device.index < torch.cuda.device_count() - return True - except RuntimeError: - return False - - -def is_lightning_model(model_id: str) -> bool: - """Checks if the model is a Lightning model. - - Args: - model_id: Model ID. - - Returns: - True if the model is a Lightning model, False otherwise. - """ - return re.search(r"[-_]lightning", model_id, re.IGNORECASE) is not None - - -def is_turbo_model(model_id: str) -> bool: - """Checks if the model is a Turbo model. - - Args: - model_id: Model ID. - - Returns: - True if the model is a Turbo model, False otherwise. - """ - return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None - - -class SafetyChecker: - """Checks images for unsafe or inappropriate content using a pretrained model. - - Attributes: - device (str): Device for inference. - """ - - def __init__( - self, - device: Optional[str] = "cuda", - dtype: Optional[TorchDtype] = torch.float16, - ): - """Initializes the SafetyChecker. - - Args: - device: Device for inference. Defaults to "cuda". - dtype: Data type for inference. Defaults to `torch.float16`. - """ - device = device.lower() if device else device - if not validate_torch_device(device): - default_device = get_torch_device() - logger.warning( - f"Device '{device}' not found. Defaulting to '{default_device}'." - ) - device = default_device - - self.device = device - self._dtype = dtype - self._safety_checker = StableDiffusionSafetyChecker.from_pretrained( - "CompVis/stable-diffusion-safety-checker" - ).to(self.device) - self._feature_extractor = CLIPFeatureExtractor.from_pretrained( - "openai/clip-vit-base-patch32" - ) - - def check_nsfw_images( - self, images: list[Image.Image] - ) -> tuple[list[Image.Image], list[bool]]: - """Checks images for unsafe content. - - Args: - images: Images to check. - - Returns: - Tuple of images and corresponding NSFW flags. - """ - safety_checker_input = self._feature_extractor(images, return_tensors="pt").to( - self.device - ) - images_np = [np.array(img) for img in images] - _, has_nsfw_concept = self._safety_checker( - images=images_np, - clip_input=safety_checker_input.pixel_values.to(self._dtype), - ) - return images, has_nsfw_concept diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index a5e6f1eb..dd1b9573 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -1,12 +1,11 @@ """This module contains several utility functions that are used across the pipelines module.""" from app.pipelines.utils.utils import ( + SafetyChecker, get_model_dir, get_model_path, get_torch_device, - validate_torch_device, is_lightning_model, is_turbo_model, - get_temp_file, - SafetyChecker, + validate_torch_device, ) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 58de41da..51987495 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -81,26 +81,6 @@ def is_turbo_model(model_id: str) -> bool: return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None -def get_temp_file(prefix: str, extension: str) -> str: - """Generates a temporary file path with the specified prefix and extension. - - Args: - prefix: The prefix for the temporary file. - extension: The extension for the temporary file. - - Returns: - The path to a non-existing temporary file with the specified prefix and extension. - """ - if not extension.startswith("."): - extension = "." + extension - filename = f"{prefix}{uuid.uuid4()}{extension}" - temp_path = os.path.join(tempfile.gettempdir(), filename) - while os.path.exists(temp_path): - filename = f"{prefix}{uuid.uuid4()}{extension}" - temp_path = os.path.join(tempfile.gettempdir(), filename) - return temp_path - - class SafetyChecker: """Checks images for unsafe or inappropriate content using a pretrained model. diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py new file mode 100644 index 00000000..5a249d9a --- /dev/null +++ b/runner/app/routes/audio_to_text.py @@ -0,0 +1,90 @@ +import logging +import os +from typing import Annotated + +from app.dependencies import get_pipeline +from app.pipelines.base import Pipeline +from app.pipelines.utils.audio import AudioConversionError +from app.routes.util import HTTPError, TextResponse, file_exceeds_max_size, http_error +from fastapi import APIRouter, Depends, File, Form, UploadFile, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +router = APIRouter() + +logger = logging.getLogger(__name__) + +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + +def handle_pipeline_error(e: Exception) -> JSONResponse: + """Handles exceptions raised during audio processing. + + Args: + e: The exception raised during audio processing. + + Returns: + A JSONResponse with the appropriate error message and status code. + """ + logger.error(f"Audio processing error: {str(e)}") # Log the detailed error + if "Soundfile is either not in the correct format or is malformed" in str( + e + ) or isinstance(e, AudioConversionError): + status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE + error_message = "Unsupported audio format or malformed file." + else: + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + error_message = "Internal server error during audio processing." + + return JSONResponse( + status_code=status_code, + content=http_error(error_message), + ) + + +@router.post("/audio-to-text", response_model=TextResponse, responses=RESPONSES) +@router.post( + "/audio-to-text/", + response_model=TextResponse, + responses=RESPONSES, + include_in_schema=False, +) +async def audio_to_text( + audio: Annotated[UploadFile, File()], + model_id: Annotated[str, Form()] = "", + pipeline: Pipeline = Depends(get_pipeline), + token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), +): + auth_token = os.environ.get("AUTH_TOKEN") + if auth_token: + if not token or token.credentials != auth_token: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + headers={"WWW-Authenticate": "Bearer"}, + content=http_error("Invalid bearer token"), + ) + + if model_id != "" and model_id != pipeline.model_id: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error( + f"pipeline configured with {pipeline.model_id} but called with " + f"{model_id}" + ), + ) + + if file_exceeds_max_size(audio, 50 * 1024 * 1024): + return JSONResponse( + status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, + content=http_error("File size exceeds limit"), + ) + + try: + return pipeline(audio=audio) + except Exception as e: + return handle_pipeline_error(e) diff --git a/runner/app/routes/util.py b/runner/app/routes/util.py index 97f09db8..96736305 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/util.py @@ -1,7 +1,9 @@ import base64 import io +import os from typing import List +from fastapi import UploadFile from PIL import Image from pydantic import BaseModel @@ -22,6 +24,16 @@ class VideoResponse(BaseModel): frames: List[List[Media]] +class chunk(BaseModel): + timestamp: tuple + text: str + + +class TextResponse(BaseModel): + text: str + chunks: List[chunk] + + class APIError(BaseModel): msg: str @@ -42,3 +54,27 @@ def image_to_base64(img: Image, format: str = "png") -> str: def image_to_data_url(img: Image, format: str = "png") -> str: return "data:image/png;base64," + image_to_base64(img, format=format) + + +def file_exceeds_max_size( + input_file: UploadFile, max_size: int = 10 * 1024 * 1024 +) -> bool: + """Checks if the uploaded file exceeds the specified maximum size. + + Args: + input_file: The uploaded file to check. + max_size: The maximum allowed file size in bytes. Defaults to 10 MB. + + Returns: + True if the file exceeds the maximum size, False otherwise. + """ + try: + if input_file.file: + # Get size by moving the cursor to the end of the file and back. + input_file.file.seek(0, os.SEEK_END) + file_size = input_file.file.tell() + input_file.file.seek(0) + return file_size > max_size + except Exception as e: + print(f"Error checking file size: {e}") + return False diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 13902220..822590d4 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -31,6 +31,9 @@ function download_alpha_models() { # Download upscale models huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models + # Download audio-to-text models. + huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models + printf "\nDownloading token-gated models...\n" # Download image-to-video models (token-gated). diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index dcd21eea..fd13217a 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -5,7 +5,8 @@ import yaml from app.main import app, use_route_names_as_operation_ids -from app.routes import health, image_to_image, image_to_video, text_to_image, upscale +from app.routes import health, image_to_image, image_to_video, text_to_image, upscale, audio_to_text + from fastapi.openapi.utils import get_openapi # Specify Endpoints for OpenAPI schema generation. @@ -77,6 +78,7 @@ def write_openapi(fname, entrypoint="runner"): app.include_router(image_to_image.router) app.include_router(image_to_video.router) app.include_router(upscale.router) + app.include_router(audio_to_text.router) use_route_names_as_operation_ids(app) diff --git a/runner/openapi.json b/runner/openapi.json index 390502f8..f28d8026 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -281,6 +281,89 @@ } ] } + }, + "/audio-to-text": { + "post": { + "summary": "Audio To Text", + "operationId": "audio_to_text", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Body_audio_to_text_audio_to_text_post" + } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/TextResponse" + } + } + } + }, + "400": { + "description": "Bad Request", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "401": { + "description": "Unauthorized", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "413": { + "description": "Request Entity Too Large", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "500": { + "description": "Internal Server Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPError" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + }, + "security": [ + { + "HTTPBearer": [] + } + ] + } } }, "components": { @@ -298,6 +381,25 @@ ], "title": "APIError" }, + "Body_audio_to_text_audio_to_text_post": { + "properties": { + "audio": { + "type": "string", + "format": "binary", + "title": "Audio" + }, + "model_id": { + "type": "string", + "title": "Model Id", + "default": "" + } + }, + "type": "object", + "required": [ + "audio" + ], + "title": "Body_audio_to_text_audio_to_text_post" + }, "Body_image_to_image_image_to_image_post": { "properties": { "prompt": { @@ -517,6 +619,27 @@ ], "title": "Media" }, + "TextResponse": { + "properties": { + "text": { + "type": "string", + "title": "Text" + }, + "chunks": { + "items": { + "$ref": "#/components/schemas/chunk" + }, + "type": "array", + "title": "Chunks" + } + }, + "type": "object", + "required": [ + "text", + "chunks" + ], + "title": "TextResponse" + }, "TextToImageParams": { "properties": { "model_id": { @@ -623,6 +746,25 @@ "frames" ], "title": "VideoResponse" + }, + "chunk": { + "properties": { + "timestamp": { + "items": {}, + "type": "array", + "title": "Timestamp" + }, + "text": { + "type": "string", + "title": "Text" + } + }, + "type": "object", + "required": [ + "timestamp", + "text" + ], + "title": "chunk" } }, "securitySchemes": { diff --git a/runner/requirements.txt b/runner/requirements.txt index 17b38644..3852800d 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -14,5 +14,6 @@ deepcache==0.1.1 safetensors==0.4.3 scipy==1.13.0 numpy==1.26.4 +av==12.1.0 sentencepiece== 0.2.0 protobuf==5.27.2 diff --git a/worker/docker.go b/worker/docker.go index e081baeb..9b0bf425 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -37,6 +37,7 @@ var containerHostPorts = map[string]string{ "image-to-image": "8001", "image-to-video": "8002", "upscale": "8003", + "audio-to-text": "8004", } type DockerManager struct { diff --git a/worker/multipart.go b/worker/multipart.go index 4526c3e4..de4fbd80 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -194,3 +194,34 @@ func NewUpscaleMultipartWriter(w io.Writer, req UpscaleMultipartRequestBody) (*m return mw, nil } +func NewAudioToTextMultipartWriter(w io.Writer, req AudioToTextMultipartRequestBody) (*multipart.Writer, error) { + mw := multipart.NewWriter(w) + writer, err := mw.CreateFormFile("audio", req.Audio.Filename()) + if err != nil { + return nil, err + } + audioSize := req.Audio.FileSize() + audioRdr, err := req.Audio.Reader() + if err != nil { + return nil, err + } + copied, err := io.Copy(writer, audioRdr) + if err != nil { + return nil, err + } + if copied != audioSize { + return nil, fmt.Errorf("failed to copy audio to multipart request audioBytes=%v copiedBytes=%v", audioSize, copied) + } + + if req.ModelId != nil { + if err := mw.WriteField("model_id", *req.ModelId); err != nil { + return nil, err + } + } + + if err := mw.Close(); err != nil { + return nil, err + } + + return mw, nil +} diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 20f36007..7957a4af 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -31,6 +31,12 @@ type APIError struct { Msg string `json:"msg"` } +// BodyAudioToTextAudioToTextPost defines model for Body_audio_to_text_audio_to_text_post. +type BodyAudioToTextAudioToTextPost struct { + Audio openapi_types.File `json:"audio"` + ModelId *string `json:"model_id,omitempty"` +} + // BodyImageToImageImageToImagePost defines model for Body_image_to_image_image_to_image_post. type BodyImageToImageImageToImagePost struct { GuidanceScale *float32 `json:"guidance_scale,omitempty"` @@ -94,6 +100,12 @@ type Media struct { Url string `json:"url"` } +// TextResponse defines model for TextResponse. +type TextResponse struct { + Chunks []Chunk `json:"chunks"` + Text string `json:"text"` +} + // TextToImageParams defines model for TextToImageParams. type TextToImageParams struct { GuidanceScale *float32 `json:"guidance_scale,omitempty"` @@ -131,6 +143,15 @@ type VideoResponse struct { Frames [][]Media `json:"frames"` } +// Chunk defines model for chunk. +type Chunk struct { + Text string `json:"text"` + Timestamp []interface{} `json:"timestamp"` +} + +// AudioToTextMultipartRequestBody defines body for AudioToText for multipart/form-data ContentType. +type AudioToTextMultipartRequestBody = BodyAudioToTextAudioToTextPost + // ImageToImageMultipartRequestBody defines body for ImageToImage for multipart/form-data ContentType. type ImageToImageMultipartRequestBody = BodyImageToImageImageToImagePost @@ -278,6 +299,9 @@ func WithRequestEditorFn(fn RequestEditorFn) ClientOption { // The interface specification for the client above. type ClientInterface interface { + // AudioToTextWithBody request with any body + AudioToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // Health request Health(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -296,6 +320,18 @@ type ClientInterface interface { UpscaleWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) } +func (c *Client) AudioToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewAudioToTextRequestWithBody(c.Server, contentType, body) + if err != nil { + return nil, err + } + req = req.WithContext(ctx) + if err := c.applyEditors(ctx, req, reqEditors); err != nil { + return nil, err + } + return c.Client.Do(req) +} + func (c *Client) Health(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { req, err := NewHealthRequest(c.Server) if err != nil { @@ -368,6 +404,35 @@ func (c *Client) UpscaleWithBody(ctx context.Context, contentType string, body i return c.Client.Do(req) } +// NewAudioToTextRequestWithBody generates requests for AudioToText with any type of body +func NewAudioToTextRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { + var err error + + serverURL, err := url.Parse(server) + if err != nil { + return nil, err + } + + operationPath := fmt.Sprintf("/audio-to-text") + if operationPath[0] == '/' { + operationPath = "." + operationPath + } + + queryURL, err := serverURL.Parse(operationPath) + if err != nil { + return nil, err + } + + req, err := http.NewRequest("POST", queryURL.String(), body) + if err != nil { + return nil, err + } + + req.Header.Add("Content-Type", contentType) + + return req, nil +} + // NewHealthRequest generates requests for Health func NewHealthRequest(server string) (*http.Request, error) { var err error @@ -565,6 +630,9 @@ func WithBaseURL(baseURL string) ClientOption { // ClientWithResponsesInterface is the interface specification for the client with responses above. type ClientWithResponsesInterface interface { + // AudioToTextWithBodyWithResponse request with any body + AudioToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*AudioToTextResponse, error) + // HealthWithResponse request HealthWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*HealthResponse, error) @@ -583,6 +651,33 @@ type ClientWithResponsesInterface interface { UpscaleWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*UpscaleResponse, error) } +type AudioToTextResponse struct { + Body []byte + HTTPResponse *http.Response + JSON200 *TextResponse + JSON400 *HTTPError + JSON401 *HTTPError + JSON413 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError +} + +// Status returns HTTPResponse.Status +func (r AudioToTextResponse) Status() string { + if r.HTTPResponse != nil { + return r.HTTPResponse.Status + } + return http.StatusText(0) +} + +// StatusCode returns HTTPResponse.StatusCode +func (r AudioToTextResponse) StatusCode() int { + if r.HTTPResponse != nil { + return r.HTTPResponse.StatusCode + } + return 0 +} + type HealthResponse struct { Body []byte HTTPResponse *http.Response @@ -705,6 +800,15 @@ func (r UpscaleResponse) StatusCode() int { return 0 } +// AudioToTextWithBodyWithResponse request with arbitrary body returning *AudioToTextResponse +func (c *ClientWithResponses) AudioToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*AudioToTextResponse, error) { + rsp, err := c.AudioToTextWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseAudioToTextResponse(rsp) +} + // HealthWithResponse request returning *HealthResponse func (c *ClientWithResponses) HealthWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*HealthResponse, error) { rsp, err := c.Health(ctx, reqEditors...) @@ -758,6 +862,67 @@ func (c *ClientWithResponses) UpscaleWithBodyWithResponse(ctx context.Context, c return ParseUpscaleResponse(rsp) } +// ParseAudioToTextResponse parses an HTTP response from a AudioToTextWithResponse call +func ParseAudioToTextResponse(rsp *http.Response) (*AudioToTextResponse, error) { + bodyBytes, err := io.ReadAll(rsp.Body) + defer func() { _ = rsp.Body.Close() }() + if err != nil { + return nil, err + } + + response := &AudioToTextResponse{ + Body: bodyBytes, + HTTPResponse: rsp, + } + + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: + var dest TextResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON200 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 401: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON401 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 413: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON413 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + + return response, nil +} + // ParseHealthResponse parses an HTTP response from a HealthWithResponse call func ParseHealthResponse(rsp *http.Response) (*HealthResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) @@ -974,6 +1139,9 @@ func ParseUpscaleResponse(rsp *http.Response) (*UpscaleResponse, error) { // ServerInterface represents all server handlers. type ServerInterface interface { + // Audio To Text + // (POST /audio-to-text) + AudioToText(w http.ResponseWriter, r *http.Request) // Health // (GET /health) Health(w http.ResponseWriter, r *http.Request) @@ -995,6 +1163,12 @@ type ServerInterface interface { type Unimplemented struct{} +// Audio To Text +// (POST /audio-to-text) +func (_ Unimplemented) AudioToText(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Health // (GET /health) func (_ Unimplemented) Health(w http.ResponseWriter, r *http.Request) { @@ -1034,6 +1208,23 @@ type ServerInterfaceWrapper struct { type MiddlewareFunc func(http.Handler) http.Handler +// AudioToText operation middleware +func (siw *ServerInterfaceWrapper) AudioToText(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + siw.Handler.AudioToText(w, r) + })) + + for _, middleware := range siw.HandlerMiddlewares { + handler = middleware(handler) + } + + handler.ServeHTTP(w, r.WithContext(ctx)) +} + // Health operation middleware func (siw *ServerInterfaceWrapper) Health(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -1230,6 +1421,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl ErrorHandlerFunc: options.ErrorHandlerFunc, } + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/audio-to-text", wrapper.AudioToText) + }) r.Group(func(r chi.Router) { r.Get(options.BaseURL+"/health", wrapper.Health) }) @@ -1252,28 +1446,31 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xY3W/bNhD/Vwhuj47teM06+C3JttbY2ga12z0UgcFIZ5mtRHIkldYI/L8PPMoS9VU7", - "y0e3Ik+2pPv43fF+dyfd0EhmSgoQ1tDpDTXRGjKGf08vZr9pLbX7r7RUoC0HfJKZxP1YblOgU/rKJHRA", - "7Ua5C2M1FwndbgdUw9851xDT6QdUuRyUKqXtUk9efYTI0u2Ansl4s+QZS2BpZfGncamksW1YSc5jJiJY", - "mog5Lzc0hhXLU0unz4cnlfMXhRyZo1wJQeTZFWgHAb04AyupM2bplF5xwfSGVkZmKNIKu9BdfgXLcYgF", - "zZD9iDIZQ7rkcc0SDfC8cgJkFndBEpAwy69hqbTMlO218bqQIxderstUnvkzMEsFusvgcWAvzwgGaMgF", - "6JZVLiwkPrzKzk63H4JhK7CbZbSG6FPNs9U5VM7nKEbOUaw0cyVlCkygHYA49Dh3113gjNUgEruuORsP", - "fwl87SRaJ9eggdpF5SssYMShVb+XMNc8Btm87CbMSplaTD9XcH5XpjMXa+DJun7gJ88DvZf+eZfqXUh1", - "p/LPpOVSLK/y6BPYppHjyfPQipMkZyhZsxbEISQ3sGR5suwpjPEkoIATJqd5Qvpr5BuU9GceN2AfjyfP", - "Kk9/4fO2ZqOc91Rxfyn2VXGusF+Wv911+80q6T/cpm7XaDrz3HEoLxeLi54lIAbLeOr+/ahhRaf0h1G1", - "SoyKPWJUDvomwEI9AFb56gHynqU8Zo6heyFxC5nZh61pb1th+dVbKoEwrdkGYwjRNg104QaW2vX5rgjq", - "eI1lNq+3YPrmDxrOFRToWq6qhls56PCPPHgLRklhoIdJ5uCMvYKYszBPfrp35anVJkx41nVYHbi9pxZe", - "YVafQzK8dtd36oS5TkO5dzrdu8vmKGO8RUQUROaBd0S0gC92ITHwC6aZT/ZDLbDVmD5gMH/nuyWaFSvQ", - "gKm10Nh6TsYNqztZMkfZ/92+Wg73W07zIqigmNs121HYe9tyKqNah2Fi82ZFpx9uWrm6aUG8DJrNnzJC", - "N612M2i9kYIxPfPf36hEETNZuLv7eO/i8K4KySBTB4yC92796W/FK82yRiu+ZU9u5KRc5b3hPT26cB+G", - "VMPbCggrMso1t5u5g+Kxu7F4BkyDLr8mYBn7W6WRtbWKbp0NLlbSs8JEmis83yk9FYQplXJ/4MRKonNB", - "TmdEcQUpFz6eXV3wa1AA2j1/mwuBjq5BG29rPDwejl1CpALBFKdT+hPeGlDF7Bphj9Y4RrEJA/LaHQ06", - "n8XllKUuZT4fqDUZj91PJIUFgVoB6NFH49zvPqnsO8ZwjmNi6gmZ51EExqzylJRHgkeQZ5nbfEuI7uYI", - "u+iRlUflprxbpOthIbMLglNfD2Cs2xAbcWV5arli2o7cyn0UM8sOD+3QV9ttvSZde9w+YMbrO8ihOR/Q", - "Z/d56uXO2+H/jMXkrT8S9DuZ3Kvf1vrbRlCJkHJFPnms8GfCghYsJXPQ16BJ9R6x6zs4Q8KO8+Fyexly", - "wn9dW0i/KTS4ga+je7mBXfCxuNH/wvzI3Kj3/idufM/c8BWO3LDwxR4wNoK18KvM+PfBtxfPp+HwRID7", - "JYCrscZsKD6H9Vf+u0LgYedB59e5JwI8EeB+CbAr5q3XcmYMKtU9la9X56nMY3IusywX3G7IC2bhM9vQ", - "4iMavtSZ6WgUa2DZUeKfDtNCfRg5dfdG/08AAAD//8a69BN0HgAA", + "H4sIAAAAAAAC/+xZWXPbNhD+Kxi0j7IlO3HT0ZvtpomnOTyRkj5kPBqYXFFISIDF4UT16L93sKBI8Arl", + "8ZEm4yeJ4mL32+NbLKBrGskslwKE0XR6TXW0gozh1+Pzs+dKSeW+50rmoAwHfJPpxH0YblKgU/paJ3RE", + "zTp3D9ooLhK62Yyogn8sVxDT6UdccjEql5S6y3Xy8hNEhm5G9ETG6wWzMZcLIxcGvprGUy61aYNCGfdl", + "KVXGDJ3SSy6YWtPAKoq0oI5oJmNIFzx2y2NYMpu69cHK106AnMWDfnoUgae7edMXBp6xBJyo/9J47A5E", + "YnnMRAQLHTEHIXDp2f5RhexFIUdmKFdCEDa7BOUgoJVvh/QMRTpC6hF+A8tBiAXVkGFEt0jUiApImOFX", + "sMiVzHLTq+NNIUfOvVyXKpv5HOhFDqpL4UGgz2YEHdTkHFRLKxcGEu9epWe7th+CZksw60W0guhzzbJR", + "FirjMxQjpyhWqrmUMgUmUA9AHFqcuecucNooEIlZ1YxN9n8PbG0lWplrsCTfeuUrrEmXHap+kDBXPAbZ", + "fOwmzDLXNZ9+q+D8mevOWKyAJ6t6wo+eBete+vddS29DqluVfyYNl2JxaaPPYJpKDg6fhVqcJDlByZq2", + "wA8huYYFs8mipzAmhwEFnDA5tgnpr5HvUNJfeNyAfTA5fFpZ+hvft1c2ynmgivtLsa+KbY79svzsrtvv", + "Vkn/4zZ1s0bTGeeOpLycz897ZqEYDOOp+/argiWd0l/G1UQ1LsapcTnvNAEWywNgla0eIB9YymPmGDoI", + "iRvI9BC2pr5NheUPr6kEwpRia/QhRNtU0IUbWGpWp9siqOPVhhlbb8H07V803FdQoGv2qhpuZaDDPvLg", + "HehcCg09TNI7R+w1xJyFcfK7e1ecWm1Ch7muw+rA7S218Aq9/BKS4Y17vlUntCoN5d6rdHDUtSijvUZE", + "FHjmgXd4NIevpj8R0cqKz7snAsXDRJz69c1EjKgbtUMHHYxBD40XKkAF3tWc6HFyLjG750wx78h9TenV", + "LLLD9PGTD9CoVixBAYbWQGO0O5o0tG5lyQxlf7ihvJxgbjiyFE41arpesx2FPbj3pDKqsZeJ9dslnX68", + "bsXqugXxIiDyKxmhmQ4qN28fQOueIcf/UIkiZjJ3vw5R3/nhTRWSQaR22O8+uBmvv80tFcsa+80NN55m", + "e9ueV7zigY2oMB+6VMPb4ZDvtC1Hdmurzk4G2rAsD10NcM/L9wPQTSjojAVOeIwt8EinyCpu1jMXR4/c", + "DS4nwBSo8toLOeh/KpWsjMnpxungYik9pXWkeI7FOaXHgrA8T7mvVmIkUVaQ4zOS8xxSLnwytkXNryAH", + "UO79OysEGroCpb2uyf7B/sRFS+YgWM7plD7Bn0Y0Z2aFsMd4ebRn5N429NuzgUsLgjiLt1ddc1nkw0UQ", + "tHEzL+6yUhgQuCqzqeE5U2bsDhF7MTOsugYcKsfd7rY29Ry6Tog/+GJDrw4nkwauIKjjT9qFZ1dQtb0Z", + "bdczNrNRBFovbUoqsRF9eocQqhG+w/4Ji8k7nw9v9+Bh7L4XzJqVVPxfiNHwwZOHMVw4S54Lw82azKUk", + "r5hKfNQPD+8UROss04ZTiZDyvHP0UMk/EwaUYCmZgboCRapD4bZF4V4ZNqePF5uLEdU2y9zRvmA2mUuC", + "3HZLxys8/OBUCR29wJ+N6D1yLjx97Uq5TehUARG9wbHQdbjyfqO7xeGoUkws99zjdriQfOAuVz85/gBt", + "7pHoNyW6/09kLv3Rp8ENvEQc5AaOdQ/Fjf5rzgfmRn2YfeTGz8wNX+HIDTdx7rBtBOfcbzLjdhNo/ST9", + "uDk8EuBuCeBqrLE3FH9i9Ff++0LgfveDzv9UHgnwSIC7JcC2mDd+lVOjcVHdUnnlcppKG5NTmWVWuGPo", + "C2bgC1vT4q8PvOjR0/E4VsCyvcS/3U+L5fuRW043F5v/AgAA//+0aZ0EMSUAAA==", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 726021fd..7c6b384d 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -254,6 +254,60 @@ func (w *Worker) Upscale(ctx context.Context, req UpscaleMultipartRequestBody) ( return resp.JSON200, nil } +func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartRequestBody) (*TextResponse, error) { + c, err := w.borrowContainer(ctx, "audio-to-text", *req.ModelId) + if err != nil { + return nil, err + } + defer w.returnContainer(c) + + var buf bytes.Buffer + mw, err := NewAudioToTextMultipartWriter(&buf, req) + if err != nil { + return nil, err + } + + resp, err := c.Client.AudioToTextWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + if err != nil { + return nil, err + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("audio-to-text container returned 422", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 422") + } + + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) + if err != nil { + return nil, err + } + slog.Error("audio-to-text container returned 400", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 400") + } + + if resp.JSON413 != nil { + msg := "audio-to-text container returned 413 file too large; max file size is 50MB" + slog.Error("audio-to-text container returned 413", slog.String("err", string(msg))) + return nil, errors.New(msg) + } + + if resp.JSON500 != nil { + val, err := json.Marshal(resp.JSON500) + if err != nil { + return nil, err + } + slog.Error("audio-to-text container returned 500", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 500") + } + + return resp.JSON200, nil +} + func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags)