Skip to content

Commit

Permalink
Update megatron_bert_model.py
Browse files Browse the repository at this point in the history
Signed-off-by: Shanmugam Ramasamy <[email protected]>
  • Loading branch information
shanmugamr1992 authored Nov 16, 2022
1 parent ff57d7a commit a4afc8d
Showing 1 changed file with 1 addition and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def loss_func(self, loss_mask, sentence_order, output_tensor):
loss_mask = loss_mask.float()

# Sometimes when the number of tokens is very small, none of the tokens get masked for prediction. In that case loss mask is all zeros
# i.e Happens when the entire batch is masked out (Practically when MBS=1 or 2, and the number of tokens in each batch is < 7 )
if loss_mask.sum() == 0:
lm_loss = torch.sum(lm_loss_.view(-1)) * 0.0
else:
Expand Down

0 comments on commit a4afc8d

Please sign in to comment.