diff --git a/apex/contrib/multihead_attn/self_multihead_attn_func.py b/apex/contrib/multihead_attn/self_multihead_attn_func.py index f3bba008c..c00d139f5 100644 --- a/apex/contrib/multihead_attn/self_multihead_attn_func.py +++ b/apex/contrib/multihead_attn/self_multihead_attn_func.py @@ -183,7 +183,7 @@ def backward(ctx, output_grads): values_grads = torch.bmm(dropout_results.transpose(1,2), output_lin_grads, out=values_grads.transpose(0,1)) # Mask and Scaling for Dropout (not a publically documented op) - dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, dropout_prob_t[0]) + dropout_grads = torch._masked_scale(matmul2_dgrad1, dropout_mask, 1.0/(1.0-dropout_prob_t[0])) # Softmax Grad (not a publically documented op) softmax_grads = torch._softmax_backward_data(dropout_grads, softmax_results, -1, softmax_results)