diff --git a/TTS/tts/layers/tacotron/capacitron_layers.py b/TTS/tts/layers/tacotron/capacitron_layers.py index 2d9be0f2bb..529560bd41 100755 --- a/TTS/tts/layers/tacotron/capacitron_layers.py +++ b/TTS/tts/layers/tacotron/capacitron_layers.py @@ -8,7 +8,7 @@ class CapacitronVAE(nn.Module): See https://arxiv.org/abs/1906.03402 """ - def __init__(self, num_mel, capacitron_embedding_dim, encoder_output_dim=256, reference_encoder_out_dim=128, speaker_embedding_dim=None, text_summary_embedding_dim=None): + def __init__(self, num_mel, capacitron_embedding_dim, text_encoder_output_dim=256, reference_encoder_out_dim=128, speaker_embedding_dim=None, text_summary_embedding_dim=None): super().__init__() # Init distributions self.prior_distribution = MVN(torch.zeros(capacitron_embedding_dim), torch.eye(capacitron_embedding_dim)) @@ -21,7 +21,7 @@ def __init__(self, num_mel, capacitron_embedding_dim, encoder_output_dim=256, re mlp_input_dimension = reference_encoder_out_dim if text_summary_embedding_dim is not None: - self.text_summary_net = TextSummary(text_summary_embedding_dim, encoder_output_dim=encoder_output_dim) + self.text_summary_net = TextSummary(text_summary_embedding_dim, text_encoder_output_dim=text_encoder_output_dim) mlp_input_dimension += text_summary_embedding_dim if speaker_embedding_dim is not None: # TODO: Figure out what to do with speaker_embedding_dim @@ -157,9 +157,9 @@ def calculate_post_conv_height(height, kernel_size, stride, pad, return height class TextSummary(nn.Module): - def __init__(self, embedding_dim, encoder_output_dim): + def __init__(self, embedding_dim, text_encoder_output_dim): super().__init__() - self.lstm = nn.LSTM(encoder_output_dim, # text embedding dimension from the text encoder + self.lstm = nn.LSTM(text_encoder_output_dim, # text embedding dimension from the text encoder embedding_dim, # fixed length output summary the lstm creates from the input batch_first=True, bidirectional=False)