Skip to content

Commit

Permalink
update kd qa in roberta modeling (huggingface#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
stevezheng23 authored Oct 29, 2019
1 parent 8b67c21 commit 6645298
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion transformers/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_
kd_end_probs = nn.LogSoftmax(kd_end_logits)
kd_start_loss = kd_loss_fct(kd_start_probs, start_targets)
kd_end_loss = kd_loss_fct(kd_end_probs, end_targets)
kd_span_loss = (kd_start_loss + kd_end_loss) / 2
kd_span_loss = (self.kd_temperature ** 2) * (kd_start_loss + kd_end_loss) / 2
total_loss = kd_span_loss if total_loss == None else total_loss + kd_span_loss

if total_loss is not None:
Expand Down

0 comments on commit 6645298

Please sign in to comment.