Skip to content

Commit

Permalink
Merge pull request #25 from rcy17/patch-1
Browse files Browse the repository at this point in the history
Fix prototypical_loss bug #23
  • Loading branch information
belerico authored Apr 5, 2022
2 parents 5e18a5e + 58b72c3 commit df89808
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/prototypical_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ def supp_idxs(c):

loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(2)
acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
acc_val = y_hat.eq(target_inds.squeeze(2)).float().mean()

return loss_val, acc_val

0 comments on commit df89808

Please sign in to comment.