From 46b7a94fbec2ad4323a0e703c4a4fd2cd5c85ff8 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Tue, 9 Apr 2019 15:38:04 -0500 Subject: [PATCH] Further towards AttModel: remove the use of get_logprobs --- models/TransformerModel.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index 2b8e06bc..0670157f 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -276,7 +276,7 @@ def __init__(self, opt): self.embed = lambda x : x delattr(self, 'fc_embed') self.fc_embed = lambda x : x - del self.logit + delattr(self, 'logit') del self.ctx2att tgt_vocab = self.vocab_size + 1 @@ -285,6 +285,9 @@ def __init__(self, opt): d_model=opt.input_encoding_size, d_ff=opt.rnn_size) + def logit(self, x): # unsafe way + return self.model.generator.proj(x) + def init_hidden(self, bsz): return None @@ -326,7 +329,7 @@ def _forward(self, fc_feats, att_feats, seq, att_masks=None): return outputs # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) - def get_logprobs_state(self, it, fc_feats_ph, att_feats_ph, memory, mask, state): + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): """ state = [ys.unsqueeze(0)] """ @@ -338,8 +341,4 @@ def get_logprobs_state(self, it, fc_feats_ph, att_feats_ph, memory, mask, state) ys, subsequent_mask(ys.size(1)) .to(memory.device)) - logprobs = self.model.generator(out[:, -1]) - - return logprobs, [ys.unsqueeze(0)] - -# For _sample and _sample_beam, now p_att_feats = memory \ No newline at end of file + return out[:, -1], [ys.unsqueeze(0)] \ No newline at end of file