From c9528bf8c8d5e32f41f64d6d3c294cae19fc2982 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Thu, 3 May 2018 22:26:08 -0500 Subject: [PATCH 01/12] Add transformer --- misc/utils.py | 69 +++++ models/TransformerModel.py | 518 +++++++++++++++++++++++++++++++++++++ models/__init__.py | 4 + opts.py | 8 +- train.py | 35 ++- 5 files changed, 620 insertions(+), 14 deletions(-) create mode 100644 models/TransformerModel.py diff --git a/misc/utils.py b/misc/utils.py index 95b227cf..67799c4e 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -65,6 +65,38 @@ def forward(self, input, target, mask): return output +class LabelSmoothing(nn.Module): + "Implement label smoothing." + def __init__(self, size=0, padding_idx=0, smoothing=0.0): + super(LabelSmoothing, self).__init__() + self.criterion = nn.KLDivLoss(size_average=False, reduce=False) + # self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + # self.size = size + self.true_dist = None + + def forward(self, input, target, mask): + # truncate to the same size + target = target[:, :input.size(1)] + mask = mask[:, :input.size(1)] + + input = to_contiguous(input).view(-1, input.size(-1)) + target = to_contiguous(target).view(-1) + mask = to_contiguous(mask).view(-1) + + # assert x.size(1) == self.size + self.size = input.size(1) + # true_dist = x.data.clone() + true_dist = input.data.clone() + # true_dist.fill_(self.smoothing / (self.size - 2)) + true_dist.fill_(self.smoothing / (self.size - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + # true_dist[:, self.padding_idx] = 0 + # mask = torch.nonzero(target.data == self.padding_idx) + # self.true_dist = true_dist + return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() + def set_lr(optimizer, lr): for group in optimizer.param_groups: group['lr'] = lr @@ -89,4 +121,41 @@ def build_optimizer(params, opt): return optim.Adam(params, opt.learning_rate, (opt.optim_alpha, opt.optim_beta), opt.optim_epsilon, weight_decay=opt.weight_decay) else: raise Exception("bad option opt.optim: {}".format(opt.optim)) + + +class NoamOpt(object): + "Optim wrapper that implements rate." + def __init__(self, model_size, factor, warmup, optimizer): + self.optimizer = optimizer + self._step = 0 + self.warmup = warmup + self.factor = factor + self.model_size = model_size + self._rate = 0 + + def step(self): + "Update parameters and rate" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p['lr'] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step = None): + "Implement `lrate` above" + if step is None: + step = self._step + return self.factor * \ + (self.model_size ** (-0.5) * + min(step ** (-0.5), step * self.warmup ** (-1.5))) + + def __getattr__(self, name): + return getattr(self.optimizer, name) + +def get_std_opt(model): + # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, + # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) + return NoamOpt(model.tgt_embed[0].d_model, 1, 2000, + torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) \ No newline at end of file diff --git a/models/TransformerModel.py b/models/TransformerModel.py new file mode 100644 index 00000000..3938a8b7 --- /dev/null +++ b/models/TransformerModel.py @@ -0,0 +1,518 @@ +# This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model + +# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning +# https://arxiv.org/abs/1612.01887 +# AdaAttMO is a modified version with maxout lstm + +# Att2in is from Self-critical Sequence Training for Image Captioning +# https://arxiv.org/abs/1612.00563 +# In this file we only have Att2in2, which is a slightly different version of att2in, +# in which the img feature embedding and word embedding is the same as what in adaatt. + +# TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA +# https://arxiv.org/abs/1707.07998 +# However, it may not be identical to the author's architecture. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +import misc.utils as utils + +import copy +from torch.autograd import Variable +import math +import numpy as np + +from .CaptionModel import CaptionModel +from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper + +class EncoderDecoder(nn.Module): + """ + A standard Encoder-Decoder architecture. Base for this and many + other models. + """ + def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): + super(EncoderDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.generator = generator + + def forward(self, src, tgt, src_mask, tgt_mask): + "Take in and process masked src and target sequences." + return self.decode(self.encode(src, src_mask), src_mask, + tgt, tgt_mask) + + def encode(self, src, src_mask): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, memory, src_mask, tgt, tgt_mask): + return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) + +class Generator(nn.Module): + "Define standard linear + softmax generation step." + def __init__(self, d_model, vocab): + super(Generator, self).__init__() + self.proj = nn.Linear(d_model, vocab) + + def forward(self, x): + return F.log_softmax(self.proj(x), dim=-1) + +def clones(module, N): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +class Encoder(nn.Module): + "Core encoder is a stack of N layers" + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask): + "Pass the input (and mask) through each layer in turn." + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + +class LayerNorm(nn.Module): + "Construct a layernorm module (See citation for details)." + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward (defined below)" + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "Follow Figure 1 (left) for connections." + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + +class Decoder(nn.Module): + "Generic N layer decoder with masking." + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, memory, src_mask, tgt_mask): + for layer in self.layers: + x = layer(x, memory, src_mask, tgt_mask) + return self.norm(x) + +class DecoderLayer(nn.Module): + "Decoder is made of self-attn, src-attn, and feed forward (defined below)" + def __init__(self, size, self_attn, src_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 3) + + def forward(self, x, memory, src_mask, tgt_mask): + "Follow Figure 1 (right) for connections." + m = memory + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) + return self.sublayer[2](x, self.feed_forward) + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) \ + / math.sqrt(d_k) + + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + p_attn = F.softmax(scores, dim = -1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = \ + [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value))] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention(query, key, value, mask=mask, + dropout=self.dropout) + + # 3) "Concat" using a view and apply a final linear. + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + return self.linears[-1](x) + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + +class Embeddings(nn.Module): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + return self.lut(x) * math.sqrt(self.d_model) + +class PositionalEncoding(nn.Module): + "Implement the PE function." + def __init__(self, d_model, dropout, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + Variable(self.pe[:, :x.size(1)], + requires_grad=False) + return self.dropout(x) + +class TransformerModel(CaptionModel): + + def make_model(self, src_vocab, tgt_vocab, N=6, + d_model=512, d_ff=2048, h=8, dropout=0.1): + "Helper: Construct a model from hyperparameters." + c = copy.deepcopy + attn = MultiHeadedAttention(h, d_model) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + position = PositionalEncoding(d_model, dropout) + model = EncoderDecoder( + Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), + Decoder(DecoderLayer(d_model, c(attn), c(attn), + c(ff), dropout), N), + lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), + nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), + Generator(d_model, tgt_vocab)) + + # This was important from their code. + # Initialize parameters with Glorot / fan_avg. + for p in model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + return model + + def __init__(self, opt): + super(TransformerModel, self).__init__() + self.opt = opt + # self.config = yaml.load(open(opt.config_file)) + # d_model = self.input_encoding_size # 512 + + self.vocab_size = opt.vocab_size + # self.input_encoding_size = opt.input_encoding_size + # #self.rnn_type = opt.rnn_type + self.rnn_size = opt.rnn_size + self.rnn_size = 512 + # self.num_layers = opt.num_layers + self.drop_prob_lm = opt.drop_prob_lm + self.seq_length = opt.seq_length + # self.fc_feat_size = opt.fc_feat_size + self.att_feat_size = opt.att_feat_size + # self.att_hid_size = opt.att_hid_size + + self.use_bn = getattr(opt, 'use_bn', 0) + + self.ss_prob = 0.0 # Schedule sampling probability + + # self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), + # nn.ReLU(), + # nn.Dropout(self.drop_prob_lm)) + # self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), + # nn.ReLU(), + # nn.Dropout(self.drop_prob_lm)) + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ + (nn.Linear(self.att_feat_size, self.rnn_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm))+ + ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) + + # self.logit_layers = getattr(opt, 'logit_layers', 1) + # if self.logit_layers == 1: + # self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) + # else: + # self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] + # self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) + # self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) + + tgt_vocab = self.vocab_size + 1 + self.model = self.make_model(0, tgt_vocab) + + # def init_hidden(self, bsz): + # weight = next(self.parameters()) + # return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), + # weight.new_zeros(self.num_layers, bsz, self.rnn_size)) + + def clip_att(self, att_feats, att_masks): + # Clip the length of att_masks and att_feats to the maximum length + if att_masks is not None: + max_len = att_masks.data.long().sum(1).max() + att_feats = att_feats[:, :max_len].contiguous() + att_masks = att_masks[:, :max_len].contiguous() + return att_feats, att_masks + + # def _prepare_feature(self, fc_feats, att_feats, att_masks): + + # # embed fc and att feats + # fc_feats = self.fc_embed(fc_feats) + # att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + # # Project the attention feats first to reduce memory and computation comsumptions. + # p_att_feats = self.ctx2att(att_feats) + + # return fc_feats, att_feats, p_att_feats + + def _prepare_feature(self, att_feats, att_masks=None, seq=None): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + if att_masks is None: + att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) + att_masks = att_masks.unsqueeze(-2) + + if seq is not None: + # crop the last one + seq = seq[:,:-1] + seq_mask = (seq.data > 0) + seq_mask[:,0] += 1 + + seq_mask = seq_mask.unsqueeze(-2) + seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) + else: + seq_mask = None + + return att_feats, seq, att_masks, seq_mask + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks, seq) + + out = self.model(att_feats, seq, att_masks, seq_mask) + + # batch_size = fc_feats.size(0) + # state = self.init_hidden(batch_size) + + # # outputs = [] + # outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1) + + # fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) + + # for i in range(seq.size(1) - 1): + # if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample + # sample_prob = fc_feats.new(batch_size).uniform_(0, 1) + # sample_mask = sample_prob < self.ss_prob + # if sample_mask.sum() == 0: + # it = seq[:, i].clone() + # else: + # sample_ind = sample_mask.nonzero().view(-1) + # it = seq[:, i].data.clone() + # #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) + # #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) + # # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) + # prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) + # it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) + # else: + # it = seq[:, i].clone() + # # break if all the sequences end + # if i >= 1 and seq[:, i].sum() == 0: + # break + + # output, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) + # outputs[:, i] = output + # # outputs.append(output) + + outputs = self.model.generator(out) + return outputs + # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) + + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): + beam_size = opt.get('beam_size', 10) + batch_size = fc_feats.size(0) + + fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = torch.LongTensor(self.seq_length, batch_size).zero_() + seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + for k in range(batch_size): + state = self.init_hidden(beam_size) + tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) + tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() + tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() + tmp_att_masks = att_masks[k:k+1].expand(*((beam_size,)+att_masks.size()[1:])).contiguous() if att_masks is not None else None + + for t in range(1): + if t == 0: # input + it = fc_feats.new_zeros([beam_size], dtype=torch.long) + + logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) + + self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) + seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[:, k] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) + + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + sample_max = opt.get('sample_max', 1) + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + decoding_constraint = opt.get('decoding_constraint', 0) + if beam_size > 1: + return self._sample_beam(fc_feats, att_feats, att_masks, opt) + + batch_size = att_feats.shape[0] + + att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks) + + memory = self.model.encode(att_feats, att_masks) + ys = torch.zeros((batch_size, 1), dtype=torch.long).to(att_feats.device) + + seq = att_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) + seqLogprobs = att_feats.new_zeros(batch_size, self.seq_length) + + for i in range(self.seq_length): + out = self.model.decode(memory, att_masks, + ys, + subsequent_mask(ys.size(1)) + .to(att_feats.device)) + logprob = self.model.generator(out[:, -1]) + if sample_max: + sampleLogprobs, next_word = torch.max(logprob, dim = 1) + else: + if temperature == 1.0: + prob_prev = torch.exp(logprob.data) # fetch prev distribution: shape Nx(M+1) + else: + # scale logprobs by temperature + prob_prev = torch.exp(torch.div(logprob.data, temperature)) + next_word = torch.multinomial(prob_prev, 1) + sampleLogprobs = logprobs.gather(1, next_word) # gather the logprobs at sampled positions + + seq[:,i] = next_word + seqLogprobs[:,i] = sampleLogprobs + ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) + return seq, seqLogprobs + + # batch_size = fc_feats.size(0) + # state = self.init_hidden(batch_size) + + # fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) + + # seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) + # seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) + # for t in range(self.seq_length + 1): + # if t == 0: # input + # it = fc_feats.new_zeros(batch_size, dtype=torch.long) + # elif sample_max: + # sampleLogprobs, it = torch.max(logprobs.data, 1) + # it = it.view(-1).long() + # else: + # if temperature == 1.0: + # prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) + # else: + # # scale logprobs by temperature + # prob_prev = torch.exp(torch.div(logprobs.data, temperature)) + # it = torch.multinomial(prob_prev, 1) + # sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions + # it = it.view(-1).long() # and flatten indices for downstream processing + + # if t >= 1: + # # stop when all finished + # if t == 1: + # unfinished = it > 0 + # else: + # unfinished = unfinished * (it > 0) + # if unfinished.sum() == 0: + # break + # it = it * unfinished.type_as(it) + # seq[:,t-1] = it + # # seq.append(it) #seq[t] the input of t+2 time step + + # # seqLogprobs.append(sampleLogprobs.view(-1)) + # seqLogprobs[:,t-1] = sampleLogprobs.view(-1) + + # logprobs, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) + # if decoding_constraint and t > 0: + # tmp = output.new_zeros(output.size(0), self.vocab_size + 1) + # tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) + # logprobs = logprobs + tmp + + # return seq, seqLogprobs diff --git a/models/__init__.py b/models/__init__.py index 73d27629..8d329bff 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -14,6 +14,7 @@ from .OldModel import ShowAttendTellModel, AllImgModel from .Att2inModel import Att2inModel from .AttModel import * +from .TransformerModel import TransformerModel def setup(opt): @@ -44,6 +45,9 @@ def setup(opt): # DenseAtt elif opt.caption_model == 'denseatt': model = DenseAttModel(opt) + # Transformer + elif opt.caption_model == 'transformer': + model = TransformerModel(opt) else: raise Exception("Caption model not supported: {}".format(opt.caption_model)) diff --git a/opts.py b/opts.py index 6d519fab..ea44fcd0 100644 --- a/opts.py +++ b/opts.py @@ -25,7 +25,7 @@ def parse_opt(): # Model settings parser.add_argument('--caption_model', type=str, default="show_tell", - help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt') + help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt, transformer') parser.add_argument('--rnn_size', type=int, default=512, help='size of the rnn in number of hidden nodes in each layer') parser.add_argument('--num_layers', type=int, default=1, @@ -128,6 +128,12 @@ def parse_opt(): parser.add_argument('--bleu_reward_weight', type=float, default=0, help='The reward weight from bleu4') + # Transformer + parser.add_argument('--label_smoothing', type=float, default=0, + help='') + parser.add_argument('--noamopt', action='store_true', + help='') + args = parser.parse_args() # Check if args are valid diff --git a/train.py b/train.py index 71147e93..d220189f 100644 --- a/train.py +++ b/train.py @@ -71,28 +71,35 @@ def train(opt): model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) - update_lr_flag = True + epoch_done = True # Assure in training mode dp_model.train() - crit = utils.LanguageModelCriterion() + if opt.label_smoothing > 0: + crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) + else: + crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() - optimizer = utils.build_optimizer(model.parameters(), opt) + if opt.noamopt: + optimizer = utils.get_std_opt(model.model) + else: + optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: - if update_lr_flag: + if epoch_done: + if not opt.noamopt: # Assign the learning rate - if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: - frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every - decay_factor = opt.learning_rate_decay_rate ** frac - opt.current_lr = opt.learning_rate * decay_factor - else: - opt.current_lr = opt.learning_rate - utils.set_lr(optimizer, opt.current_lr) # set the decayed rate + if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: + frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every + decay_factor = opt.learning_rate_decay_rate ** frac + opt.current_lr = opt.learning_rate * decay_factor + else: + opt.current_lr = opt.learning_rate + utils.set_lr(optimizer, opt.current_lr) # set the decayed rate # Assign the scheduled sampling prob if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0: frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every @@ -106,7 +113,7 @@ def train(opt): else: sc_flag = False - update_lr_flag = False + epoch_done = False start = time.time() # Load data from train split (0) @@ -145,11 +152,13 @@ def train(opt): iteration += 1 if data['bounds']['wrapped']: epoch += 1 - update_lr_flag = True + epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) + if opt.noamopt: + opt.current_lr = optimizer.rate() add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: From 50e3f2e4c29fa802c4f290cf235e2ca6bfa58eb8 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Fri, 4 May 2018 13:18:17 -0500 Subject: [PATCH 02/12] Add noamopt options. --- misc/utils.py | 4 ++-- opts.py | 4 ++++ train.py | 4 +++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/misc/utils.py b/misc/utils.py index 67799c4e..f4b4fdf0 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -153,9 +153,9 @@ def rate(self, step = None): def __getattr__(self, name): return getattr(self.optimizer, name) -def get_std_opt(model): +def get_std_opt(model, factor=1, warmup=2000): # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) - return NoamOpt(model.tgt_embed[0].d_model, 1, 2000, + return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) \ No newline at end of file diff --git a/opts.py b/opts.py index ea44fcd0..2f7123ae 100644 --- a/opts.py +++ b/opts.py @@ -133,6 +133,10 @@ def parse_opt(): help='') parser.add_argument('--noamopt', action='store_true', help='') + parser.add_argument('--noamopt_warmup', type=int, default=2000, + help='') + parser.add_argument('--noamopt_factor', type=float, default=1, + help='') args = parser.parse_args() diff --git a/train.py b/train.py index d220189f..d363185c 100644 --- a/train.py +++ b/train.py @@ -82,7 +82,9 @@ def train(opt): rl_crit = utils.RewardCriterion() if opt.noamopt: - optimizer = utils.get_std_opt(model.model) + assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' + optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) + optimizer._step = iteration else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer From 96c4e89b23b4410f0065f06fb23cb791ea1b827d Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Fri, 4 May 2018 13:18:53 -0500 Subject: [PATCH 03/12] formatting and remove variable. --- models/TransformerModel.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index 3938a8b7..f9be2e8b 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -23,7 +23,6 @@ import misc.utils as utils import copy -from torch.autograd import Variable import math import numpy as np @@ -161,7 +160,6 @@ def attention(query, key, value, mask=None, dropout=None): d_k = query.size(-1) scores = torch.matmul(query, key.transpose(-2, -1)) \ / math.sqrt(d_k) - if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) p_attn = F.softmax(scores, dim = -1) @@ -187,7 +185,7 @@ def forward(self, query, key, value, mask=None): # Same mask applied to all h heads. mask = mask.unsqueeze(1) nbatches = query.size(0) - + # 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \ [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) @@ -239,8 +237,7 @@ def __init__(self, d_model, dropout, max_len=5000): self.register_buffer('pe', pe) def forward(self, x): - x = x + Variable(self.pe[:, :x.size(1)], - requires_grad=False) + x = x + self.pe[:, :x.size(1)] return self.dropout(x) class TransformerModel(CaptionModel): From f265b155596ad11b37ffb8d97b5313962941e8c2 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Fri, 4 May 2018 14:06:35 -0500 Subject: [PATCH 04/12] uncomment the original sample. --- models/TransformerModel.py | 98 +++++++++++++++++++++----------------- 1 file changed, 55 insertions(+), 43 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index f9be2e8b..ca17ea4b 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -468,48 +468,60 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) return seq, seqLogprobs - # batch_size = fc_feats.size(0) - # state = self.init_hidden(batch_size) + def _sample_(self, fc_feats, att_feats, att_masks=None, opt={}): + att_feats, att_masks = self.clip_att(att_feats, att_masks) - # fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) + sample_max = opt.get('sample_max', 1) + beam_size = opt.get('beam_size', 1) + temperature = opt.get('temperature', 1.0) + decoding_constraint = opt.get('decoding_constraint', 0) + if beam_size > 1: + return self._sample_beam(fc_feats, att_feats, att_masks, opt) - # seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) - # seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) - # for t in range(self.seq_length + 1): - # if t == 0: # input - # it = fc_feats.new_zeros(batch_size, dtype=torch.long) - # elif sample_max: - # sampleLogprobs, it = torch.max(logprobs.data, 1) - # it = it.view(-1).long() - # else: - # if temperature == 1.0: - # prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) - # else: - # # scale logprobs by temperature - # prob_prev = torch.exp(torch.div(logprobs.data, temperature)) - # it = torch.multinomial(prob_prev, 1) - # sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions - # it = it.view(-1).long() # and flatten indices for downstream processing - - # if t >= 1: - # # stop when all finished - # if t == 1: - # unfinished = it > 0 - # else: - # unfinished = unfinished * (it > 0) - # if unfinished.sum() == 0: - # break - # it = it * unfinished.type_as(it) - # seq[:,t-1] = it - # # seq.append(it) #seq[t] the input of t+2 time step - - # # seqLogprobs.append(sampleLogprobs.view(-1)) - # seqLogprobs[:,t-1] = sampleLogprobs.view(-1) - - # logprobs, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) - # if decoding_constraint and t > 0: - # tmp = output.new_zeros(output.size(0), self.vocab_size + 1) - # tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) - # logprobs = logprobs + tmp - - # return seq, seqLogprobs + batch_size = fc_feats.size(0) + state = self.init_hidden(batch_size) + + fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) + + # seq = [] + # seqLogprobs = [] + seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) + for t in range(self.seq_length + 1): + if t == 0: # input + it = fc_feats.new_zeros(batch_size, dtype=torch.long) + elif sample_max: + sampleLogprobs, it = torch.max(logprobs.data, 1) + it = it.view(-1).long() + else: + if temperature == 1.0: + prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) + else: + # scale logprobs by temperature + prob_prev = torch.exp(torch.div(logprobs.data, temperature)) + it = torch.multinomial(prob_prev, 1) + sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions + it = it.view(-1).long() # and flatten indices for downstream processing + + if t >= 1: + # stop when all finished + if t == 1: + unfinished = it > 0 + else: + unfinished = unfinished * (it > 0) + if unfinished.sum() == 0: + break + it = it * unfinished.type_as(it) + seq[:,t-1] = it + # seq.append(it) #seq[t] the input of t+2 time step + + # seqLogprobs.append(sampleLogprobs.view(-1)) + seqLogprobs[:,t-1] = sampleLogprobs.view(-1) + + logprobs, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) + if decoding_constraint and t > 0: + tmp = output.new_zeros(output.size(0), self.vocab_size + 1) + tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + return seq, seqLogprobs From 7e89919dbd6cbc22ae1dd8f1f1988b23af965a01 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Fri, 4 May 2018 14:12:29 -0500 Subject: [PATCH 05/12] Update to previous framework of sample and sample_beam --- models/TransformerModel.py | 99 +++++++++++++++++++++++--------------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index ca17ea4b..55d92abf 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -397,11 +397,28 @@ def _forward(self, fc_feats, att_feats, seq, att_masks=None): return outputs # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) + def get_logprobs_state(self, it, memory, mask, state): + """ + state = [ys.unsqueeze(0)] + """ + if state is None: + ys = it.unsqueeze(1) + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + out = self.model.decode(memory, mask, + ys, + subsequent_mask(ys.size(1)) + .to(memory.device)) + logprobs = self.model.generator(out[:, -1]) + + return logprobs, [ys.unsqueeze(0)] + def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): beam_size = opt.get('beam_size', 10) batch_size = fc_feats.size(0) - fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) + att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks) + memory = self.model.encode(att_feats, att_masks) assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' seq = torch.LongTensor(self.seq_length, batch_size).zero_() @@ -410,25 +427,23 @@ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): self.done_beams = [[] for _ in range(batch_size)] for k in range(batch_size): - state = self.init_hidden(beam_size) - tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, fc_feats.size(1)) - tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:])).contiguous() - tmp_p_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous() + state = None + tmp_memory = memory[k:k+1].expand(*((beam_size,)+memory.size()[1:])).contiguous() tmp_att_masks = att_masks[k:k+1].expand(*((beam_size,)+att_masks.size()[1:])).contiguous() if att_masks is not None else None for t in range(1): if t == 0: # input it = fc_feats.new_zeros([beam_size], dtype=torch.long) - logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state) + logprobs, state = self.get_logprobs_state(it, tmp_memory, tmp_att_masks, state) - self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, opt=opt) + self.done_beams[k] = self.beam_search(state, logprobs, tmp_memory, tmp_att_masks, opt=opt) seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score seqLogprobs[:, k] = self.done_beams[k][0]['logps'] # return the samples and their log likelihoods return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) - def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + def _sample_(self, fc_feats, att_feats, att_masks=None, opt={}): sample_max = opt.get('sample_max', 1) beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) @@ -436,6 +451,10 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): if beam_size > 1: return self._sample_beam(fc_feats, att_feats, att_masks, opt) + if sample_max: + with torch.no_grad(): + seq_, seqLogprobs_ = self._sample_(fc_feats, att_feats, att_masks, opt) + batch_size = att_feats.shape[0] att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks) @@ -466,11 +485,11 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): seq[:,i] = next_word seqLogprobs[:,i] = sampleLogprobs ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) + assert (seq*((seq_>0).long())==seq_).all(), 'seq doens\'t match' + assert (seqLogprobs*((seq_>0).float()) - seqLogprobs_*((seq_>0).float())).abs().max() < 1e-5, 'logprobs doens\'t match' return seq, seqLogprobs - def _sample_(self, fc_feats, att_feats, att_masks=None, opt={}): - att_feats, att_masks = self.clip_att(att_feats, att_masks) - + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): sample_max = opt.get('sample_max', 1) beam_size = opt.get('beam_size', 1) temperature = opt.get('temperature', 1.0) @@ -478,19 +497,30 @@ def _sample_(self, fc_feats, att_feats, att_masks=None, opt={}): if beam_size > 1: return self._sample_beam(fc_feats, att_feats, att_masks, opt) - batch_size = fc_feats.size(0) - state = self.init_hidden(batch_size) + batch_size = att_feats.shape[0] + + att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks) - fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) + state = None + memory = self.model.encode(att_feats, att_masks) + + seq = att_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) + seqLogprobs = att_feats.new_zeros(batch_size, self.seq_length) - # seq = [] - # seqLogprobs = [] - seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) - seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length) for t in range(self.seq_length + 1): if t == 0: # input it = fc_feats.new_zeros(batch_size, dtype=torch.long) - elif sample_max: + + logprobs, state = self.get_logprobs_state(it, memory, att_masks, state) + if decoding_constraint and t > 0: + tmp = output.new_zeros(output.size(0), self.vocab_size + 1) + tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) + logprobs = logprobs + tmp + + # sample the next word + if t == self.seq_length: # skip if we achieve maximum length + break + if sample_max: sampleLogprobs, it = torch.max(logprobs.data, 1) it = it.view(-1).long() else: @@ -503,25 +533,16 @@ def _sample_(self, fc_feats, att_feats, att_masks=None, opt={}): sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions it = it.view(-1).long() # and flatten indices for downstream processing - if t >= 1: - # stop when all finished - if t == 1: - unfinished = it > 0 - else: - unfinished = unfinished * (it > 0) - if unfinished.sum() == 0: - break - it = it * unfinished.type_as(it) - seq[:,t-1] = it - # seq.append(it) #seq[t] the input of t+2 time step - - # seqLogprobs.append(sampleLogprobs.view(-1)) - seqLogprobs[:,t-1] = sampleLogprobs.view(-1) - - logprobs, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) - if decoding_constraint and t > 0: - tmp = output.new_zeros(output.size(0), self.vocab_size + 1) - tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) - logprobs = logprobs + tmp + # stop when all finished + if t == 0: + unfinished = it > 0 + else: + unfinished = unfinished * (it > 0) + it = it * unfinished.type_as(it) + seq[:,t] = it + seqLogprobs[:,t] = sampleLogprobs.view(-1) + # quit loop if all sequences have finished + if unfinished.sum() == 0: + break return seq, seqLogprobs From 97e6cbf6514a6d4b2903ff800abfb4fb86b25764 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Sun, 6 May 2018 20:36:20 -0500 Subject: [PATCH 06/12] Use the original options to represent hyperparameters in transofmer --- models/TransformerModel.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index 55d92abf..ad3ad21d 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -271,10 +271,9 @@ def __init__(self, opt): # d_model = self.input_encoding_size # 512 self.vocab_size = opt.vocab_size - # self.input_encoding_size = opt.input_encoding_size + self.input_encoding_size = opt.input_encoding_size # #self.rnn_type = opt.rnn_type self.rnn_size = opt.rnn_size - self.rnn_size = 512 # self.num_layers = opt.num_layers self.drop_prob_lm = opt.drop_prob_lm self.seq_length = opt.seq_length @@ -294,10 +293,10 @@ def __init__(self, opt): # nn.Dropout(self.drop_prob_lm)) self.att_embed = nn.Sequential(*( ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ - (nn.Linear(self.att_feat_size, self.rnn_size), + (nn.Linear(self.att_feat_size, self.input_encoding_size), nn.ReLU(), nn.Dropout(self.drop_prob_lm))+ - ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ()))) + ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn==2 else ()))) # self.logit_layers = getattr(opt, 'logit_layers', 1) # if self.logit_layers == 1: @@ -308,7 +307,10 @@ def __init__(self, opt): # self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) tgt_vocab = self.vocab_size + 1 - self.model = self.make_model(0, tgt_vocab) + self.model = self.make_model(0, tgt_vocab, + N=opt.num_layers, + d_model=opt.input_encoding_size, + d_ff=opt.rnn_size) # def init_hidden(self, bsz): # weight = next(self.parameters()) From df1688c0bb4ab308bd33f38350c08033eefd7cc8 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Tue, 15 May 2018 22:12:19 -0500 Subject: [PATCH 07/12] Add reduce on plateau. --- misc/utils.py | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ opts.py | 3 +++ train.py | 13 ++++++++++++- 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/misc/utils.py b/misc/utils.py index f4b4fdf0..e0f04b47 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -101,6 +101,10 @@ def set_lr(optimizer, lr): for group in optimizer.param_groups: group['lr'] = lr +def get_lr(optimizer): + for group in optimizer.param_groups: + return group['lr'] + def clip_gradient(optimizer, grad_clip): for group in optimizer.param_groups: for param in group['params']: @@ -152,6 +156,50 @@ def rate(self, step = None): def __getattr__(self, name): return getattr(self.optimizer, name) + +class ReduceLROnPlateau(object): + "Optim wrapper that implements rate." + def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): + self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode, factor, patience, verbose, threshold, threshold_mode, cooldown, min_lr, eps) + self.optimizer = optimizer + self.current_lr = get_lr(optimizer) + + def step(self): + "Update parameters and rate" + self.optimizer.step() + + def scheduler_step(self, val): + self.scheduler.step(val) + self.current_lr = get_lr(self.optimizer) + + def state_dict(self): + return {'current_lr':self.current_lr, + 'scheduler_state_dict': {key: value for key, value in self.scheduler.__dict__.items() if key not in {'optimizer', 'is_better'}}, + 'optimizer_state_dict': self.optimizer.state_dict()} + + def load_state_dict(self, state_dict): + if 'current_lr' not in state_dict: + # it's normal optimizer + self.optimizer.load_state_dict(state_dict) + set_lr(self.optimizer, self.current_lr) # use the lr fromt the option + else: + # it's a schduler + self.current_lr = state_dict['current_lr'] + self.scheduler.__dict__.update(state_dict['scheduler_state_dict']) + self.scheduler._init_is_better(mode=self.scheduler.mode, threshold=self.scheduler.threshold, threshold_mode=self.scheduler.threshold_mode) + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # current_lr is actually useless in this case + + def rate(self, step = None): + "Implement `lrate` above" + if step is None: + step = self._step + return self.factor * \ + (self.model_size ** (-0.5) * + min(step ** (-0.5), step * self.warmup ** (-1.5))) + + def __getattr__(self, name): + return getattr(self.optimizer, name) def get_std_opt(model, factor=1, warmup=2000): # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, diff --git a/opts.py b/opts.py index 2f7123ae..002fbb8d 100644 --- a/opts.py +++ b/opts.py @@ -138,6 +138,9 @@ def parse_opt(): parser.add_argument('--noamopt_factor', type=float, default=1, help='') + parser.add_argument('--reduce_on_plateau', action='store_true', + help='') + args = parser.parse_args() # Check if args are valid diff --git a/train.py b/train.py index d363185c..1a1470ae 100644 --- a/train.py +++ b/train.py @@ -85,6 +85,9 @@ def train(opt): assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) optimizer._step = iteration + elif opt.reduce_on_plateau: + optimizer = utils.build_optimizer(model.parameters(), opt) + optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer @@ -93,7 +96,7 @@ def train(opt): while True: if epoch_done: - if not opt.noamopt: + if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every @@ -161,6 +164,8 @@ def train(opt): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) if opt.noamopt: opt.current_lr = optimizer.rate() + elif opt.reduce_on_plateau: + opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration) if sc_flag: @@ -178,6 +183,12 @@ def train(opt): eval_kwargs.update(vars(opt)) val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs) + if opt.reduce_on_plateau: + if 'CIDEr' in lang_stats: + optimizer.scheduler_step(-lang_stats['CIDEr']) + else: + optimizer.scheduler_step(val_loss) + # Write validation result into summary add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration) for k,v in lang_stats.items(): From f0b12c6a7f83762c599adbc61623419d2f932911 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Fri, 18 May 2018 14:03:53 -0500 Subject: [PATCH 08/12] [0.5]Match the new arange behaior. --- models/TransformerModel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index ad3ad21d..8a7fbf59 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -228,8 +228,8 @@ def __init__(self, d_model, dropout, max_len=5000): # Compute the positional encodings once in log space. pe = torch.zeros(max_len, d_model) - position = torch.arange(0, max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) From ab5f36492b45aef7f8061c379d5485866f108876 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Thu, 31 May 2018 01:12:12 -0500 Subject: [PATCH 09/12] Update readme for transformer --- README.md | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/README.md b/README.md index 6ecc9d99..fb50d8d9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,42 @@ +# Transformer for captioning + +This is an experiment to use transformer model to do captioning. Most of the code is copy from [Harvard detailed tutorial for transformer(http://nlp.seas.harvard.edu/2018/04/03/attention.html). + +Also, notice, this repository is a fork of my [self-critical.pytorch](https://github.com/ruotianluo/self-critical.pytorch) repository. Most of the code are shared. + +The addition to self-critical.pytorch is following: +- transformer model +- Add warmup adam for training transformer (important) +- Add reduce_on_paltaeu (not really useful) + +A training script that could achieve 1.25 on validation set without beam search. + +```bash +id="transformer" +ckpt_path="log_"$id +if [ ! -d $ckpt_path ]; then + mkdir $ckpt_path +fi +if [ ! -f $ckpt_path"/infos_"$id".pkl" ]; then +start_from="" +else +start_from="--start_from "$ckpt_path +fi + +python train.py --id $id --caption_model transformer --noamopt --noamopt_warmup 20000 --label_smoothing 0.0 --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 5e-4 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --learning_rate_decay_start 0 --scheduled_sampling_start 0 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --max_epochs 15 + +python train.py --id $id --caption_model transformer --reduce_on_plateau --input_json data/cocotalk.json --input_label_h5 data/cocotalk_label.h5 --input_fc_dir data/cocobu_fc --input_att_dir data/cocobu_att --input_box_dir data/cocobu_box --seq_per_img 5 --batch_size 10 --beam_size 1 --learning_rate 1e-5 --num_layers 6 --input_encoding_size 512 --rnn_size 2048 --checkpoint_path $ckpt_path $start_from --save_checkpoint_every 3000 --language_eval 1 --val_images_use 5000 --self_critical_after 10 +``` + +**Notice**: because I'm to lazy, I reuse the option name for RNNs to set the hyperparameters for transformer: +``` +N=num_layers +d_model=input_encoding_size +d_ff=rnn_size +h is always 8 +``` + + # Self-critical Sequence Training for Image Captioning (+ misc.) This repository includes the unofficial implementation [Self-critical Sequence Training for Image Captioning](https://arxiv.org/abs/1612.00563) and [Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering](https://arxiv.org/abs/1707.07998). From f7930976d877a179c24b941fbe9bbcc145374d91 Mon Sep 17 00:00:00 2001 From: "Ruotian(RT) Luo" Date: Mon, 15 Oct 2018 02:09:01 -0500 Subject: [PATCH 10/12] Fix #1 --- models/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/__init__.py b/models/__init__.py index 8d329bff..91e1faa8 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -12,7 +12,6 @@ from .ShowTellModel import ShowTellModel from .FCModel import FCModel from .OldModel import ShowAttendTellModel, AllImgModel -from .Att2inModel import Att2inModel from .AttModel import * from .TransformerModel import TransformerModel @@ -58,4 +57,4 @@ def setup(opt): assert os.path.isfile(os.path.join(opt.start_from,"infos_"+opt.id+".pkl")),"infos.pkl file does not exist in path %s"%opt.start_from model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth'))) - return model \ No newline at end of file + return model From 068697ecef10ec6ce68e0a81bc509633a2e9d442 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Mon, 8 Apr 2019 01:27:17 -0500 Subject: [PATCH 11/12] clean up the code, and make transformer as a subclass of AttModel. --- models/TransformerModel.py | 261 ++++--------------------------------- 1 file changed, 28 insertions(+), 233 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index 8a7fbf59..2b8e06bc 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -1,17 +1,11 @@ -# This file contains Att2in2, AdaAtt, AdaAttMO, TopDown model +# This file contains Transformer network +# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html -# AdaAtt is from Knowing When to Look: Adaptive Attention via A Visual Sentinel for Image Captioning -# https://arxiv.org/abs/1612.01887 -# AdaAttMO is a modified version with maxout lstm - -# Att2in is from Self-critical Sequence Training for Image Captioning -# https://arxiv.org/abs/1612.00563 -# In this file we only have Att2in2, which is a slightly different version of att2in, -# in which the img feature embedding and word embedding is the same as what in adaatt. - -# TopDown is from Bottom-Up and Top-Down Attention for Image Captioning and VQA -# https://arxiv.org/abs/1707.07998 -# However, it may not be identical to the author's architecture. +# The cfg name correspondance: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size +# h is always 8 from __future__ import absolute_import from __future__ import division @@ -27,7 +21,7 @@ import numpy as np from .CaptionModel import CaptionModel -from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper +from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel class EncoderDecoder(nn.Module): """ @@ -240,7 +234,7 @@ def forward(self, x): x = x + self.pe[:, :x.size(1)] return self.dropout(x) -class TransformerModel(CaptionModel): +class TransformerModel(AttModel): def make_model(self, src_vocab, tgt_vocab, N=6, d_model=512, d_ff=2048, h=8, dropout=0.1): @@ -265,46 +259,25 @@ def make_model(self, src_vocab, tgt_vocab, N=6, return model def __init__(self, opt): - super(TransformerModel, self).__init__() + super(TransformerModel, self).__init__(opt) self.opt = opt # self.config = yaml.load(open(opt.config_file)) # d_model = self.input_encoding_size # 512 - self.vocab_size = opt.vocab_size - self.input_encoding_size = opt.input_encoding_size - # #self.rnn_type = opt.rnn_type - self.rnn_size = opt.rnn_size - # self.num_layers = opt.num_layers - self.drop_prob_lm = opt.drop_prob_lm - self.seq_length = opt.seq_length - # self.fc_feat_size = opt.fc_feat_size - self.att_feat_size = opt.att_feat_size - # self.att_hid_size = opt.att_hid_size - - self.use_bn = getattr(opt, 'use_bn', 0) - - self.ss_prob = 0.0 # Schedule sampling probability - - # self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size), - # nn.ReLU(), - # nn.Dropout(self.drop_prob_lm)) - # self.fc_embed = nn.Sequential(nn.Linear(self.fc_feat_size, self.rnn_size), - # nn.ReLU(), - # nn.Dropout(self.drop_prob_lm)) + delattr(self, 'att_embed') self.att_embed = nn.Sequential(*( ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ (nn.Linear(self.att_feat_size, self.input_encoding_size), nn.ReLU(), nn.Dropout(self.drop_prob_lm))+ ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn==2 else ()))) - - # self.logit_layers = getattr(opt, 'logit_layers', 1) - # if self.logit_layers == 1: - # self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1) - # else: - # self.logit = [[nn.Linear(self.rnn_size, self.rnn_size), nn.ReLU(), nn.Dropout(0.5)] for _ in range(opt.logit_layers - 1)] - # self.logit = nn.Sequential(*(reduce(lambda x,y:x+y, self.logit) + [nn.Linear(self.rnn_size, self.vocab_size + 1)])) - # self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size) + + delattr(self, 'embed') + self.embed = lambda x : x + delattr(self, 'fc_embed') + self.fc_embed = lambda x : x + del self.logit + del self.ctx2att tgt_vocab = self.vocab_size + 1 self.model = self.make_model(0, tgt_vocab, @@ -312,31 +285,17 @@ def __init__(self, opt): d_model=opt.input_encoding_size, d_ff=opt.rnn_size) - # def init_hidden(self, bsz): - # weight = next(self.parameters()) - # return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), - # weight.new_zeros(self.num_layers, bsz, self.rnn_size)) + def init_hidden(self, bsz): + return None - def clip_att(self, att_feats, att_masks): - # Clip the length of att_masks and att_feats to the maximum length - if att_masks is not None: - max_len = att_masks.data.long().sum(1).max() - att_feats = att_feats[:, :max_len].contiguous() - att_masks = att_masks[:, :max_len].contiguous() - return att_feats, att_masks + def _prepare_feature(self, fc_feats, att_feats, att_masks): - # def _prepare_feature(self, fc_feats, att_feats, att_masks): - - # # embed fc and att feats - # fc_feats = self.fc_embed(fc_feats) - # att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) - - # # Project the attention feats first to reduce memory and computation comsumptions. - # p_att_feats = self.ctx2att(att_feats) + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) + memory = self.model.encode(att_feats, att_masks) - # return fc_feats, att_feats, p_att_feats + return fc_feats[...,:1], att_feats[...,:1], memory, att_masks - def _prepare_feature(self, att_feats, att_masks=None, seq=None): + def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): att_feats, att_masks = self.clip_att(att_feats, att_masks) att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) @@ -359,47 +318,15 @@ def _prepare_feature(self, att_feats, att_masks=None, seq=None): return att_feats, seq, att_masks, seq_mask def _forward(self, fc_feats, att_feats, seq, att_masks=None): - att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks, seq) + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) out = self.model(att_feats, seq, att_masks, seq_mask) - # batch_size = fc_feats.size(0) - # state = self.init_hidden(batch_size) - - # # outputs = [] - # outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1) - - # fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks) - - # for i in range(seq.size(1) - 1): - # if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample - # sample_prob = fc_feats.new(batch_size).uniform_(0, 1) - # sample_mask = sample_prob < self.ss_prob - # if sample_mask.sum() == 0: - # it = seq[:, i].clone() - # else: - # sample_ind = sample_mask.nonzero().view(-1) - # it = seq[:, i].data.clone() - # #prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) # fetch prev distribution: shape Nx(M+1) - # #it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1)) - # # prob_prev = torch.exp(outputs[-1].data) # fetch prev distribution: shape Nx(M+1) - # prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1) - # it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind)) - # else: - # it = seq[:, i].clone() - # # break if all the sequences end - # if i >= 1 and seq[:, i].sum() == 0: - # break - - # output, state = self.get_logprobs_state(it, fc_feats, att_feats, p_att_feats, att_masks, state) - # outputs[:, i] = output - # # outputs.append(output) - outputs = self.model.generator(out) return outputs # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) - def get_logprobs_state(self, it, memory, mask, state): + def get_logprobs_state(self, it, fc_feats_ph, att_feats_ph, memory, mask, state): """ state = [ys.unsqueeze(0)] """ @@ -415,136 +342,4 @@ def get_logprobs_state(self, it, memory, mask, state): return logprobs, [ys.unsqueeze(0)] - def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): - beam_size = opt.get('beam_size', 10) - batch_size = fc_feats.size(0) - - att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks) - memory = self.model.encode(att_feats, att_masks) - - assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' - seq = torch.LongTensor(self.seq_length, batch_size).zero_() - seqLogprobs = torch.FloatTensor(self.seq_length, batch_size) - # lets process every image independently for now, for simplicity - - self.done_beams = [[] for _ in range(batch_size)] - for k in range(batch_size): - state = None - tmp_memory = memory[k:k+1].expand(*((beam_size,)+memory.size()[1:])).contiguous() - tmp_att_masks = att_masks[k:k+1].expand(*((beam_size,)+att_masks.size()[1:])).contiguous() if att_masks is not None else None - - for t in range(1): - if t == 0: # input - it = fc_feats.new_zeros([beam_size], dtype=torch.long) - - logprobs, state = self.get_logprobs_state(it, tmp_memory, tmp_att_masks, state) - - self.done_beams[k] = self.beam_search(state, logprobs, tmp_memory, tmp_att_masks, opt=opt) - seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score - seqLogprobs[:, k] = self.done_beams[k][0]['logps'] - # return the samples and their log likelihoods - return seq.transpose(0, 1), seqLogprobs.transpose(0, 1) - - def _sample_(self, fc_feats, att_feats, att_masks=None, opt={}): - sample_max = opt.get('sample_max', 1) - beam_size = opt.get('beam_size', 1) - temperature = opt.get('temperature', 1.0) - decoding_constraint = opt.get('decoding_constraint', 0) - if beam_size > 1: - return self._sample_beam(fc_feats, att_feats, att_masks, opt) - - if sample_max: - with torch.no_grad(): - seq_, seqLogprobs_ = self._sample_(fc_feats, att_feats, att_masks, opt) - - batch_size = att_feats.shape[0] - - att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks) - - memory = self.model.encode(att_feats, att_masks) - ys = torch.zeros((batch_size, 1), dtype=torch.long).to(att_feats.device) - - seq = att_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) - seqLogprobs = att_feats.new_zeros(batch_size, self.seq_length) - - for i in range(self.seq_length): - out = self.model.decode(memory, att_masks, - ys, - subsequent_mask(ys.size(1)) - .to(att_feats.device)) - logprob = self.model.generator(out[:, -1]) - if sample_max: - sampleLogprobs, next_word = torch.max(logprob, dim = 1) - else: - if temperature == 1.0: - prob_prev = torch.exp(logprob.data) # fetch prev distribution: shape Nx(M+1) - else: - # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprob.data, temperature)) - next_word = torch.multinomial(prob_prev, 1) - sampleLogprobs = logprobs.gather(1, next_word) # gather the logprobs at sampled positions - - seq[:,i] = next_word - seqLogprobs[:,i] = sampleLogprobs - ys = torch.cat([ys, next_word.unsqueeze(1)], dim=1) - assert (seq*((seq_>0).long())==seq_).all(), 'seq doens\'t match' - assert (seqLogprobs*((seq_>0).float()) - seqLogprobs_*((seq_>0).float())).abs().max() < 1e-5, 'logprobs doens\'t match' - return seq, seqLogprobs - - def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): - sample_max = opt.get('sample_max', 1) - beam_size = opt.get('beam_size', 1) - temperature = opt.get('temperature', 1.0) - decoding_constraint = opt.get('decoding_constraint', 0) - if beam_size > 1: - return self._sample_beam(fc_feats, att_feats, att_masks, opt) - - batch_size = att_feats.shape[0] - - att_feats, seq, att_masks, seq_mask = self._prepare_feature(att_feats, att_masks) - - state = None - memory = self.model.encode(att_feats, att_masks) - - seq = att_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long) - seqLogprobs = att_feats.new_zeros(batch_size, self.seq_length) - - for t in range(self.seq_length + 1): - if t == 0: # input - it = fc_feats.new_zeros(batch_size, dtype=torch.long) - - logprobs, state = self.get_logprobs_state(it, memory, att_masks, state) - if decoding_constraint and t > 0: - tmp = output.new_zeros(output.size(0), self.vocab_size + 1) - tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf')) - logprobs = logprobs + tmp - - # sample the next word - if t == self.seq_length: # skip if we achieve maximum length - break - if sample_max: - sampleLogprobs, it = torch.max(logprobs.data, 1) - it = it.view(-1).long() - else: - if temperature == 1.0: - prob_prev = torch.exp(logprobs.data) # fetch prev distribution: shape Nx(M+1) - else: - # scale logprobs by temperature - prob_prev = torch.exp(torch.div(logprobs.data, temperature)) - it = torch.multinomial(prob_prev, 1) - sampleLogprobs = logprobs.gather(1, it) # gather the logprobs at sampled positions - it = it.view(-1).long() # and flatten indices for downstream processing - - # stop when all finished - if t == 0: - unfinished = it > 0 - else: - unfinished = unfinished * (it > 0) - it = it * unfinished.type_as(it) - seq[:,t] = it - seqLogprobs[:,t] = sampleLogprobs.view(-1) - # quit loop if all sequences have finished - if unfinished.sum() == 0: - break - - return seq, seqLogprobs +# For _sample and _sample_beam, now p_att_feats = memory \ No newline at end of file From c8af7482731f9a7ac28df23b3e9436b35be4d212 Mon Sep 17 00:00:00 2001 From: Ruotian Luo Date: Tue, 9 Apr 2019 15:38:04 -0500 Subject: [PATCH 12/12] Further towards AttModel: remove the use of get_logprobs --- models/TransformerModel.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/models/TransformerModel.py b/models/TransformerModel.py index 2b8e06bc..0670157f 100644 --- a/models/TransformerModel.py +++ b/models/TransformerModel.py @@ -276,7 +276,7 @@ def __init__(self, opt): self.embed = lambda x : x delattr(self, 'fc_embed') self.fc_embed = lambda x : x - del self.logit + delattr(self, 'logit') del self.ctx2att tgt_vocab = self.vocab_size + 1 @@ -285,6 +285,9 @@ def __init__(self, opt): d_model=opt.input_encoding_size, d_ff=opt.rnn_size) + def logit(self, x): # unsafe way + return self.model.generator.proj(x) + def init_hidden(self, bsz): return None @@ -326,7 +329,7 @@ def _forward(self, fc_feats, att_feats, seq, att_masks=None): return outputs # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) - def get_logprobs_state(self, it, fc_feats_ph, att_feats_ph, memory, mask, state): + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): """ state = [ys.unsqueeze(0)] """ @@ -338,8 +341,4 @@ def get_logprobs_state(self, it, fc_feats_ph, att_feats_ph, memory, mask, state) ys, subsequent_mask(ys.size(1)) .to(memory.device)) - logprobs = self.model.generator(out[:, -1]) - - return logprobs, [ys.unsqueeze(0)] - -# For _sample and _sample_beam, now p_att_feats = memory \ No newline at end of file + return out[:, -1], [ys.unsqueeze(0)] \ No newline at end of file