diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index f2a4f983ca15..3a87344421d0 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1394,6 +1394,7 @@ def forward( ) losses = [] + wers, wer_nums, wer_denoms = [], [], [] target_lengths = [] batch_size = int(encoder_outputs.size(0)) # actual batch size @@ -1476,6 +1477,13 @@ def forward( targets=sub_transcripts, targets_lengths=sub_transcript_lens, ) + # Sync and all_reduce on all processes, compute global WER + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + wers.append(wer) + wer_nums.append(wer_num) + wer_denoms.append(wer_denom) del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens @@ -1485,9 +1493,9 @@ def forward( # Collect sub batch wer results if compute_wer: - # Sync and all_reduce on all processes, compute global WER - wer, wer_num, wer_denom = self.wer.compute() - self.wer.reset() + wer = sum(wers) / len(wers) + wer_num = sum(wer_nums) + wer_denom = sum(wer_denoms) else: wer = None wer_num = None @@ -1899,6 +1907,7 @@ def forward( ) losses = [] + wers, wer_nums, wer_denoms = [], [], [] target_lengths = [] batch_size = int(encoder_outputs.size(0)) # actual batch size @@ -1992,6 +2001,14 @@ def forward( targets_lengths=sub_transcript_lens, ) + # Sync and all_reduce on all processes, compute global WER + wer, wer_num, wer_denom = self.wer.compute() + self.wer.reset() + + wers.append(wer) + wer_nums.append(wer_num) + wer_denoms.append(wer_denom) + del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens # Reduce over sub batches @@ -2000,9 +2017,9 @@ def forward( # Collect sub batch wer results if compute_wer: - # Sync and all_reduce on all processes, compute global WER - wer, wer_num, wer_denom = self.wer.compute() - self.wer.reset() + wer = sum(wers) / len(wers) + wer_num = sum(wer_nums) + wer_denom = sum(wer_denoms) else: wer = None wer_num = None