-
Notifications
You must be signed in to change notification settings - Fork 0
/
do_tts.py
310 lines (243 loc) · 9.71 KB
/
do_tts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import os
import pandas as pd
import numpy as np
import json
import re
import nltk
import soundfile as sf
import torch
import tempfile
import wave
import logging
from tqdm import tqdm
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import get_user_data_dir
from TTS.utils.manage import ModelManager
from preprocess_text_with_gemini import preprocess_text
nltk.download("punkt_tab") # Download the necessary data (only needed once)
def postprocess(wav):
"""Post process the output waveform"""
if isinstance(wav, list):
wav = torch.cat(wav, dim=0)
wav = wav.clone().detach().cpu().numpy()
wav = wav[None, : int(wav.shape[0])]
wav = np.clip(wav, -1, 1)
wav = (wav * 32767).astype(np.int16)
return wav
def load_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
torch.cuda.empty_cache()
# Load or download your model
model_name = "tts_models/multilingual/multi-dataset/xtts_v2" # Example model name
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
if not os.path.exists(model_path):
print("Downloading XTTS Model:", model_name)
ModelManager().download_model(model_name)
print("XTTS Model downloaded")
print("Loading XTTS")
config = XttsConfig()
config_file = os.path.join(model_path, "config.json") # Use the default config file
config.load_json(config_file)
model = Xtts.init_from_config(config)
model.load_checkpoint(
config, checkpoint_dir=model_path, eval=True, use_deepspeed=False
)
model.to(device)
print("XTTS Loaded.")
return model
def load_speaker_embedding(file_path):
"""Load speaker embedding from a file."""
with open(file_path, "r") as file:
data = json.load(file)
return data["speaker_embedding"], data["gpt_cond_latent"]
def synthesize_speech(text, language, speaker_embedding, gpt_cond_latent, model):
"""Generate speech from text using the specified speaker embedding."""
# Convert to tensors
speaker_embedding_tensor = (
torch.tensor(speaker_embedding).unsqueeze(0).unsqueeze(-1)
)
gpt_cond_latent_tensor = (
torch.tensor(gpt_cond_latent).reshape((-1, 1024)).unsqueeze(0)
)
# Synthesize speech
out = model.inference(
text,
language,
gpt_cond_latent_tensor,
speaker_embedding_tensor,
)
# Post-process and convert to audio format
wav = postprocess(torch.tensor(out["wav"]))
# Save or return the audio data
return wav
def get_wav_duration(file_path):
"""Calculate the duration of a WAV file in seconds."""
with wave.open(file_path, "rb") as wav_file:
framerate = wav_file.getframerate()
nframes = wav_file.getnframes()
# Calculate duration
duration = nframes / float(framerate)
return duration
def split_sentence(sentence, max_length=250):
"""
Split a sentence at the nearest comma to the midpoint for sentences longer than max_length.
If no comma is found, split at the nearest space.
"""
if len(sentence) <= max_length:
return [sentence]
# Calculate the midpoint of the sentence
midpoint = len(sentence) // 2
# Find the nearest comma to the midpoint
left_comma = sentence.rfind(",", 0, midpoint)
right_comma = sentence.find(",", midpoint)
# Choose the closest comma to split
if left_comma != -1 or right_comma != -1:
# Prioritize splitting at a comma
if right_comma == -1 or (
left_comma != -1 and midpoint - left_comma <= right_comma - midpoint
):
split_index = left_comma
else:
split_index = right_comma
else:
# If no comma found, find the nearest space
left_space = sentence.rfind(" ", 0, midpoint)
right_space = sentence.find(" ", midpoint)
if right_space == -1 or (
left_space != -1 and midpoint - left_space <= right_space - midpoint
):
split_index = left_space
else:
split_index = right_space
# Split the sentence
first_part = sentence[
: split_index + 1
].rstrip() # Include the comma in the first part
second_part = sentence[split_index + 1 :].lstrip()
# # Debug print
print(f"Original Sentence: '{sentence}'")
print(f"Split Index: {split_index}")
print(f"First Part: '{first_part}'")
print(f"Second Part: '{second_part}'")
return [first_part, second_part]
def split_text_into_sentences(text):
"""Split the text into sentences using NLKTokenizer."""
return nltk.sent_tokenize(text)
def generate_silence(duration_ms, sample_rate):
"""Generate a silence (zero amplitude) segment of a given duration in milliseconds."""
num_samples = int(sample_rate * duration_ms / 1000)
return np.zeros(num_samples, dtype=np.int16)
def create_srt(subtitles, output_file_path):
"""Create an SRT file for the given subtitles with timings."""
with open(output_file_path, "w", encoding="utf-8") as file:
for index, (start, end, text) in enumerate(subtitles, 1):
start_time_str = f"{int(start//3600):02d}:{int((start%3600)//60):02d}:{int(start%60):02d},000"
end_time_str = (
f"{int(end//3600):02d}:{int((end%3600)//60):02d}:{int(end%60):02d},000"
)
file.write(f"{index}\n")
file.write(f"{start_time_str} --> {end_time_str}\n")
file.write(f"{text.strip()}\n")
file.write("\n") # Only one newline to separate entries
def process_sentences(
sentences,
model,
language,
speaker_embedding,
gpt_cond_latent,
sample_rate,
sample_width=2,
):
full_wav_data = np.array([], dtype=np.int16)
# Load the speaker embedding once
subtitles = []
start_time = 0.0
pause_duration_ms = 500 # Adjust this value as needed
# Initialize progress bar
pbar = tqdm(total=len(sentences), desc="Processing sentences")
for sentence in sentences:
parts = split_sentence(sentence)
subtitle_start = start_time
for i, part in enumerate(parts):
wav_data = synthesize_speech(
part, language, speaker_embedding, gpt_cond_latent, model
)
part_end_index = sentence.find(part) + len(part)
if (
i < len(parts) - 1
and sentence[part_end_index : part_end_index + 1] == ","
):
silence = generate_silence(pause_duration_ms, sample_rate)
wav_data = np.concatenate((wav_data, silence))
wav_data = np.asarray(wav_data, dtype=np.int16).flatten()
# Write to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
# Save the audio data to the temporary file
sf.write(temp_file.name, wav_data, sample_rate, subtype="PCM_16")
# Calculate duration using the temporary file
duration = get_wav_duration(temp_file.name)
# Append to full WAV data
full_wav_data = np.concatenate((full_wav_data, wav_data))
# Update end time for this part
end_time = start_time + duration
# Update start time for next part
start_time = end_time
# Clean up the temporary file
os.remove(temp_file.name)
# Generate and append silence after processing the sentence
silence = generate_silence(pause_duration_ms, sample_rate)
full_wav_data = np.concatenate((full_wav_data, silence))
# Update start_time for the next sentence, accounting for the pause
start_time += pause_duration_ms / 1000.0
# Append subtitle for the whole sentence
subtitles.append((subtitle_start, end_time, sentence))
# Update the progress bar
pbar.update(1)
return full_wav_data, subtitles
def text_to_tts(input_text, speaker_name, output_name):
# Process sentences and create subtitles
model = load_model()
# Preprocess the text using gemini
text = preprocess_text(input_text)
if not text:
raise Exception("Preprocessed text not received from Gemini")
# For debugging
with open("last_gemini_preprocessed.txt", "w", encoding="utf-8") as file:
file.write(text)
# Split the text into sentences
sentences = split_text_into_sentences(text)
language = "en"
pbar = tqdm(total=len(sentences), desc="Processing Text")
# Set the sample rate and sample width
sample_rate = 24000 # 24 kHz sampling rate
sample_width = 2 # 16 bits (2 bytes)
# Load the speaker embedding and GPT condition latent once
speaker_embedding, gpt_cond_latent = load_speaker_embedding(
f"speakers/{speaker_name}.json"
)
# Call process_sentences with the loaded embeddings
full_wav_data, subtitles = process_sentences(
sentences,
model,
language,
speaker_embedding,
gpt_cond_latent,
sample_rate,
sample_width,
)
pbar.close()
# Specify the output file path for the synthesized speech
output_wav_file_path = f"output/{output_name}.wav"
subtitles_output_path = f"output/{output_name}.srt"
create_srt(subtitles, subtitles_output_path)
# Save the full WAV file
sf.write(output_wav_file_path, full_wav_data, sample_rate, subtype="PCM_16")
print(f"Synthesized speech saved to {output_wav_file_path}")
if __name__ == "__main__":
with open("text_example.txt", "r", encoding="utf-8") as file:
input_text = file.read()
speaker = "audiobook_lady"
output_name = "test_output"
text_to_tts(input_text, speaker, output_name)