Skip to content

Commit

Permalink
Committing changes to experiments_with_tts
Browse files Browse the repository at this point in the history
  • Loading branch information
Dicklesworthstone committed May 29, 2024
1 parent c436648 commit 6b73cdf
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 65 deletions.
3 changes: 3 additions & 0 deletions misc_utility_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,6 @@ def find_clip_model_path(llm_model_name: str) -> Optional[str]:
logger.error(f"No mmproj file found matching: {mmproj_model_name}")
return None
return mmproj_files[0]

def sanitize_filename(text):
return re.sub(r'_{2,}', '_', re.sub(r'\s+', '_', text[:30]))
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ aioredis
aioredlock
aiosqlite
apscheduler
ChatTTS @ git+https://github.com/Dicklesworthstone/ChatTTS
faiss-cpu
fast_vector_similarity
fastapi
Expand All @@ -12,11 +13,13 @@ llama-cpp-python
magika
mutagen
nvgpu
openai
pandas
pillow
psutil
pydantic
PyPDF2
pydub
pytest
python-decouple
python-multipart
Expand Down
46 changes: 22 additions & 24 deletions service_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from logger_config import setup_logger
import shared_resources
from shared_resources import load_model, text_completion_model_cache, is_gpu_available, evict_model_from_gpu
from shared_resources import load_model, text_completion_model_cache, is_gpu_available
from database_functions import AsyncSessionLocal, execute_with_retry
from misc_utility_functions import clean_filename_for_url_func, FakeUploadFile, sophisticated_sentence_splitter, merge_transcript_segments_into_combined_text, suppress_stdout_stderr, image_to_base64_data_uri, process_image, find_clip_model_path
from embeddings_data_models import TextEmbedding, DocumentEmbedding, Document, AudioTranscript
Expand Down Expand Up @@ -483,40 +483,38 @@ def load_text_completion_model(llm_model_name: str, raise_http_exception: bool =
matching_files.sort(key=os.path.getmtime, reverse=True)
model_file_path = matching_files[0]
is_llava_multimodal_model = 'llava' in llm_model_name and 'mmproj' not in llm_model_name
chat_handler = None
chat_handler = None # Determine the appropriate chat handler based on the model name
if 'llava' in llm_model_name:
clip_model_path = find_clip_model_path(llm_model_name)
if clip_model_path is None:
raise FileNotFoundError
chat_handler = Llava16ChatHandler(clip_model_path=clip_model_path)
with suppress_stdout_stderr():
gpu_info = is_gpu_available()
llama_split_mode = 2 if gpu_info and gpu_info['num_gpus'] > 1 else 0
while True:
try:
model_instance = Llama(
model_path=model_file_path,
embedding=True if is_llava_multimodal_model else False,
n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS,
flash_attn=USE_FLASH_ATTENTION,
verbose=USE_VERBOSE,
llama_split_mode=llama_split_mode,
n_gpu_layers=-1 if gpu_info['gpu_found'] else 0,
clip_model_path=clip_model_path if is_llava_multimodal_model else None,
chat_handler=chat_handler
)
break
except ValueError as e:
if "cudaMalloc failed: out of memory" in str(e):
evict_model_from_gpu()
else:
raise
if gpu_info:
num_gpus = gpu_info['num_gpus']
if num_gpus > 1:
llama_split_mode = 2 # 2, // split rows across GPUs | 1, // split layers and KV across GPUs
else:
llama_split_mode = 0
else:
num_gpus = 0
model_instance = Llama(
model_path=model_file_path,
embedding=True if is_llava_multimodal_model else False,
n_ctx=TEXT_COMPLETION_CONTEXT_SIZE_IN_TOKENS,
flash_attn=USE_FLASH_ATTENTION,
verbose=USE_VERBOSE,
llama_split_mode=llama_split_mode,
n_gpu_layers=-1 if gpu_info['gpu_found'] else 0,
clip_model_path=clip_model_path if is_llava_multimodal_model else None,
chat_handler=chat_handler
)
text_completion_model_cache[llm_model_name] = model_instance
shared_resources.loaded_models[llm_model_name] = model_instance
return model_instance
except TypeError as e:
logger.error(f"TypeError occurred while loading the model: {e}")
logger.error(traceback.format_exc())
logger.error(traceback.format_exc())
raise
except Exception as e:
logger.error(f"Exception occurred while loading the model: {e}")
Expand Down
51 changes: 13 additions & 38 deletions shared_resources.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from misc_utility_functions import is_redis_running, start_redis_server, build_faiss_indexes, suppress_stdout_stderr
from misc_utility_functions import is_redis_running, start_redis_server, build_faiss_indexes
from database_functions import DatabaseWriter, initialize_db, AsyncSessionLocal, delete_expired_rows
from ramdisk_functions import setup_ramdisk, copy_models_to_ramdisk, check_that_user_has_required_permissions_to_manage_ramdisks
from logger_config import setup_logger
Expand All @@ -17,13 +17,11 @@
from decouple import config
from fastapi import HTTPException
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from collections import OrderedDict

logger = setup_logger()

embedding_model_cache = OrderedDict() # Model cache to store loaded models with LRU eviction
text_completion_model_cache = OrderedDict() # Model cache to store loaded text completion models with LRU eviction
loaded_models = OrderedDict() # Track loaded models to manage GPU memory
embedding_model_cache = {} # Model cache to store loaded models
text_completion_model_cache = {} # Model cache to store loaded text completion models

SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT = config("SWISS_ARMY_LLAMA_SERVER_LISTEN_PORT", default=8089, cast=int)
DEFAULT_MODEL_NAME = config("DEFAULT_MODEL_NAME", default="openchat_v3.2_super", cast=str)
Expand Down Expand Up @@ -114,7 +112,7 @@ async def initialize_globals():
lock_manager = None

def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
download_status = []
download_status = []
json_path = os.path.join(BASE_DIRECTORY, "model_urls.json")
if not os.path.exists(json_path):
initial_model_urls = [
Expand Down Expand Up @@ -150,7 +148,7 @@ def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
status = {"url": url, "status": "success", "message": "File already exists."}
filename = os.path.join(models_dir, model_name_with_extension)
try:
with lock.acquire(timeout=1200): # Wait up to 20 minutes for the file to be downloaded before returning failure
with lock.acquire(timeout=1200): # Wait up to 20 minutes for the file to be downloaded before returning failure
if not os.path.exists(filename):
logger.info(f"Downloading model {model_name_with_extension} from {url}...")
urllib.request.urlretrieve(url, filename)
Expand All @@ -173,12 +171,6 @@ def download_models() -> Tuple[List[str], List[Dict[str, str]]]:
logger.info("Model downloads completed.")
return model_names, download_status

def evict_model_from_gpu():
if loaded_models:
evicted_model_name, evicted_model_instance = loaded_models.popitem(last=False)
del evicted_model_instance
logger.info(f"Evicted model {evicted_model_name} from GPU memory")

def load_model(llm_model_name: str, raise_http_exception: bool = True):
global USE_VERBOSE
model_instance = None
Expand All @@ -193,33 +185,16 @@ def load_model(llm_model_name: str, raise_http_exception: bool = True):
matching_files.sort(key=os.path.getmtime, reverse=True)
model_file_path = matching_files[0]
gpu_info = is_gpu_available()
is_llava_multimodal_model = 'llava' in llm_model_name
with suppress_stdout_stderr():
if is_llava_multimodal_model:
pass
if 'llava' in llm_model_name:
is_llava_multimodal_model = 1
else:
is_llava_multimodal_model = 0
if not is_llava_multimodal_model:
if gpu_info['gpu_found']:
model_instance = llama_cpp.Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=USE_VERBOSE, n_gpu_layers=-1) # Load the model with GPU acceleration
else:
while True:
try:
if gpu_info['gpu_found']:
model_instance = llama_cpp.Llama(
model_path=model_file_path, embedding=True,
n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS,
verbose=USE_VERBOSE, n_gpu_layers=-1
) # Load the model with GPU acceleration
else:
model_instance = llama_cpp.Llama(
model_path=model_file_path, embedding=True,
n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS,
verbose=USE_VERBOSE
) # Load the model without GPU acceleration
break
except ValueError as e:
if "cudaMalloc failed: out of memory" in str(e):
evict_model_from_gpu()
else:
raise
model_instance = llama_cpp.Llama(model_path=model_file_path, embedding=True, n_ctx=LLM_CONTEXT_SIZE_IN_TOKENS, verbose=USE_VERBOSE) # Load the model without GPU acceleration
embedding_model_cache[llm_model_name] = model_instance
loaded_models[llm_model_name] = model_instance
return model_instance
except TypeError as e:
logger.error(f"TypeError occurred while loading the model: {e}")
Expand Down
62 changes: 59 additions & 3 deletions swiss_army_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from logger_config import setup_logger
from database_functions import AsyncSessionLocal
from ramdisk_functions import clear_ramdisk
from misc_utility_functions import build_faiss_indexes, configure_redis_optimally
from misc_utility_functions import build_faiss_indexes, configure_redis_optimally, sanitize_filename
from embeddings_data_models import DocumentEmbedding, ShowLogsIncrementalModel
from embeddings_data_models import EmbeddingRequest, SemanticSearchRequest, AdvancedSemanticSearchRequest, SimilarityRequest, TextCompletionRequest, AddGrammarRequest
from embeddings_data_models import EmbeddingResponse, SemanticSearchResponse, AdvancedSemanticSearchResponse, SimilarityResponse, AllStringsResponse, AllDocumentsResponse, TextCompletionResponse, AudioTranscriptResponse, ImageQuestionResponse, AddGrammarResponse
Expand All @@ -22,6 +22,7 @@
import tempfile
import traceback
import zipfile
from io import BytesIO
from pathlib import Path
from datetime import datetime
from hashlib import sha3_256
Expand All @@ -31,7 +32,7 @@
from decouple import config
import uvicorn
import fastapi
from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Form
from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Form, Query
from fastapi.responses import JSONResponse, FileResponse, HTMLResponse
from contextlib import asynccontextmanager
from sqlalchemy import select
Expand All @@ -41,6 +42,8 @@
import fast_vector_similarity as fvs
import uvloop
from magika import Magika
import ChatTTS
from pydub import AudioSegment

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger = setup_logger()
Expand Down Expand Up @@ -1460,4 +1463,57 @@ async def convert_document_to_sentences(
result = await convert_document_to_sentences_func(temp_file_path, result.output.mime_type)
finally:
os.remove(temp_file_path)
return JSONResponse(content=result)
return JSONResponse(content=result)



@app.get("/generate_tts_audio/",
summary="Generate Text-to-Speech Audio",
description="""Generate a text-to-speech audio file from a given input text string.
### Parameters:
- `text`: The input text string to convert into speech. This is a required query parameter.
### Example Request:
The request should be a GET request with the `text` query parameter, for example:
```
/generate_tts_audio/?text=This+is+a+sample+text
```
### Response:
The response will include the generated audio file in MP3 format.
### Example Response:
The endpoint will return the MP3 file as an audio response.
""",
response_description="An MP3 audio file generated from the input text.")
async def generate_tts_audio(text: str = Query(..., description="The input text string to convert into speech")):
try:
chat = ChatTTS.Chat()
chat.load_models()

# Infer the audio from the provided text
texts = [text]
wavs = chat.infer(texts, use_decoder=True)

# Ensure the WAV data is in bytes format and has a correct sample rate
if not isinstance(wavs[0], bytes):
raise ValueError("Generated WAV data is not in bytes format.")

# Save WAV data to a temporary BytesIO buffer
wav_buffer = BytesIO(wavs[0])
audio = AudioSegment.from_raw(wav_buffer, sample_width=2, frame_rate=24000, channels=1)

audio_output_file_name = sanitize_filename(text)
wav_path = f"/tmp/{audio_output_file_name}.wav"
mp3_path = f"/tmp/{audio_output_file_name}.mp3"

# Export to WAV file
audio.export(wav_path, format="wav")

# Convert WAV to MP3, keeping the original sample rate
audio.export(mp3_path, format="mp3", parameters=["-ar", "24000"])

return FileResponse(mp3_path, media_type="audio/mpeg", filename=f"{audio_output_file_name}.mp3")
except Exception as e:
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")

0 comments on commit 6b73cdf

Please sign in to comment.