Skip to content

Commit

Permalink
Add XTTS training unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Edresson committed Oct 18, 2023
1 parent 75dc0e1 commit 2e029c1
Show file tree
Hide file tree
Showing 5 changed files with 12,859 additions and 17 deletions.
5 changes: 4 additions & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def format_batch_on_device(self, batch):
dvae_wav = batch["wav"]
dvae_mel_spec = self.torch_mel_spectrogram_dvae(dvae_wav)
codes = self.dvae.get_codebook_indices(dvae_mel_spec)

batch["audio_codes"] = codes
# delete useless batch tensors
del batch["padded_text"]
Expand Down Expand Up @@ -454,7 +455,9 @@ def load_checkpoint(
target_options={"anon": True},
): # pylint: disable=unused-argument, disable=W0201, disable=W0102, redefined-builtin
"""Load the model checkpoint and setup for training or inference"""
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"))["model"]

state, _ = self.xtts.get_compatible_checkpoint_state(checkpoint_path)

# load the model weights
self.xtts.load_state_dict(state, strict=strict)

Expand Down
38 changes: 22 additions & 16 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def inference(
expected_output_len = torch.tensor(
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
)

text_len = torch.tensor([text_tokens.shape[-1]], device=self.device)
gpt_latents = self.gpt(
text_tokens,
Expand Down Expand Up @@ -757,6 +758,26 @@ def eval(self): # pylint: disable=redefined-builtin
self.gpt.init_gpt_for_inference()
super().eval()

def get_compatible_checkpoint_state(self, model_path):
checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys.extend(["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"])
coqui_trainer_checkpoint = False
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
coqui_trainer_checkpoint = True
new_key = key.replace("xtts.", "")
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
key = new_key

# remove unused keys
if key.split(".")[0] in ignore_keys:
del checkpoint[key]

return checkpoint, coqui_trainer_checkpoint

def load_checkpoint(
self,
config,
Expand Down Expand Up @@ -790,22 +811,7 @@ def load_checkpoint(

self.init_models()

checkpoint = load_fsspec(model_path, map_location=torch.device("cpu"))["model"]
ignore_keys = ["diffusion_decoder", "vocoder"] if self.args.use_hifigan else ["hifigan_decoder"]
ignore_keys.extend(["torch_mel_spectrogram_style_encoder", "torch_mel_spectrogram_dvae", "dvae"])
coqui_trainer_checkpoint = False
for key in list(checkpoint.keys()):
# check if it is from the coqui Trainer if so convert it
if key.startswith("xtts."):
coqui_trainer_checkpoint = True
new_key = key.replace("xtts.", "")
checkpoint[new_key] = checkpoint[key]
del checkpoint[key]
key = new_key

# remove unused keys
if key.split(".")[0] in ignore_keys:
del checkpoint[key]
checkpoint, coqui_trainer_checkpoint = self.get_compatible_checkpoint_state(model_path)

if eval and not coqui_trainer_checkpoint:
self.gpt.init_gpt_for_inference(kv_cache=self.args.kv_cache)
Expand Down
1 change: 1 addition & 0 deletions recipes/ljspeech/xtts_v1/train_gpt_xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
)
LANGUAGE = config_dataset.language


def main():
# init args and config
model_args = GPTArgs(
Expand Down
Loading

0 comments on commit 2e029c1

Please sign in to comment.