diff --git a/README.md b/README.md index 6ecc9d99..329a9015 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/Imag - Bottom up feature support from [ref](https://arxiv.org/abs/1707.07998). (Evaluation on arbitrary images is not supported.) - Ensemble - Multi-GPU training +- Add transformer (merged from [Transformer_captioning](https://github.com/ruotianluo/Transformer_Captioning)) ## Requirements Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3) diff --git a/misc/utils.py b/misc/utils.py index b2e208ad..eca2efe2 100644 --- a/misc/utils.py +++ b/misc/utils.py @@ -108,6 +108,38 @@ def forward(self, input, target, mask): return output +class LabelSmoothing(nn.Module): + "Implement label smoothing." + def __init__(self, size=0, padding_idx=0, smoothing=0.0): + super(LabelSmoothing, self).__init__() + self.criterion = nn.KLDivLoss(size_average=False, reduce=False) + # self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + # self.size = size + self.true_dist = None + + def forward(self, input, target, mask): + # truncate to the same size + target = target[:, :input.size(1)] + mask = mask[:, :input.size(1)] + + input = to_contiguous(input).view(-1, input.size(-1)) + target = to_contiguous(target).view(-1) + mask = to_contiguous(mask).view(-1) + + # assert x.size(1) == self.size + self.size = input.size(1) + # true_dist = x.data.clone() + true_dist = input.data.clone() + # true_dist.fill_(self.smoothing / (self.size - 2)) + true_dist.fill_(self.smoothing / (self.size - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + # true_dist[:, self.padding_idx] = 0 + # mask = torch.nonzero(target.data == self.padding_idx) + # self.true_dist = true_dist + return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum() + def set_lr(optimizer, lr): for group in optimizer.param_groups: group['lr'] = lr @@ -164,6 +196,37 @@ def length_average(length, logprobs, alpha=0.): """ return logprobs / length + +class NoamOpt(object): + "Optim wrapper that implements rate." + def __init__(self, model_size, factor, warmup, optimizer): + self.optimizer = optimizer + self._step = 0 + self.warmup = warmup + self.factor = factor + self.model_size = model_size + self._rate = 0 + + def step(self): + "Update parameters and rate" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p['lr'] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step = None): + "Implement `lrate` above" + if step is None: + step = self._step + return self.factor * \ + (self.model_size ** (-0.5) * + min(step ** (-0.5), step * self.warmup ** (-1.5))) + + def __getattr__(self, name): + return getattr(self.optimizer, name) + class ReduceLROnPlateau(object): "Optim wrapper that implements rate." def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08): @@ -195,6 +258,21 @@ def load_state_dict(self, state_dict): self.scheduler.load_state_dict(state_dict['scheduler_state_dict']) self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) # current_lr is actually useless in this case + + def rate(self, step = None): + "Implement `lrate` above" + if step is None: + step = self._step + return self.factor * \ + (self.model_size ** (-0.5) * + min(step ** (-0.5), step * self.warmup ** (-1.5))) def __getattr__(self, name): return getattr(self.optimizer, name) + +def get_std_opt(model, factor=1, warmup=2000): + # return NoamOpt(model.tgt_embed[0].d_model, 2, 4000, + # torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) + return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup, + torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) + diff --git a/models/TransformerModel.py b/models/TransformerModel.py new file mode 100644 index 00000000..0670157f --- /dev/null +++ b/models/TransformerModel.py @@ -0,0 +1,344 @@ +# This file contains Transformer network +# Most of the code is copied from http://nlp.seas.harvard.edu/2018/04/03/attention.html + +# The cfg name correspondance: +# N=num_layers +# d_model=input_encoding_size +# d_ff=rnn_size +# h is always 8 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +import misc.utils as utils + +import copy +import math +import numpy as np + +from .CaptionModel import CaptionModel +from .AttModel import sort_pack_padded_sequence, pad_unsort_packed_sequence, pack_wrapper, AttModel + +class EncoderDecoder(nn.Module): + """ + A standard Encoder-Decoder architecture. Base for this and many + other models. + """ + def __init__(self, encoder, decoder, src_embed, tgt_embed, generator): + super(EncoderDecoder, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.src_embed = src_embed + self.tgt_embed = tgt_embed + self.generator = generator + + def forward(self, src, tgt, src_mask, tgt_mask): + "Take in and process masked src and target sequences." + return self.decode(self.encode(src, src_mask), src_mask, + tgt, tgt_mask) + + def encode(self, src, src_mask): + return self.encoder(self.src_embed(src), src_mask) + + def decode(self, memory, src_mask, tgt, tgt_mask): + return self.decoder(self.tgt_embed(tgt), memory, src_mask, tgt_mask) + +class Generator(nn.Module): + "Define standard linear + softmax generation step." + def __init__(self, d_model, vocab): + super(Generator, self).__init__() + self.proj = nn.Linear(d_model, vocab) + + def forward(self, x): + return F.log_softmax(self.proj(x), dim=-1) + +def clones(module, N): + "Produce N identical layers." + return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) + +class Encoder(nn.Module): + "Core encoder is a stack of N layers" + def __init__(self, layer, N): + super(Encoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, mask): + "Pass the input (and mask) through each layer in turn." + for layer in self.layers: + x = layer(x, mask) + return self.norm(x) + +class LayerNorm(nn.Module): + "Construct a layernorm module (See citation for details)." + def __init__(self, features, eps=1e-6): + super(LayerNorm, self).__init__() + self.a_2 = nn.Parameter(torch.ones(features)) + self.b_2 = nn.Parameter(torch.zeros(features)) + self.eps = eps + + def forward(self, x): + mean = x.mean(-1, keepdim=True) + std = x.std(-1, keepdim=True) + return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 + +class SublayerConnection(nn.Module): + """ + A residual connection followed by a layer norm. + Note for code simplicity the norm is first as opposed to last. + """ + def __init__(self, size, dropout): + super(SublayerConnection, self).__init__() + self.norm = LayerNorm(size) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, sublayer): + "Apply residual connection to any sublayer with the same size." + return x + self.dropout(sublayer(self.norm(x))) + +class EncoderLayer(nn.Module): + "Encoder is made up of self-attn and feed forward (defined below)" + def __init__(self, size, self_attn, feed_forward, dropout): + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 2) + self.size = size + + def forward(self, x, mask): + "Follow Figure 1 (left) for connections." + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask)) + return self.sublayer[1](x, self.feed_forward) + +class Decoder(nn.Module): + "Generic N layer decoder with masking." + def __init__(self, layer, N): + super(Decoder, self).__init__() + self.layers = clones(layer, N) + self.norm = LayerNorm(layer.size) + + def forward(self, x, memory, src_mask, tgt_mask): + for layer in self.layers: + x = layer(x, memory, src_mask, tgt_mask) + return self.norm(x) + +class DecoderLayer(nn.Module): + "Decoder is made of self-attn, src-attn, and feed forward (defined below)" + def __init__(self, size, self_attn, src_attn, feed_forward, dropout): + super(DecoderLayer, self).__init__() + self.size = size + self.self_attn = self_attn + self.src_attn = src_attn + self.feed_forward = feed_forward + self.sublayer = clones(SublayerConnection(size, dropout), 3) + + def forward(self, x, memory, src_mask, tgt_mask): + "Follow Figure 1 (right) for connections." + m = memory + x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) + x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) + return self.sublayer[2](x, self.feed_forward) + +def subsequent_mask(size): + "Mask out subsequent positions." + attn_shape = (1, size, size) + subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') + return torch.from_numpy(subsequent_mask) == 0 + +def attention(query, key, value, mask=None, dropout=None): + "Compute 'Scaled Dot Product Attention'" + d_k = query.size(-1) + scores = torch.matmul(query, key.transpose(-2, -1)) \ + / math.sqrt(d_k) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e9) + p_attn = F.softmax(scores, dim = -1) + if dropout is not None: + p_attn = dropout(p_attn) + return torch.matmul(p_attn, value), p_attn + +class MultiHeadedAttention(nn.Module): + def __init__(self, h, d_model, dropout=0.1): + "Take in model size and number of heads." + super(MultiHeadedAttention, self).__init__() + assert d_model % h == 0 + # We assume d_v always equals d_k + self.d_k = d_model // h + self.h = h + self.linears = clones(nn.Linear(d_model, d_model), 4) + self.attn = None + self.dropout = nn.Dropout(p=dropout) + + def forward(self, query, key, value, mask=None): + "Implements Figure 2" + if mask is not None: + # Same mask applied to all h heads. + mask = mask.unsqueeze(1) + nbatches = query.size(0) + + # 1) Do all the linear projections in batch from d_model => h x d_k + query, key, value = \ + [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) + for l, x in zip(self.linears, (query, key, value))] + + # 2) Apply attention on all the projected vectors in batch. + x, self.attn = attention(query, key, value, mask=mask, + dropout=self.dropout) + + # 3) "Concat" using a view and apply a final linear. + x = x.transpose(1, 2).contiguous() \ + .view(nbatches, -1, self.h * self.d_k) + return self.linears[-1](x) + +class PositionwiseFeedForward(nn.Module): + "Implements FFN equation." + def __init__(self, d_model, d_ff, dropout=0.1): + super(PositionwiseFeedForward, self).__init__() + self.w_1 = nn.Linear(d_model, d_ff) + self.w_2 = nn.Linear(d_ff, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + return self.w_2(self.dropout(F.relu(self.w_1(x)))) + +class Embeddings(nn.Module): + def __init__(self, d_model, vocab): + super(Embeddings, self).__init__() + self.lut = nn.Embedding(vocab, d_model) + self.d_model = d_model + + def forward(self, x): + return self.lut(x) * math.sqrt(self.d_model) + +class PositionalEncoding(nn.Module): + "Implement the PE function." + def __init__(self, d_model, dropout, max_len=5000): + super(PositionalEncoding, self).__init__() + self.dropout = nn.Dropout(p=dropout) + + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp(torch.arange(0, d_model, 2).float() * + -(math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) + +class TransformerModel(AttModel): + + def make_model(self, src_vocab, tgt_vocab, N=6, + d_model=512, d_ff=2048, h=8, dropout=0.1): + "Helper: Construct a model from hyperparameters." + c = copy.deepcopy + attn = MultiHeadedAttention(h, d_model) + ff = PositionwiseFeedForward(d_model, d_ff, dropout) + position = PositionalEncoding(d_model, dropout) + model = EncoderDecoder( + Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N), + Decoder(DecoderLayer(d_model, c(attn), c(attn), + c(ff), dropout), N), + lambda x:x, # nn.Sequential(Embeddings(d_model, src_vocab), c(position)), + nn.Sequential(Embeddings(d_model, tgt_vocab), c(position)), + Generator(d_model, tgt_vocab)) + + # This was important from their code. + # Initialize parameters with Glorot / fan_avg. + for p in model.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + return model + + def __init__(self, opt): + super(TransformerModel, self).__init__(opt) + self.opt = opt + # self.config = yaml.load(open(opt.config_file)) + # d_model = self.input_encoding_size # 512 + + delattr(self, 'att_embed') + self.att_embed = nn.Sequential(*( + ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+ + (nn.Linear(self.att_feat_size, self.input_encoding_size), + nn.ReLU(), + nn.Dropout(self.drop_prob_lm))+ + ((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn==2 else ()))) + + delattr(self, 'embed') + self.embed = lambda x : x + delattr(self, 'fc_embed') + self.fc_embed = lambda x : x + delattr(self, 'logit') + del self.ctx2att + + tgt_vocab = self.vocab_size + 1 + self.model = self.make_model(0, tgt_vocab, + N=opt.num_layers, + d_model=opt.input_encoding_size, + d_ff=opt.rnn_size) + + def logit(self, x): # unsafe way + return self.model.generator.proj(x) + + def init_hidden(self, bsz): + return None + + def _prepare_feature(self, fc_feats, att_feats, att_masks): + + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks) + memory = self.model.encode(att_feats, att_masks) + + return fc_feats[...,:1], att_feats[...,:1], memory, att_masks + + def _prepare_feature_forward(self, att_feats, att_masks=None, seq=None): + att_feats, att_masks = self.clip_att(att_feats, att_masks) + + att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + + if att_masks is None: + att_masks = att_feats.new_ones(att_feats.shape[:2], dtype=torch.long) + att_masks = att_masks.unsqueeze(-2) + + if seq is not None: + # crop the last one + seq = seq[:,:-1] + seq_mask = (seq.data > 0) + seq_mask[:,0] += 1 + + seq_mask = seq_mask.unsqueeze(-2) + seq_mask = seq_mask & subsequent_mask(seq.size(-1)).to(seq_mask) + else: + seq_mask = None + + return att_feats, seq, att_masks, seq_mask + + def _forward(self, fc_feats, att_feats, seq, att_masks=None): + att_feats, seq, att_masks, seq_mask = self._prepare_feature_forward(att_feats, att_masks, seq) + + out = self.model(att_feats, seq, att_masks, seq_mask) + + outputs = self.model.generator(out) + return outputs + # return torch.cat([_.unsqueeze(1) for _ in outputs], 1) + + def core(self, it, fc_feats_ph, att_feats_ph, memory, state, mask): + """ + state = [ys.unsqueeze(0)] + """ + if state is None: + ys = it.unsqueeze(1) + else: + ys = torch.cat([state[0][0], it.unsqueeze(1)], dim=1) + out = self.model.decode(memory, mask, + ys, + subsequent_mask(ys.size(1)) + .to(memory.device)) + return out[:, -1], [ys.unsqueeze(0)] \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py index 4bf28531..8bd1f28d 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -12,8 +12,8 @@ from .ShowTellModel import ShowTellModel from .FCModel import FCModel from .OldModel import ShowAttendTellModel, AllImgModel -# from .Att2inModel import Att2inModel from .AttModel import * +from .TransformerModel import TransformerModel def setup(opt): if opt.caption_model == 'fc': @@ -47,6 +47,9 @@ def setup(opt): # DenseAtt elif opt.caption_model == 'denseatt': model = DenseAttModel(opt) + # Transformer + elif opt.caption_model == 'transformer': + model = TransformerModel(opt) else: raise Exception("Caption model not supported: {}".format(opt.caption_model)) diff --git a/opts.py b/opts.py index 69c1f605..27087275 100644 --- a/opts.py +++ b/opts.py @@ -25,7 +25,7 @@ def parse_opt(): # Model settings parser.add_argument('--caption_model', type=str, default="show_tell", - help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt') + help='show_tell, show_attend_tell, all_img, fc, att2in, att2in2, att2all2, adaatt, adaattmo, topdown, stackatt, denseatt, transformer') parser.add_argument('--rnn_size', type=int, default=512, help='size of the rnn in number of hidden nodes in each layer') parser.add_argument('--num_layers', type=int, default=1, @@ -100,6 +100,15 @@ def parse_opt(): help='epsilon that goes into denominator for smoothing') parser.add_argument('--weight_decay', type=float, default=0, help='weight_decay') + # Transformer + parser.add_argument('--label_smoothing', type=float, default=0, + help='') + parser.add_argument('--noamopt', action='store_true', + help='') + parser.add_argument('--noamopt_warmup', type=int, default=2000, + help='') + parser.add_argument('--noamopt_factor', type=float, default=1, + help='') parser.add_argument('--reduce_on_plateau', action='store_true', help='') diff --git a/train.py b/train.py index 6125a455..a806f542 100644 --- a/train.py +++ b/train.py @@ -72,16 +72,23 @@ def train(opt): model = models.setup(opt).cuda() dp_model = torch.nn.DataParallel(model) - update_lr_flag = True + epoch_done = True # Assure in training mode dp_model.train() - crit = utils.LanguageModelCriterion() + if opt.label_smoothing > 0: + crit = utils.LabelSmoothing(smoothing=opt.label_smoothing) + else: + crit = utils.LanguageModelCriterion() rl_crit = utils.RewardCriterion() - if opt.reduce_on_plateau: - optimizer = utils.build_optimizer(model.parameters(), opt) - optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) + if opt.noamopt: + assert opt.caption_model == 'transformer', 'noamopt can only work with transformer' + optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup) + optimizer._step = iteration + elif opt.reduce_on_plateau: + optimizer = utils.build_optimizer(model.parameters(), opt) + optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3) else: optimizer = utils.build_optimizer(model.parameters(), opt) # Load the optimizer @@ -89,9 +96,9 @@ def train(opt): optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth'))) while True: - if update_lr_flag: + if epoch_done: + if not opt.noamopt and not opt.reduce_on_plateau: # Assign the learning rate - if not opt.reduce_on_plateau: if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0: frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every decay_factor = opt.learning_rate_decay_rate ** frac @@ -112,7 +119,7 @@ def train(opt): else: sc_flag = False - update_lr_flag = False + epoch_done = False start = time.time() # Load data from train split (0) @@ -151,12 +158,14 @@ def train(opt): iteration += 1 if data['bounds']['wrapped']: epoch += 1 - update_lr_flag = True + epoch_done = True # Write the training loss summary if (iteration % opt.losses_log_every == 0): add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration) - if opt.reduce_on_plateau: + if opt.noamopt: + opt.current_lr = optimizer.rate() + elif opt.reduce_on_plateau: opt.current_lr = optimizer.current_lr add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration) add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)