From 5d7caa95af7e3675305c542253c4e372801897d2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 28 Apr 2020 10:56:24 +0800 Subject: [PATCH] bug re-fix --- main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 72015f0..f12392a 100755 --- a/main.py +++ b/main.py @@ -181,8 +181,8 @@ def outlier(each_target_share_weight): adv_loss_separate += nn.BCELoss()(domain_prob_discriminator_source_separate, torch.ones_like(domain_prob_discriminator_source_separate)) adv_loss_separate += nn.BCELoss()(domain_prob_discriminator_target_separate, torch.zeros_like(domain_prob_discriminator_target_separate)) - # ============================== cross entropy loss, it receives logits as its inputs - ce = nn.CrossEntropyLoss(reduction='none')(fc2_s, label_source) + # ============================== cross entropy loss + ce = nn.CrossEntropyLoss(reduction='none')(predict_prob_source, label_source) ce = torch.mean(ce, dim=0, keepdim=True) with OptimizerManager(