-
Notifications
You must be signed in to change notification settings - Fork 1
/
rnnt_joint.py
314 lines (250 loc) · 12.7 KB
/
rnnt_joint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
from typing import Any, Dict, List, Optional, Union
from torch.nn import Module
import torch
"""
Classes and methods from the nemo-toolkit for using the RNNTJoint module
"""
class RNNTJoint(Module):
def __init__(
self,
jointnet: Dict[str, Any],
num_classes: int,
num_extra_outputs: int = 0,
vocabulary: Optional[List] = None,
log_softmax: Optional[bool] = None,
preserve_memory: bool = False,
fuse_loss_wer: bool = False,
fused_batch_size: Optional[int] = None,
experimental_fuse_loss_wer: Any = None,
):
super().__init__()
self.vocabulary = vocabulary
self._vocab_size = num_classes
self._num_extra_outputs = num_extra_outputs
self._num_classes = num_classes + 1 + num_extra_outputs # 1 is for blank
if experimental_fuse_loss_wer is not None:
# Override fuse_loss_wer from deprecated argument
fuse_loss_wer = experimental_fuse_loss_wer
self._fuse_loss_wer = fuse_loss_wer
self._fused_batch_size = fused_batch_size
if fuse_loss_wer and (fused_batch_size is None):
raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!")
self._loss = None
self._wer = None
# Log softmax should be applied explicitly only for CPU
self.log_softmax = log_softmax
self.preserve_memory = preserve_memory
# Required arguments
self.encoder_hidden = jointnet['encoder_hidden']
self.pred_hidden = jointnet['pred_hidden']
self.joint_hidden = jointnet['joint_hidden']
self.activation = jointnet['activation']
# Optional arguments
dropout = jointnet.get('dropout', 0.0)
self.pred, self.enc, self.joint_net = self._joint_net_modules(
num_classes=self._num_classes, # add 1 for blank symbol
pred_n_hidden=self.pred_hidden,
enc_n_hidden=self.encoder_hidden,
joint_n_hidden=self.joint_hidden,
activation=self.activation,
dropout=dropout,
)
# Flag needed for RNNT export support
self._rnnt_export = False
# to change, requires running ``model.temperature = T`` explicitly
self.temperature = 1.0
def forward(
self,
encoder_outputs: torch.Tensor,
decoder_outputs: Optional[torch.Tensor],
encoder_lengths: Optional[torch.Tensor] = None,
transcripts: Optional[torch.Tensor] = None,
transcript_lengths: Optional[torch.Tensor] = None,
compute_wer: bool = False,
) -> Union[torch.Tensor, List[Optional[torch.Tensor]]]:
# encoder = (B, D, T)
# decoder = (B, D, U) if passed, else None
encoder_outputs = encoder_outputs.transpose(1, 2) # (B, T, D)
if decoder_outputs is not None:
decoder_outputs = decoder_outputs.transpose(1, 2) # (B, U, D)
if not self._fuse_loss_wer:
if decoder_outputs is None:
raise ValueError(
"decoder_outputs passed is None, and `fuse_loss_wer` is not set. "
"decoder_outputs can only be None for fused step!"
)
out = self.joint(encoder_outputs, decoder_outputs) # [B, T, U, V + 1]
return out
else:
# At least the loss module must be supplied during fused joint
if self._loss is None or self._wer is None:
raise ValueError("`fuse_loss_wer` flag is set, but `loss` and `wer` modules were not provided! ")
# If fused joint step is required, fused batch size is required as well
if self._fused_batch_size is None:
raise ValueError("If `fuse_loss_wer` is set, then `fused_batch_size` cannot be None!")
# When using fused joint step, both encoder and transcript lengths must be provided
if (encoder_lengths is None) or (transcript_lengths is None):
raise ValueError(
"`fuse_loss_wer` is set, therefore encoder and target lengths " "must be provided as well!"
)
losses = []
target_lengths = []
batch_size = int(encoder_outputs.size(0)) # actual batch size
# Iterate over batch using fused_batch_size steps
for batch_idx in range(0, batch_size, self._fused_batch_size):
begin = batch_idx
end = min(begin + self._fused_batch_size, batch_size)
# Extract the sub batch inputs
# sub_enc = encoder_outputs[begin:end, ...]
# sub_transcripts = transcripts[begin:end, ...]
sub_enc = encoder_outputs.narrow(dim=0, start=begin, length=int(end - begin))
sub_transcripts = transcripts.narrow(dim=0, start=begin, length=int(end - begin))
sub_enc_lens = encoder_lengths[begin:end]
sub_transcript_lens = transcript_lengths[begin:end]
# Sub transcripts does not need the full padding of the entire batch
# Therefore reduce the decoder time steps to match
max_sub_enc_length = sub_enc_lens.max()
max_sub_transcript_length = sub_transcript_lens.max()
if decoder_outputs is not None:
# Reduce encoder length to preserve computation
# Encoder: [sub-batch, T, D] -> [sub-batch, T', D]; T' < T
if sub_enc.shape[1] != max_sub_enc_length:
sub_enc = sub_enc.narrow(dim=1, start=0, length=int(max_sub_enc_length))
# sub_dec = decoder_outputs[begin:end, ...] # [sub-batch, U, D]
sub_dec = decoder_outputs.narrow(dim=0, start=begin, length=int(end - begin)) # [sub-batch, U, D]
# Reduce decoder length to preserve computation
# Decoder: [sub-batch, U, D] -> [sub-batch, U', D]; U' < U
if sub_dec.shape[1] != max_sub_transcript_length + 1:
sub_dec = sub_dec.narrow(dim=1, start=0, length=int(max_sub_transcript_length + 1))
# Perform joint => [sub-batch, T', U', V + 1]
sub_joint = self.joint(sub_enc, sub_dec)
del sub_dec
# Reduce transcript length to correct alignment
# Transcript: [sub-batch, L] -> [sub-batch, L']; L' <= L
if sub_transcripts.shape[1] != max_sub_transcript_length:
sub_transcripts = sub_transcripts.narrow(dim=1, start=0, length=int(max_sub_transcript_length))
# Compute sub batch loss
# preserve loss reduction type
loss_reduction = self.loss.reduction
# override loss reduction to sum
self.loss.reduction = None
# compute and preserve loss
loss_batch = self.loss(
log_probs=sub_joint,
targets=sub_transcripts,
input_lengths=sub_enc_lens,
target_lengths=sub_transcript_lens,
)
losses.append(loss_batch)
target_lengths.append(sub_transcript_lens)
# reset loss reduction type
self.loss.reduction = loss_reduction
else:
losses = None
# Update WER for sub batch
if compute_wer:
sub_enc = sub_enc.transpose(1, 2) # [B, T, D] -> [B, D, T]
sub_enc = sub_enc.detach()
sub_transcripts = sub_transcripts.detach()
# Update WER on each process without syncing
self.wer.update(
predictions=sub_enc,
predictions_lengths=sub_enc_lens,
targets=sub_transcripts,
targets_lengths=sub_transcript_lens,
)
del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens
# Reduce over sub batches
if losses is not None:
losses = self.loss.reduce(losses, target_lengths)
# Collect sub batch wer results
if compute_wer:
# Sync and all_reduce on all processes, compute global WER
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()
else:
wer = None
wer_num = None
wer_denom = None
return losses, wer, wer_num, wer_denom
def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g))
def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor:
return self.enc(encoder_output)
def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor:
return self.pred(prednet_output)
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
"""
Compute the joint step of the network after projection.
Here,
B = Batch size
T = Acoustic model timesteps
U = Target sequence length
H1, H2 = Hidden dimensions of the Encoder / Decoder respectively
H = Hidden dimension of the Joint hidden step.
V = Vocabulary size of the Decoder (excluding the RNNT blank token).
NOTE:
The implementation of this model is slightly modified from the original paper.
The original paper proposes the following steps :
(enc, dec) -> Expand + Concat + Sum [B, T, U, H1+H2] -> Forward through joint hidden [B, T, U, H] -- *1
*1 -> Forward through joint final [B, T, U, V + 1].
We instead split the joint hidden into joint_hidden_enc and joint_hidden_dec and act as follows:
enc -> Forward through joint_hidden_enc -> Expand [B, T, 1, H] -- *1
dec -> Forward through joint_hidden_dec -> Expand [B, 1, U, H] -- *2
(*1, *2) -> Sum [B, T, U, H] -> Forward through joint final [B, T, U, V + 1].
Args:
f: Output of the Encoder model. A torch.Tensor of shape [B, T, H1]
g: Output of the Decoder model. A torch.Tensor of shape [B, U, H2]
Returns:
Logits / log softmaxed tensor of shape (B, T, U, V + 1).
"""
f = f.unsqueeze(dim=2) # (B, T, 1, H)
g = g.unsqueeze(dim=1) # (B, 1, U, H)
inp = f + g # [B, T, U, H]
del f, g
res = self.joint_net(inp) # [B, T, U, V + 1]
del inp
if self.preserve_memory:
torch.cuda.empty_cache()
# If log_softmax is automatic
if self.log_softmax is None:
if not res.is_cuda: # Use log softmax only if on CPU
if self.temperature != 1.0:
res = (res / self.temperature).log_softmax(dim=-1)
else:
res = res.log_softmax(dim=-1)
else:
if self.log_softmax:
if self.temperature != 1.0:
res = (res / self.temperature).log_softmax(dim=-1)
else:
res = res.log_softmax(dim=-1)
return res
def _joint_net_modules(self, num_classes, pred_n_hidden, enc_n_hidden, joint_n_hidden, activation, dropout):
"""
Prepare the trainable modules of the Joint Network
Args:
num_classes: Number of output classes (vocab size) excluding the RNNT blank token.
pred_n_hidden: Hidden size of the prediction network.
enc_n_hidden: Hidden size of the encoder network.
joint_n_hidden: Hidden size of the joint network.
activation: Activation of the joint. Can be one of [relu, tanh, sigmoid]
dropout: Dropout value to apply to joint.
"""
pred = torch.nn.Linear(pred_n_hidden, joint_n_hidden)
enc = torch.nn.Linear(enc_n_hidden, joint_n_hidden)
if activation not in ['relu', 'sigmoid', 'tanh']:
raise ValueError("Unsupported activation for joint step - please pass one of " "[relu, sigmoid, tanh]")
activation = activation.lower()
if activation == 'relu':
activation = torch.nn.ReLU(inplace=True)
elif activation == 'sigmoid':
activation = torch.nn.Sigmoid()
elif activation == 'tanh':
activation = torch.nn.Tanh()
layers = (
[activation]
+ ([torch.nn.Dropout(p=dropout)] if dropout else [])
+ [torch.nn.Linear(joint_n_hidden, num_classes)]
)
return pred, enc, torch.nn.Sequential(*layers)