Skip to content

Commit

Permalink
Add speed control for inference (#3214)
Browse files Browse the repository at this point in the history
* Add speed control for inference

* Fix XTTS tests

* Add speed control tests
  • Loading branch information
WeberJulian authored Nov 14, 2023
1 parent d96f388 commit 04901fb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 3 deletions.
17 changes: 17 additions & 0 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,8 +530,10 @@ def inference(
top_p=0.85,
do_sample=True,
num_beams=1,
speed=1.0,
**hf_generate_kwargs,
):
length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)

Expand Down Expand Up @@ -584,6 +586,13 @@ def inference(
gpt_latents = gpt_latents[:, :k]
break

if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)

wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding)

return {
Expand Down Expand Up @@ -634,8 +643,10 @@ def inference_stream(
top_k=50,
top_p=0.85,
do_sample=True,
speed=1.0,
**hf_generate_kwargs,
):
length_scale = 1.0 / max(speed, 0.05)
text = text.strip().lower()
text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device)

Expand Down Expand Up @@ -674,6 +685,12 @@ def inference_stream(

if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size):
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len
Expand Down
34 changes: 31 additions & 3 deletions tests/zoo_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_xtts_streaming():
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)

print("Inference...")
chunks = model.inference_stream(
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_xtts_v2():
"yes | "
f"tts --model_name tts_models/multilingual/multi-dataset/xtts_v2 "
f'--text "This is an example." --out_path "{output_path}" --progress_bar False --use_cuda True '
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" "--language_idx "en"'
f'--speaker_wav "{speaker_wav}" "{speaker_wav_2}" --language_idx "en"'
)
else:
run_cli(
Expand All @@ -164,7 +164,7 @@ def test_xtts_v2_streaming():
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))

print("Computing speaker latents...")
gpt_cond_latent, _, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)
gpt_cond_latent, speaker_embedding = model.get_conditioning_latents(audio_path=speaker_wav)

print("Inference...")
chunks = model.inference_stream(
Expand All @@ -179,6 +179,34 @@ def test_xtts_v2_streaming():
assert chunk.shape[-1] > 5000
wav_chuncks.append(chunk)
assert len(wav_chuncks) > 1
normal_len = sum([len(chunk) for chunk in wav_chuncks])

chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding,
speed=1.5
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
wav_chuncks.append(chunk)
fast_len = sum([len(chunk) for chunk in wav_chuncks])

chunks = model.inference_stream(
"It took me quite a long time to develop a voice and now that I have it I am not going to be silent.",
"en",
gpt_cond_latent,
speaker_embedding,
speed=0.66
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
wav_chuncks.append(chunk)
slow_len = sum([len(chunk) for chunk in wav_chuncks])

assert slow_len > normal_len
assert normal_len > fast_len


def test_tortoise():
Expand Down

0 comments on commit 04901fb

Please sign in to comment.