diff --git a/misc_utility_functions.py b/misc_utility_functions.py index 2f83fd7..d1dda3d 100644 --- a/misc_utility_functions.py +++ b/misc_utility_functions.py @@ -332,3 +332,6 @@ def find_clip_model_path(llm_model_name: str) -> Optional[str]: logger.error(f"No mmproj file found matching: {mmproj_model_name}") return None return mmproj_files[0] + +def sanitize_filename(text): + return re.sub(r'_{2,}', '_', re.sub(r'\s+', '_', text[:30])) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5761ce8..afee4ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ aioredis aioredlock aiosqlite apscheduler +ChatTTS @ git+https://github.com/Dicklesworthstone/ChatTTS faiss-cpu fast_vector_similarity fastapi @@ -12,11 +13,13 @@ llama-cpp-python magika mutagen nvgpu +openai pandas pillow psutil pydantic PyPDF2 +pydub pytest python-decouple python-multipart diff --git a/service_functions.py b/service_functions.py index f06946e..268d365 100644 --- a/service_functions.py +++ b/service_functions.py @@ -1,6 +1,6 @@ from logger_config import setup_logger import shared_resources -from shared_resources import load_model, text_completion_model_cache, is_gpu_available, evict_model_from_gpu +from shared_resources import load_model, text_completion_model_cache, is_gpu_available from database_functions import AsyncSessionLocal, execute_with_retry from misc_utility_functions import clean_filename_for_url_func, FakeUploadFile, sophisticated_sentence_splitter, merge_transcript_segments_into_combined_text, suppress_stdout_stderr, image_to_base64_data_uri, process_image, find_clip_model_path from embeddings_data_models import TextEmbedding, DocumentEmbedding, Document, AudioTranscript @@ -483,7 +483,7 @@ def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = matching_files.sort(key=os.path.getmtime, reverse=True) model_file_path = matching_files[0] is_llava_multimodal_model = 'llava' in llm_model_name and 'mmproj' not in llm_model_name - chat_handler = None + chat_handler = None # Determine the appropriate chat handler based on the model name if 'llava' in llm_model_name: clip_model_path = find_clip_model_path(llm_model_name) if clip_model_path is None: @@ -491,32 +491,30 @@ def load_text_completion_model(llm_model_name: str, raise_http_exception: bool = chat_handler = Llava16ChatHandler(clip_model_path=clip_model_path) with suppress_stdout_stderr(): gpu_info = is_gpu_available() - llama_split_mode = 2 if gpu_info and gpu_info['num_gpus'] > 1 else 0 - while True: - try: - model_instance = Llama( - model_path=model_file_path, - embedding=True if is_llava_multimodal_model else False, - n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, - flash_attn=USE_FLASH_ATTENTION, - verbose=USE_VERBOSE, - llama_split_mode=llama_split_mode, - n_gpu_layers=-1 if gpu_info['gpu_found'] else 0, - clip_model_path=clip_model_path if is_llava_multimodal_model else None, - chat_handler=chat_handler - ) - break - except ValueError as e: - if "cudaMalloc failed: out of memory" in str(e): - evict_model_from_gpu() - else: - raise + if gpu_info: + num_gpus = gpu_info['num_gpus'] + if num_gpus > 1: + llama_split_mode = 2 # 2, // split rows across GPUs | 1, // split layers and KV across GPUs + else: + llama_split_mode = 0 + else: + num_gpus = 0 + model_instance = Llama( + model_path=model_file_path, + embedding=True if is_llava_multimodal_model else False, + n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS, + flash_attn=USE_FLASH_ATTENTION, + verbose=USE_VERBOSE, + llama_split_mode=llama_split_mode, + n_gpu_layers=-1 if gpu_info['gpu_found'] else 0, + clip_model_path=clip_model_path if is_llava_multimodal_model else None, + chat_handler=chat_handler + ) text_completion_model_cache[llm_model_name] = model_instance - shared_resources.loaded_models[llm_model_name] = model_instance return model_instance except TypeError as e: logger.error(f"TypeError occurred while loading the model: {e}") - logger.error(traceback.format_exc()) + logger.error(traceback.format_exc()) raise except Exception as e: logger.error(f"Exception occurred while loading the model: {e}") diff --git a/shared_resources.py b/shared_resources.py index 9c65eb5..41e5c57 100644 --- a/shared_resources.py +++ b/shared_resources.py @@ -1,4 +1,4 @@ -from misc_utility_functions import is_redis_running, start_redis_server, build_faiss_indexes, suppress_stdout_stderr +from misc_utility_functions import is_redis_running, start_redis_server, build_faiss_indexes from database_functions import DatabaseWriter, initialize_db, AsyncSessionLocal, delete_expired_rows from ramdisk_functions import setup_ramdisk, copy_models_to_ramdisk, check_that_user_has_required_permissions_to_manage_ramdisks from logger_config import setup_logger @@ -17,13 +17,11 @@ from decouple import config from fastapi import HTTPException from apscheduler.schedulers.asyncio import AsyncIOScheduler -from collections import OrderedDict logger = setup_logger() -embedding_model_cache = OrderedDict() # Model cache to store loaded models with LRU eviction -text_completion_model_cache = OrderedDict() # Model cache to store loaded text completion models with LRU eviction -loaded_models = OrderedDict() # Track loaded models to manage GPU memory +embedding_model_cache = {} # Model cache to store loaded models +text_completion_model_cache = {} # Model cache to store loaded text completion models SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int) DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str) @@ -114,7 +112,7 @@ async def initialize_globals(): lock_manager = None def download_models() -> Tuple[List[str], List[Dict[str, str]]]: - download_status = [] + download_status = [] json_path = os.path.join(BASE_DIRECTORY, "model_urls.json") if not os.path.exists(json_path): initial_model_urls = [ @@ -150,7 +148,7 @@ def download_models() -> Tuple[List[str], List[Dict[str, str]]]: status = {"url": url, "status": "success", "message": "File already exists."} filename = os.path.join(models_dir, model_name_with_extension) try: - with lock.acquire(timeout=1200): # Wait up to 20 minutes for the file to be downloaded before returning failure + with lock.acquire(timeout=1200): # Wait up to 20 minutes for the file to be downloaded before returning failure if not os.path.exists(filename): logger.info(f"Downloading model {model_name_with_extension} from {url}...") urllib.request.urlretrieve(url, filename) @@ -173,12 +171,6 @@ def download_models() -> Tuple[List[str], List[Dict[str, str]]]: logger.info("Model downloads completed.") return model_names, download_status -def evict_model_from_gpu(): - if loaded_models: - evicted_model_name, evicted_model_instance = loaded_models.popitem(last=False) - del evicted_model_instance - logger.info(f"Evicted model {evicted_model_name} from GPU memory") - def load_model(llm_model_name: str, raise_http_exception: bool = True): global USE_VERBOSE model_instance = None @@ -193,33 +185,16 @@ def load_model(llm_model_name: str, raise_http_exception: bool = True): matching_files.sort(key=os.path.getmtime, reverse=True) model_file_path = matching_files[0] gpu_info = is_gpu_available() - is_llava_multimodal_model = 'llava' in llm_model_name - with suppress_stdout_stderr(): - if is_llava_multimodal_model: - pass + if 'llava' in llm_model_name: + is_llava_multimodal_model = 1 + else: + is_llava_multimodal_model = 0 + if not is_llava_multimodal_model: + if gpu_info['gpu_found']: + model_instance = llama_cpp.Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=USE_VERBOSE, n_gpu_layers=-1) # Load the model with GPU acceleration else: - while True: - try: - if gpu_info['gpu_found']: - model_instance = llama_cpp.Llama( - model_path=model_file_path, embedding=True, - n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, - verbose=USE_VERBOSE, n_gpu_layers=-1 - ) # Load the model with GPU acceleration - else: - model_instance = llama_cpp.Llama( - model_path=model_file_path, embedding=True, - n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, - verbose=USE_VERBOSE - ) # Load the model without GPU acceleration - break - except ValueError as e: - if "cudaMalloc failed: out of memory" in str(e): - evict_model_from_gpu() - else: - raise + model_instance = llama_cpp.Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=USE_VERBOSE) # Load the model without GPU acceleration embedding_model_cache[llm_model_name] = model_instance - loaded_models[llm_model_name] = model_instance return model_instance except TypeError as e: logger.error(f"TypeError occurred while loading the model: {e}") diff --git a/swiss_army_llama.py b/swiss_army_llama.py index 70c8989..89b3ea1 100644 --- a/swiss_army_llama.py +++ b/swiss_army_llama.py @@ -3,7 +3,7 @@ from logger_config import setup_logger from database_functions import AsyncSessionLocal from ramdisk_functions import clear_ramdisk -from misc_utility_functions import build_faiss_indexes, configure_redis_optimally +from misc_utility_functions import build_faiss_indexes, configure_redis_optimally, sanitize_filename from embeddings_data_models import DocumentEmbedding, ShowLogsIncrementalModel from embeddings_data_models import EmbeddingRequest, SemanticSearchRequest, AdvancedSemanticSearchRequest, SimilarityRequest, TextCompletionRequest, AddGrammarRequest from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse, AudioTranscriptResponse, ImageQuestionResponse, AddGrammarResponse @@ -22,6 +22,7 @@ import tempfile import traceback import zipfile +from io import BytesIO from pathlib import Path from datetime import datetime from hashlib import sha3_256 @@ -31,7 +32,7 @@ from decouple import config import uvicorn import fastapi -from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Form +from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Form, Query from fastapi.responses import JSONResponse, FileResponse, HTMLResponse from contextlib import asynccontextmanager from sqlalchemy import select @@ -41,6 +42,8 @@ import fast_vector_similarity as fvs import uvloop from magika import Magika +import ChatTTS +from pydub import AudioSegment asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) logger = setup_logger() @@ -1460,4 +1463,57 @@ async def convert_document_to_sentences( result = await convert_document_to_sentences_func(temp_file_path, result.output.mime_type) finally: os.remove(temp_file_path) - return JSONResponse(content=result) \ No newline at end of file + return JSONResponse(content=result) + + + +@app.get("/generate_tts_audio/", + summary="Generate Text-to-Speech Audio", + description="""Generate a text-to-speech audio file from a given input text string. + +### Parameters: +- `text`: The input text string to convert into speech. This is a required query parameter. + +### Example Request: +The request should be a GET request with the `text` query parameter, for example: +``` +/generate_tts_audio/?text=This+is+a+sample+text +``` + +### Response: +The response will include the generated audio file in MP3 format. + +### Example Response: +The endpoint will return the MP3 file as an audio response. +""", + response_description="An MP3 audio file generated from the input text.") +async def generate_tts_audio(text: str = Query(..., description="The input text string to convert into speech")): + try: + chat = ChatTTS.Chat() + chat.load_models() + + # Infer the audio from the provided text + texts = [text] + wavs = chat.infer(texts, use_decoder=True) + + # Ensure the WAV data is in bytes format and has a correct sample rate + if not isinstance(wavs[0], bytes): + raise ValueError("Generated WAV data is not in bytes format.") + + # Save WAV data to a temporary BytesIO buffer + wav_buffer = BytesIO(wavs[0]) + audio = AudioSegment.from_raw(wav_buffer, sample_width=2, frame_rate=24000, channels=1) + + audio_output_file_name = sanitize_filename(text) + wav_path = f"/tmp/{audio_output_file_name}.wav" + mp3_path = f"/tmp/{audio_output_file_name}.mp3" + + # Export to WAV file + audio.export(wav_path, format="wav") + + # Convert WAV to MP3, keeping the original sample rate + audio.export(mp3_path, format="mp3", parameters=["-ar", "24000"]) + + return FileResponse(mp3_path, media_type="audio/mpeg", filename=f"{audio_output_file_name}.mp3") + except Exception as e: + raise HTTPException(status_code=500, detail=f"An error occurred: {e}") \ No newline at end of file