-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add XTTS Fine tuning gradio demo (#3296)
* Add XTTS FT demo data processing pipeline * Add training and inference columns * Uses tabs instead of columns * Fix demo freezing issue * Update demo * Convert stereo to mono * Bug fix on XTTS inference * Update gradio demo * Update gradio demo * Update gradio demo * Update gradio demo * Add parameters to be able to set then on colab demo * Add erros messages * Add intuitive error messages * Update * Add max_audio_length parameter * Add XTTS fine-tuner docs * Update XTTS finetuner docs * Delete trainer to freeze memory * Delete unused variables * Add gc.collect() * Update xtts.md --------- Co-authored-by: Eren Gölge <[email protected]>
- Loading branch information
Showing
7 changed files
with
800 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
faster_whisper==0.9.0 | ||
gradio==4.7.1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import os | ||
import gc | ||
import torchaudio | ||
import pandas | ||
from faster_whisper import WhisperModel | ||
from glob import glob | ||
|
||
from tqdm import tqdm | ||
|
||
import torch | ||
import torchaudio | ||
# torch.set_num_threads(1) | ||
|
||
from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners | ||
|
||
torch.set_num_threads(16) | ||
|
||
|
||
import os | ||
|
||
audio_types = (".wav", ".mp3", ".flac") | ||
|
||
|
||
def list_audios(basePath, contains=None): | ||
# return the set of files that are valid | ||
return list_files(basePath, validExts=audio_types, contains=contains) | ||
|
||
def list_files(basePath, validExts=None, contains=None): | ||
# loop over the directory structure | ||
for (rootDir, dirNames, filenames) in os.walk(basePath): | ||
# loop over the filenames in the current directory | ||
for filename in filenames: | ||
# if the contains string is not none and the filename does not contain | ||
# the supplied string, then ignore the file | ||
if contains is not None and filename.find(contains) == -1: | ||
continue | ||
|
||
# determine the file extension of the current file | ||
ext = filename[filename.rfind("."):].lower() | ||
|
||
# check to see if the file is an audio and should be processed | ||
if validExts is None or ext.endswith(validExts): | ||
# construct the path to the audio and yield it | ||
audioPath = os.path.join(rootDir, filename) | ||
yield audioPath | ||
|
||
def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): | ||
audio_total_size = 0 | ||
# make sure that ooutput file exists | ||
os.makedirs(out_path, exist_ok=True) | ||
|
||
# Loading Whisper | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
print("Loading Whisper Model!") | ||
asr_model = WhisperModel("large-v2", device=device, compute_type="float16") | ||
|
||
metadata = {"audio_file": [], "text": [], "speaker_name": []} | ||
|
||
if gradio_progress is not None: | ||
tqdm_object = gradio_progress.tqdm(audio_files, desc="Formatting...") | ||
else: | ||
tqdm_object = tqdm(audio_files) | ||
|
||
for audio_path in tqdm_object: | ||
wav, sr = torchaudio.load(audio_path) | ||
# stereo to mono if needed | ||
if wav.size(0) != 1: | ||
wav = torch.mean(wav, dim=0, keepdim=True) | ||
|
||
wav = wav.squeeze() | ||
audio_total_size += (wav.size(-1) / sr) | ||
|
||
segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) | ||
segments = list(segments) | ||
i = 0 | ||
sentence = "" | ||
sentence_start = None | ||
first_word = True | ||
# added all segments words in a unique list | ||
words_list = [] | ||
for _, segment in enumerate(segments): | ||
words = list(segment.words) | ||
words_list.extend(words) | ||
|
||
# process each word | ||
for word_idx, word in enumerate(words_list): | ||
if first_word: | ||
sentence_start = word.start | ||
# If it is the first sentence, add buffer or get the begining of the file | ||
if word_idx == 0: | ||
sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start | ||
else: | ||
# get previous sentence end | ||
previous_word_end = words_list[word_idx - 1].end | ||
# add buffer or get the silence midle between the previous sentence and the current one | ||
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) | ||
|
||
sentence = word.word | ||
first_word = False | ||
else: | ||
sentence += word.word | ||
|
||
if word.word[-1] in ["!", ".", "?"]: | ||
sentence = sentence[1:] | ||
# Expand number and abbreviations plus normalization | ||
sentence = multilingual_cleaners(sentence, target_language) | ||
audio_file_name, _ = os.path.splitext(os.path.basename(audio_path)) | ||
|
||
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav" | ||
|
||
# Check for the next word's existence | ||
if word_idx + 1 < len(words_list): | ||
next_word_start = words_list[word_idx + 1].start | ||
else: | ||
# If don't have more words it means that it is the last sentence then use the audio len as next word start | ||
next_word_start = (wav.shape[0] - 1) / sr | ||
|
||
# Average the current word end and next word start | ||
word_end = min((word.end + next_word_start) / 2, word.end + buffer) | ||
|
||
absoulte_path = os.path.join(out_path, audio_file) | ||
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) | ||
i += 1 | ||
first_word = True | ||
|
||
audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) | ||
# if the audio is too short ignore it (i.e < 0.33 seconds) | ||
if audio.size(-1) >= sr/3: | ||
torchaudio.save(absoulte_path, | ||
audio, | ||
sr | ||
) | ||
else: | ||
continue | ||
|
||
metadata["audio_file"].append(audio_file) | ||
metadata["text"].append(sentence) | ||
metadata["speaker_name"].append(speaker_name) | ||
|
||
df = pandas.DataFrame(metadata) | ||
df = df.sample(frac=1) | ||
num_val_samples = int(len(df)*eval_percentage) | ||
|
||
df_eval = df[:num_val_samples] | ||
df_train = df[num_val_samples:] | ||
|
||
df_train = df_train.sort_values('audio_file') | ||
train_metadata_path = os.path.join(out_path, "metadata_train.csv") | ||
df_train.to_csv(train_metadata_path, sep="|", index=False) | ||
|
||
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") | ||
df_eval = df_eval.sort_values('audio_file') | ||
df_eval.to_csv(eval_metadata_path, sep="|", index=False) | ||
|
||
# deallocate VRAM and RAM | ||
del asr_model, df_train, df_eval, df, metadata | ||
gc.collect() | ||
|
||
return train_metadata_path, eval_metadata_path, audio_total_size |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
import os | ||
import gc | ||
|
||
from trainer import Trainer, TrainerArgs | ||
|
||
from TTS.config.shared_configs import BaseDatasetConfig | ||
from TTS.tts.datasets import load_tts_samples | ||
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig | ||
from TTS.utils.manage import ModelManager | ||
|
||
|
||
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path, max_audio_length=255995): | ||
# Logging parameters | ||
RUN_NAME = "GPT_XTTS_FT" | ||
PROJECT_NAME = "XTTS_trainer" | ||
DASHBOARD_LOGGER = "tensorboard" | ||
LOGGER_URI = None | ||
|
||
# Set here the path that the checkpoints will be saved. Default: ./run/training/ | ||
OUT_PATH = os.path.join(output_path, "run", "training") | ||
|
||
# Training Parameters | ||
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False | ||
START_WITH_EVAL = False # if True it will star with evaluation | ||
BATCH_SIZE = batch_size # set here the batch size | ||
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps | ||
|
||
|
||
# Define here the dataset that you want to use for the fine-tuning on. | ||
config_dataset = BaseDatasetConfig( | ||
formatter="coqui", | ||
dataset_name="ft_dataset", | ||
path=os.path.dirname(train_csv), | ||
meta_file_train=train_csv, | ||
meta_file_val=eval_csv, | ||
language=language, | ||
) | ||
|
||
# Add here the configs of the datasets | ||
DATASETS_CONFIG_LIST = [config_dataset] | ||
|
||
# Define the path where XTTS v2.0.1 files will be downloaded | ||
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") | ||
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) | ||
|
||
|
||
# DVAE files | ||
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" | ||
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" | ||
|
||
# Set the path to the downloaded files | ||
DVAE_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(DVAE_CHECKPOINT_LINK)) | ||
MEL_NORM_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(MEL_NORM_LINK)) | ||
|
||
# download DVAE files if needed | ||
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): | ||
print(" > Downloading DVAE files!") | ||
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) | ||
|
||
|
||
# Download XTTS v2.0 checkpoint if needed | ||
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" | ||
XTTS_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth" | ||
XTTS_CONFIG_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json" | ||
|
||
# XTTS transfer learning parameters: You we need to provide the paths of XTTS model checkpoint that you want to do the fine tuning. | ||
TOKENIZER_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(TOKENIZER_FILE_LINK)) # vocab.json file | ||
XTTS_CHECKPOINT = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CHECKPOINT_LINK)) # model.pth file | ||
XTTS_CONFIG_FILE = os.path.join(CHECKPOINTS_OUT_PATH, os.path.basename(XTTS_CONFIG_LINK)) # config.json file | ||
|
||
# download XTTS v2.0 files if needed | ||
if not os.path.isfile(TOKENIZER_FILE) or not os.path.isfile(XTTS_CHECKPOINT): | ||
print(" > Downloading XTTS v2.0 files!") | ||
ModelManager._download_model_files( | ||
[TOKENIZER_FILE_LINK, XTTS_CHECKPOINT_LINK, XTTS_CONFIG_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True | ||
) | ||
|
||
# init args and config | ||
model_args = GPTArgs( | ||
max_conditioning_length=132300, # 6 secs | ||
min_conditioning_length=66150, # 3 secs | ||
debug_loading_failures=False, | ||
max_wav_length=max_audio_length, # ~11.6 seconds | ||
max_text_length=200, | ||
mel_norm_file=MEL_NORM_FILE, | ||
dvae_checkpoint=DVAE_CHECKPOINT, | ||
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune | ||
tokenizer_file=TOKENIZER_FILE, | ||
gpt_num_audio_tokens=1026, | ||
gpt_start_audio_token=1024, | ||
gpt_stop_audio_token=1025, | ||
gpt_use_masking_gt_prompt_approach=True, | ||
gpt_use_perceiver_resampler=True, | ||
) | ||
# define audio config | ||
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000) | ||
# training parameters config | ||
config = GPTTrainerConfig( | ||
epochs=num_epochs, | ||
output_path=OUT_PATH, | ||
model_args=model_args, | ||
run_name=RUN_NAME, | ||
project_name=PROJECT_NAME, | ||
run_description=""" | ||
GPT XTTS training | ||
""", | ||
dashboard_logger=DASHBOARD_LOGGER, | ||
logger_uri=LOGGER_URI, | ||
audio=audio_config, | ||
batch_size=BATCH_SIZE, | ||
batch_group_size=48, | ||
eval_batch_size=BATCH_SIZE, | ||
num_loader_workers=8, | ||
eval_split_max_size=256, | ||
print_step=50, | ||
plot_step=100, | ||
log_model_step=100, | ||
save_step=1000, | ||
save_n_checkpoints=1, | ||
save_checkpoints=True, | ||
# target_loss="loss", | ||
print_eval=False, | ||
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters. | ||
optimizer="AdamW", | ||
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS, | ||
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2}, | ||
lr=5e-06, # learning rate | ||
lr_scheduler="MultiStepLR", | ||
# it was adjusted accordly for the new step scheme | ||
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1}, | ||
test_sentences=[], | ||
) | ||
|
||
# init the model from config | ||
model = GPTTrainer.init_from_config(config) | ||
|
||
# load training samples | ||
train_samples, eval_samples = load_tts_samples( | ||
DATASETS_CONFIG_LIST, | ||
eval_split=True, | ||
eval_split_max_size=config.eval_split_max_size, | ||
eval_split_size=config.eval_split_size, | ||
) | ||
|
||
# init the trainer and 🚀 | ||
trainer = Trainer( | ||
TrainerArgs( | ||
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter | ||
skip_train_epoch=False, | ||
start_with_eval=START_WITH_EVAL, | ||
grad_accum_steps=GRAD_ACUMM_STEPS, | ||
), | ||
config, | ||
output_path=OUT_PATH, | ||
model=model, | ||
train_samples=train_samples, | ||
eval_samples=eval_samples, | ||
) | ||
trainer.fit() | ||
|
||
# get the longest text audio file to use as speaker reference | ||
samples_len = [len(item["text"].split(" ")) for item in train_samples] | ||
longest_text_idx = samples_len.index(max(samples_len)) | ||
speaker_ref = train_samples[longest_text_idx]["audio_file"] | ||
|
||
trainer_out_path = trainer.output_path | ||
|
||
# deallocate VRAM and RAM | ||
del model, trainer, train_samples, eval_samples | ||
gc.collect() | ||
|
||
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref |
Oops, something went wrong.