forked from livepeer/ai-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(runner): add support for SD3-medium model (livepeer#118)
This commit introduces support for the Stable Diffusion 3 Medium model from Hugging Face: [https://huggingface.co/stabilityai/stable-diffusion-3-medium](https://huggingface.co/stabilityai/stable-diffusion-3-medium). Please be aware that this model has restrictive licensing at the time of writing and is not yet advised for public use. Ensure you read and understand the [licensing terms](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE) before enabling this model on your orchestrator.
- Loading branch information
Showing
7 changed files
with
274 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
"""This module contains several utility functions that are used across the pipelines module.""" | ||
|
||
from app.pipelines.utils.utils import ( | ||
get_model_dir, | ||
get_model_path, | ||
get_torch_device, | ||
validate_torch_device, | ||
is_lightning_model, | ||
is_turbo_model, | ||
get_temp_file, | ||
SafetyChecker, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""This module provides functionality for converting audio files between different formats.""" | ||
|
||
from io import BytesIO | ||
|
||
import av | ||
from fastapi import UploadFile | ||
|
||
|
||
class AudioConversionError(Exception): | ||
"""Raised when an audio file cannot be converted.""" | ||
|
||
def __init__(self, message="Audio conversion failed."): | ||
self.message = message | ||
super().__init__(self.message) | ||
|
||
|
||
class AudioConverter: | ||
"""Converts audio files to different formats.""" | ||
|
||
@staticmethod | ||
def convert( | ||
upload_file: UploadFile, output_extension: str, output_codec=None | ||
) -> bytes: | ||
"""Converts an audio file to a different format. | ||
Args: | ||
upload_file: The audio file to convert. | ||
output_extension: The desired output format. | ||
output_codec: The desired output codec. | ||
Returns: | ||
The converted audio file as bytes. | ||
""" | ||
if output_extension.startswith("."): | ||
output_extension = output_extension.lstrip(".") | ||
|
||
output_buffer = BytesIO() | ||
|
||
input_container = av.open(upload_file.file) | ||
output_container = av.open(output_buffer, mode="w", format=output_extension) | ||
|
||
try: | ||
for stream in input_container.streams.audio: | ||
audio_stream = output_container.add_stream( | ||
output_codec if output_codec else output_extension | ||
) | ||
|
||
# Convert input audio to target format. | ||
for frame in input_container.decode(stream): | ||
for packet in audio_stream.encode(frame): | ||
output_container.mux(packet) | ||
|
||
# Flush remaining packets to the output. | ||
for packet in audio_stream.encode(): | ||
output_container.mux(packet) | ||
except Exception as e: | ||
raise AudioConversionError(f"Error during audio conversion: {e}") | ||
finally: | ||
input_container.close() | ||
output_container.close() | ||
|
||
# Return the converted audio bytes. | ||
output_buffer.seek(0) | ||
converted_bytes = output_buffer.read() | ||
return converted_bytes | ||
|
||
@staticmethod | ||
def write_bytes_to_file(bytes: bytes, upload_file: UploadFile): | ||
"""Writes bytes to a file. | ||
Args: | ||
bytes: The bytes to write. | ||
upload_file: The file to write to. | ||
""" | ||
upload_file.file.seek(0) | ||
upload_file.file.write(bytes) | ||
upload_file.file.seek(0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
"""This module contains several utility functions.""" | ||
|
||
import logging | ||
import os | ||
import re | ||
import tempfile | ||
import uuid | ||
from pathlib import Path | ||
from typing import Optional | ||
|
||
import numpy as np | ||
import torch | ||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
from PIL import Image | ||
from torch import dtype as TorchDtype | ||
from transformers import CLIPFeatureExtractor | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_model_dir() -> Path: | ||
return Path(os.environ["MODEL_DIR"]) | ||
|
||
|
||
def get_model_path(model_id: str) -> Path: | ||
return get_model_dir() / model_id.lower() | ||
|
||
|
||
def get_torch_device(): | ||
if torch.cuda.is_available(): | ||
return torch.device("cuda") | ||
elif torch.backends.mps.is_available(): | ||
return torch.device("mps") | ||
else: | ||
return torch.device("cpu") | ||
|
||
|
||
def validate_torch_device(device_name: str) -> bool: | ||
"""Checks if the given PyTorch device name is valid and available. | ||
Args: | ||
device_name: Name of the device ('cuda:0', 'cuda', 'cpu'). | ||
Returns: | ||
True if valid and available, False otherwise. | ||
""" | ||
try: | ||
device = torch.device(device_name) | ||
if device.type == "cuda": | ||
# Check if CUDA is available and the specified index is within range | ||
if device.index is None: | ||
return torch.cuda.is_available() | ||
else: | ||
return device.index < torch.cuda.device_count() | ||
return True | ||
except RuntimeError: | ||
return False | ||
|
||
|
||
def is_lightning_model(model_id: str) -> bool: | ||
"""Checks if the model is a Lightning model. | ||
Args: | ||
model_id: Model ID. | ||
Returns: | ||
True if the model is a Lightning model, False otherwise. | ||
""" | ||
return re.search(r"[-_]lightning", model_id, re.IGNORECASE) is not None | ||
|
||
|
||
def is_turbo_model(model_id: str) -> bool: | ||
"""Checks if the model is a Turbo model. | ||
Args: | ||
model_id: Model ID. | ||
Returns: | ||
True if the model is a Turbo model, False otherwise. | ||
""" | ||
return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None | ||
|
||
|
||
def get_temp_file(prefix: str, extension: str) -> str: | ||
"""Generates a temporary file path with the specified prefix and extension. | ||
Args: | ||
prefix: The prefix for the temporary file. | ||
extension: The extension for the temporary file. | ||
Returns: | ||
The path to a non-existing temporary file with the specified prefix and extension. | ||
""" | ||
if not extension.startswith("."): | ||
extension = "." + extension | ||
filename = f"{prefix}{uuid.uuid4()}{extension}" | ||
temp_path = os.path.join(tempfile.gettempdir(), filename) | ||
while os.path.exists(temp_path): | ||
filename = f"{prefix}{uuid.uuid4()}{extension}" | ||
temp_path = os.path.join(tempfile.gettempdir(), filename) | ||
return temp_path | ||
|
||
|
||
class SafetyChecker: | ||
"""Checks images for unsafe or inappropriate content using a pretrained model. | ||
Attributes: | ||
device (str): Device for inference. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
device: Optional[str] = "cuda", | ||
dtype: Optional[TorchDtype] = torch.float16, | ||
): | ||
"""Initializes the SafetyChecker. | ||
Args: | ||
device: Device for inference. Defaults to "cuda". | ||
dtype: Data type for inference. Defaults to `torch.float16`. | ||
""" | ||
device = device.lower() if device else device | ||
if not validate_torch_device(device): | ||
default_device = get_torch_device() | ||
logger.warning( | ||
f"Device '{device}' not found. Defaulting to '{default_device}'." | ||
) | ||
device = default_device | ||
|
||
self.device = device | ||
self._dtype = dtype | ||
self._safety_checker = StableDiffusionSafetyChecker.from_pretrained( | ||
"CompVis/stable-diffusion-safety-checker" | ||
).to(self.device) | ||
self._feature_extractor = CLIPFeatureExtractor.from_pretrained( | ||
"openai/clip-vit-base-patch32" | ||
) | ||
|
||
def check_nsfw_images( | ||
self, images: list[Image.Image] | ||
) -> tuple[list[Image.Image], list[bool]]: | ||
"""Checks images for unsafe content. | ||
Args: | ||
images: Images to check. | ||
Returns: | ||
Tuple of images and corresponding NSFW flags. | ||
""" | ||
safety_checker_input = self._feature_extractor(images, return_tensors="pt").to( | ||
self.device | ||
) | ||
images_np = [np.array(img) for img in images] | ||
_, has_nsfw_concept = self._safety_checker( | ||
images=images_np, | ||
clip_input=safety_checker_input.pixel_values.to(self._dtype), | ||
) | ||
return images, has_nsfw_concept |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters