From 73685cd363f39c8e2e1b67d20b1ad51b16ecd36b Mon Sep 17 00:00:00 2001 From: Nan Rosemary Ke Date: Mon, 23 Oct 2017 19:21:31 -0400 Subject: [PATCH] correct alignments between back and forward states --- train_twinnet_para.py | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/train_twinnet_para.py b/train_twinnet_para.py index f96e93ef..b9a25735 100644 --- a/train_twinnet_para.py +++ b/train_twinnet_para.py @@ -108,32 +108,46 @@ def train(opt): torch.cuda.synchronize() start = time.time() + # labels = x1, x2, x3, x4 + # rev_labels = x4, x3, x2, x1 reverse_labels = np.flip(data['labels'],1).copy() + # maks = 1, 1, 1, 0 + # rev_masks = 0, 1, 1, 1 reverse_masks = np.flip(data['masks'], 1).copy() tmp = [data['fc_feats'], data['att_feats'], data['labels'],reverse_labels, data['masks'], reverse_masks] tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp] fc_feats, att_feats, labels, reverse_labels, masks, reverse_masks = tmp optimizer.zero_grad() + # affine_states = s1, s2, s3 out, affine_states = model(fc_feats, att_feats, labels) + # back_states = s4, s3, s2 back_out, back_states = back_model(fc_feats, att_feats, reverse_labels) idx = [i for i in range(back_states.size()[1] - 1, -1, -1)] idx = torch.LongTensor(idx) idx = Variable(idx).cuda() + # invert_backstates = s2, s3, s4 invert_backstates = back_states.index_select(1, idx) + # labels[:, 1:] = x2, x3, x4 loss = crit( out, labels[:,1:], masks[:,1:]) + # rev_lables[:, 1:] = x3, x2, x1 back_loss = crit(back_out, reverse_labels[:, 1:], reverse_masks[:, 1:]) + # do affine transform on forward state - - back_states = back_states.detach() affine_states = affine_states * masks[:, 1:].unsqueeze(2).expand_as(affine_states) invert_backstates = invert_backstates * masks[:, 1:].unsqueeze(2).expand_as(invert_backstates) - + + # affine_states[:, :-1] = s1, s2 + affine_states = affine_states[:, :-1, :] + # invert_backstates[:, 1:] = s3, s4 + invert_backstates = invert_backstates[:, 1: , :] + invert_backstates = invert_backstates.detach() + l2_loss = ((affine_states - invert_backstates)** 2).sum(dim=1)[:, 0, :] - l2_loss = (l2_loss / masks[:, 1:].sum(dim=1).expand_as(l2_loss)).mean() + l2_loss = (l2_loss / masks[:, 1:-1].sum(dim=1).expand_as(l2_loss)).mean() - all_loss = loss + 2.0 * l2_loss + back_loss + all_loss = loss + 3.0 * l2_loss + back_loss all_loss.backward() #back_loss.backward()