forked from Gastron/sb-fin-parl-2015-2020-kevat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
tdnn.py
135 lines (120 loc) · 4.81 KB
/
tdnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# Copyright (c) Yiwen Shao
# Apache 2.0
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
class tdnn_bn_relu(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride=1, dilation=1):
super(tdnn_bn_relu, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.padding = dilation * (kernel_size - 1) // 2
self.dilation = dilation
self.tdnn = nn.Conv1d(in_dim, out_dim, kernel_size,
stride=stride, padding=self.padding, dilation=dilation)
self.bn = nn.BatchNorm1d(out_dim)
self.relu = nn.ReLU(inplace=True)
def output_lengths(self, in_lengths):
out_lengths = (
in_lengths + 2 * self.padding - self.dilation * (self.kernel_size - 1) +
self.stride - 1
) // self.stride
return out_lengths
def forward(self, x, x_lengths):
assert len(x.size()) == 3 # x is of size (N, F, T)
x = self.tdnn(x)
x = self.bn(x)
x = self.relu(x)
x_lengths = self.output_lengths(x_lengths)
return x, x_lengths
class TDNN(nn.Module):
def __init__(self, in_dim, num_layers, hidden_dims, kernel_sizes, strides, dilations,
dropout=0, residual=False):
super(TDNN, self).__init__()
assert len(hidden_dims) == num_layers
assert len(kernel_sizes) == num_layers
assert len(strides) == num_layers
assert len(dilations) == num_layers
self.dropout = dropout
self.residual = residual
self.num_layers = num_layers
self.tdnn = nn.ModuleList([
tdnn_bn_relu(
in_dim if layer == 0 else hidden_dims[layer - 1],
hidden_dims[layer], kernel_sizes[layer],
strides[layer], dilations[layer],
)
for layer in range(num_layers)
])
def forward(self, x, x_lengths=None):
if x_lengths is None:
x_lengths = torch.tensor([x.shape[1] for _ in range(x.shape[0])])
assert len(x.size()) == 3 # x is of size (B, T, D)
# turn x to (B, D, T) for tdnn/cnn input
x = x.transpose(1, 2).contiguous()
for i in range(len(self.tdnn)):
# apply Tdnn
if self.residual and i > 0: # residual starts from the 2nd layer
prev_x = x
x, x_lengths = self.tdnn[i](x, x_lengths)
x = x + prev_x if (self.residual and i >
0 and x.size(2) == prev_x.size(2)) else x
x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(2, 1).contiguous() # turn it back to (B, T, D)
return x
class TDNN_MFCC(nn.Module):
def __init__(self, in_dim, out_dim, num_layers, hidden_dims, kernel_sizes, strides, dilations, dropout=0):
super(TDNN_MFCC, self).__init__()
assert len(hidden_dims) == num_layers
assert len(kernel_sizes) == num_layers
assert len(strides) == num_layers
assert len(dilations) == num_layers
self.dropout = dropout
self.num_layers = num_layers
self.tdnn = nn.ModuleList([
tdnn_bn_relu(
in_dim if layer == 0 else hidden_dims[layer - 1],
hidden_dims[layer], kernel_sizes[layer],
strides[layer], dilations[layer],
)
for layer in range(num_layers)
])
self.mfcc = torchaudio.transforms.MFCC()
self.final_layer = nn.Linear(hidden_dims[-1], out_dim, True)
def mfcc_output_lengths(self, in_lengths):
hop_length = self.mfcc.MelSpectrogram.hop_length
out_lengths = in_lengths // hop_length + 1
return out_lengths
def forward(self, x, x_lengths):
assert len(x.size()) == 3 # x is of size (B, T, D)
# turn x to (B, D, T) for tdnn/cnn input
x = x.transpose(1, 2).contiguous()
x = self.mfcc(x)
x = x.squeeze(1) # x of size (B, D, T)
x_lengths = self.mfcc_output_lengths(x_lengths)
for i in range(len(self.tdnn)):
# apply Tdnn
x, x_lengths = self.tdnn[i](x, x_lengths)
x = F.dropout(x, p=self.dropout, training=self.training)
x = x.transpose(2, 1).contiguous() # turn it back to (B, T, D)
x = self.final_layer(x)
return x, x_lengths
if __name__ == "__main__":
kernel_size = 3
dilation = 2
num_layers = 1
hidden_dims = [20]
kernel_sizes = [3]
strides = [2]
dilations = [2]
in_dim = 10
out_dim = 5
net = TDNN(in_dim, out_dim, num_layers,
hidden_dims, kernel_sizes, strides, dilations)
print(net)
input = torch.randn(2, 8, 10)
input_lengths = torch.IntTensor([8, 6])
output, output_lengths = net(input, input_lengths)
print(output.size())
print(output_lengths)