Skip to content

Commit

Permalink
Further towards AttModel: remove the use of get_logprobs
Browse files Browse the repository at this point in the history
  • Loading branch information
ruotianluo committed Apr 11, 2019
1 parent cde27bc commit 46b7a94
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions models/TransformerModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(self, opt):
self.embed = lambda x : x
delattr(self, 'fc_embed')
self.fc_embed = lambda x : x
del self.logit
delattr(self, 'logit')
del self.ctx2att

tgt_vocab = self.vocab_size + 1
Expand All @@ -285,6 +285,9 @@ def __init__(self, opt):
d_model=opt.input_encoding_size,
d_ff=opt.rnn_size)

def logit(self, x): # unsafe way
return self.model.generator.proj(x)

def init_hidden(self, bsz):
return None

Expand Down Expand Up @@ -326,7 +329,7 @@ def _forward(self, fc_feats, att_feats, seq, att_masks=None):
return outputs
# return torch.cat([_.unsqueeze(1) for _ in outputs], 1)

def get_logprobs_state(self, it, fc_feats_ph, att_feats_ph, memory, mask, state):
def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask):
"""
state = [ys.unsqueeze(0)]
"""
Expand All @@ -338,8 +341,4 @@ def get_logprobs_state(self, it, fc_feats_ph, att_feats_ph, memory, mask, state)
ys,
subsequent_mask(ys.size(1))
.to(memory.device))
logprobs = self.model.generator(out[:, -1])

return logprobs, [ys.unsqueeze(0)]

# For _sample and _sample_beam, now p_att_feats = memory
return out[:, -1], [ys.unsqueeze(0)]

0 comments on commit 46b7a94

Please sign in to comment.