Skip to content

Commit

Permalink
Fix prototypical_loss bug orobix#23
Browse files Browse the repository at this point in the history
  • Loading branch information
rcy17 authored Feb 3, 2021
1 parent 5e18a5e commit 58b72c3
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/prototypical_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,6 @@ def supp_idxs(c):

loss_val = -log_p_y.gather(2, target_inds).squeeze().view(-1).mean()
_, y_hat = log_p_y.max(2)
acc_val = y_hat.eq(target_inds.squeeze()).float().mean()
acc_val = y_hat.eq(target_inds.squeeze(2)).float().mean()

return loss_val, acc_val

0 comments on commit 58b72c3

Please sign in to comment.