Skip to content

Commit

Permalink
Fix bug in RNNT Joint WER calculation for fused batch (#8587)
Browse files Browse the repository at this point in the history
Signed-off-by: smajumdar <[email protected]>
Signed-off-by: Pablo Garay <[email protected]>
  • Loading branch information
titu1994 authored and pablo-garay committed Mar 19, 2024
1 parent fc87bd3 commit 60bb3de
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,6 +1394,7 @@ def forward(
)

losses = []
wers, wer_nums, wer_denoms = [], [], []
target_lengths = []
batch_size = int(encoder_outputs.size(0)) # actual batch size

Expand Down Expand Up @@ -1476,6 +1477,13 @@ def forward(
targets=sub_transcripts,
targets_lengths=sub_transcript_lens,
)
# Sync and all_reduce on all processes, compute global WER
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()

wers.append(wer)
wer_nums.append(wer_num)
wer_denoms.append(wer_denom)

del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens

Expand All @@ -1485,9 +1493,9 @@ def forward(

# Collect sub batch wer results
if compute_wer:
# Sync and all_reduce on all processes, compute global WER
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()
wer = sum(wers) / len(wers)
wer_num = sum(wer_nums)
wer_denom = sum(wer_denoms)
else:
wer = None
wer_num = None
Expand Down Expand Up @@ -1899,6 +1907,7 @@ def forward(
)

losses = []
wers, wer_nums, wer_denoms = [], [], []
target_lengths = []
batch_size = int(encoder_outputs.size(0)) # actual batch size

Expand Down Expand Up @@ -1992,6 +2001,14 @@ def forward(
targets_lengths=sub_transcript_lens,
)

# Sync and all_reduce on all processes, compute global WER
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()

wers.append(wer)
wer_nums.append(wer_num)
wer_denoms.append(wer_denom)

del sub_enc, sub_transcripts, sub_enc_lens, sub_transcript_lens

# Reduce over sub batches
Expand All @@ -2000,9 +2017,9 @@ def forward(

# Collect sub batch wer results
if compute_wer:
# Sync and all_reduce on all processes, compute global WER
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()
wer = sum(wers) / len(wers)
wer_num = sum(wer_nums)
wer_denom = sum(wer_denoms)
else:
wer = None
wer_num = None
Expand Down

0 comments on commit 60bb3de

Please sign in to comment.