Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Differentiable Duration Modeling #4

Open
LEECHOONGHO opened this issue Jun 2, 2022 · 0 comments
Open

Question about Differentiable Duration Modeling #4

LEECHOONGHO opened this issue Jun 2, 2022 · 0 comments

Comments

@LEECHOONGHO
Copy link

LEECHOONGHO commented Jun 2, 2022

Hello, I'm trying to implement Differentiable Duration Modeling(DDM) module introduced in
Differentiable Duration Modeling for End-to-End Text-to-Speech.

I opened this issue to get advice on implementation DDM.

My Implementation of Differentiable Alignment Encoder outputs attention like thing from noise input.
But the training speed of DDM is too slow(10s/iter). Seems like it hanged in backward progress.

Can anyone give me some advice to improve the speed of recursive tensor operation?
Should I use cuda.jit like Soft DTW? Or is there something wrong with the approach itself?

The module's output from noise input and code is like below.

Thank you.

dae = DifferentiableAlignmentEncoder()
b = 5
text_max_len = 25
mel_max_len = 85
dim = 256
x_len = torch.randint(1, text_max_len, (b,))
mel_len = torch.randint(2, mel_max_len, (b,))
x = torch.randn(b, max(x_len), dim)
s, l, q, dur = dae(x, x_len, mel_len)
i = 2
plt.imshow(l[i, :x_len[i], :mel_len[i]].detach().numpy())
plt.imshow(q[i, :x_len[i], :mel_len[i]].detach().numpy())
plt.imshow(s[i, :x_len[i], :mel_len[i]].detach().numpy())
plt.plot(dur[i, :x_len[i]])

L
image
Q
image
S = soft attention
image
Duration
image

Code

class DifferentiableAlignmentEncoder(nn.Module):
    def __init__(
        self,
        hidden_dim=256,
        conv_kernels=3,
        num_layers=3,
        dropout_p=0.2,
        max_mel_len=1150 # Max Length of Mel-Spectrogram Frame in training data
    ):
        super().__init__()
        
        self.conv_layer_blocks = nn.ModuleList([
            nn.Sequential(
                ConvNorm(hidden_dim, hidden_dim, conv_kernels, bias=True, transpose=True),
                nn.ReLU(),
                nn.LayerNorm(hidden_dim),
                nn.Dropout(dropout_p)
            )
            for i in range(num_layers)
        ])
        self.dur_prob_proj = LinearNorm(hidden_dim, max_mel_len, bias=False)
        
        self.ddm = DifferentiableDurationModeling()
    
    def forward(self, x, phon_lens, mel_lens, x_masks=None):
        
        """
        x  : Tensor[B, T_phon, C_phone]
        phon_lens : LongTensor[B]
        mel_lens : LongTensor[B]
        s : S Matrix : Tensor[B, T_phon, T_mel]
        dur : Duration Matrix : Tensor[B, T_phon]
        """
        
        max_mel_len = int(torch.max(mel_lens))
        
        for layer in self.conv_layer_blocks:
            if x_masks is not None:
                x = x * (1 - x_masks.float())
            x = layer(x)
        x = self.dur_prob_proj(x)
                
        norm = torch.randn(x.shape).to(x.device)
        x = x + norm
        
        p = torch.sigmoid(x)
        p = p[:, :, :max_mel_len]
        
        s, l, q, dur = self.ddm(p, phon_lens, mel_lens)
        
        dur = dur.detach()
        
        return s, l, q, dur
    
    
class DifferentiableDurationModeling(nn.Module):
    def __init__(self):
        super().__init__()
        
    def _get_attn_mask(self, phon_lens, mel_lens):
        phon_mask = ~get_mask_from_lengths(phon_lens)
        mel_mask = ~get_mask_from_lengths(mel_lens)
        
        return phon_mask.unsqueeze(-1) * mel_mask.unsqueeze(1), phon_mask
    
    def forward(self, p, phon_lens, mel_lens):
        
        attn_mask, phon_mask = self._get_attn_mask(phon_lens, mel_lens)
        
        p = p * attn_mask
        
        l = self._get_l(p, attn_mask)
        
        l = l * attn_mask

        dur = self._get_duration(l)
        
        dur = dur * phon_mask

        q = self._get_q(l)
        
        q = q * attn_mask
        
        s = self._get_s(q, l)
        
        s = s * attn_mask
            
        return s, l, q, dur
    
    def _get_duration(self, l):
        with torch.no_grad():
            m = torch.arange(1, l.shape[-1] + 1)[None, :].expand_as(l).to(l.device)
            dur = torch.sum(m * l, dim=-1)
        return dur
    
    def _get_l(self, p, mask):
        # getting l is numerically unstable for the gradient computation.
        # Paper's Author resolve this issue by computing this product in the log-space
        _p = torch.log(mask[:, :, 1:].float() - p[:, :, 1:] + 1e-8)
        p = torch.log(p + 1e-8)
        com = torch.cumsum(_p, dim=-1)
        l_0 = com[:, :, -1].unsqueeze(-1)
        l_1 = p[:, :, 1].unsqueeze(-1)
        
        l_m = com[:, :, :-1] + p[:, :, 2:]
                
        l = torch.cat([l_0, l_1, l_m], dim=-1)

        l = torch.exp(l)
        
        return l
    
    def _variable_kernel_size_convolution(self, x, y, length):
        matrix = torch.flip(x.unsqueeze(1) * y.unsqueeze(-1), dims=[-1])
        output =  torch.flip(
            torch.cat(
                [
                    torch.sum(
                        torch.diagonal(
                            matrix, offset=idx, dim1=-2, dim2=-1
                        ), dim=1
                    ).unsqueeze(1) 
                    for idx in range(length)
                ],
                dim=1
            ),
            dims=[1] 
        )
        return output
    
    def _get_q(self, l):
        length = l.shape[-1]
        q = [l[:, 0, :]]
        if l.shape[-1] > 1:
            for i in range(1, l.shape[1]):
                q.append(self._variable_kernel_size_convolution(q[i-1], l[:, i], length))
                        
        q = torch.cat([_.unsqueeze(1) for _ in q], dim=1)
        
        return q   

    def _reverse_cumsum(self, x):
        return torch.flip(torch.cumsum(torch.flip(x, dims=[-1]), dim=-1), dims=[-1])
    
    def _get_s(self, q, l):
        length = l.shape[-1]
        l_rev_cumsum = self._reverse_cumsum(l)
        s = [l_rev_cumsum[:, 0, :]]
        
        if l.shape[-1] > 1:
            for i in range(1, q.shape[1]):
                s.append(self._variable_kernel_size_convolution(q[:, i-1], l_rev_cumsum[:, i], length))
        
        s = torch.cat([_.unsqueeze(1) for _ in s], dim=1)
            
        return s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant