Skip to content

Commit

Permalink
Merge branch 'transformer' into master
Browse files Browse the repository at this point in the history
* transformer:
  Further towards AttModel: remove the use of get_logprobs
  clean up the code, and make transformer as a subclass of AttModel.
  Fix #1
  Update readme for transformer
  [0.5]Match the new arange behaior.
  Add reduce on plateau.
  Use the original options to represent hyperparameters in transofmer
  Update to previous framework of sample and sample_beam
  uncomment the original sample.
  formatting and remove variable.
  Add noamopt options.
  Add transformer

# Conflicts:
#	misc/utils.py
#	models/__init__.py
#	train.py
  • Loading branch information
ruotianluo committed Apr 10, 2019
2 parents a70d024 + 662135d commit e87275d
Show file tree
Hide file tree
Showing 6 changed files with 456 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ This is based on my [ImageCaptioning.pytorch](https://github.com/ruotianluo/Imag
- Bottom up feature support from [ref](https://arxiv.org/abs/1707.07998). (Evaluation on arbitrary images is not supported.)
- Ensemble
- Multi-GPU training
- Add transformer (merged from [Transformer_captioning](https://github.com/ruotianluo/Transformer_Captioning))

## Requirements
Python 2.7 (because there is no [coco-caption](https://github.com/tylin/coco-caption) version for python 3)
Expand Down
78 changes: 78 additions & 0 deletions misc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,38 @@ def forward(self, input, target, mask):

return output

class LabelSmoothing(nn.Module):
"Implement label smoothing."
def __init__(self, size=0, padding_idx=0, smoothing=0.0):
super(LabelSmoothing, self).__init__()
self.criterion = nn.KLDivLoss(size_average=False, reduce=False)
# self.padding_idx = padding_idx
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
# self.size = size
self.true_dist = None

def forward(self, input, target, mask):
# truncate to the same size
target = target[:, :input.size(1)]
mask = mask[:, :input.size(1)]

input = to_contiguous(input).view(-1, input.size(-1))
target = to_contiguous(target).view(-1)
mask = to_contiguous(mask).view(-1)

# assert x.size(1) == self.size
self.size = input.size(1)
# true_dist = x.data.clone()
true_dist = input.data.clone()
# true_dist.fill_(self.smoothing / (self.size - 2))
true_dist.fill_(self.smoothing / (self.size - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
# true_dist[:, self.padding_idx] = 0
# mask = torch.nonzero(target.data == self.padding_idx)
# self.true_dist = true_dist
return (self.criterion(input, true_dist).sum(1) * mask).sum() / mask.sum()

def set_lr(optimizer, lr):
for group in optimizer.param_groups:
group['lr'] = lr
Expand Down Expand Up @@ -164,6 +196,37 @@ def length_average(length, logprobs, alpha=0.):
"""
return logprobs / length


class NoamOpt(object):
"Optim wrapper that implements rate."
def __init__(self, model_size, factor, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.factor = factor
self.model_size = model_size
self._rate = 0

def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()

def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * \
(self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))

def __getattr__(self, name):
return getattr(self.optimizer, name)

class ReduceLROnPlateau(object):
"Optim wrapper that implements rate."
def __init__(self, optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08):
Expand Down Expand Up @@ -195,6 +258,21 @@ def load_state_dict(self, state_dict):
self.scheduler.load_state_dict(state_dict['scheduler_state_dict'])
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
# current_lr is actually useless in this case

def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return self.factor * \
(self.model_size ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))

def __getattr__(self, name):
return getattr(self.optimizer, name)

def get_std_opt(model, factor=1, warmup=2000):
# return NoamOpt(model.tgt_embed[0].d_model, 2, 4000,
# torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
return NoamOpt(model.model.tgt_embed[0].d_model, factor, warmup,
torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))

Loading

0 comments on commit e87275d

Please sign in to comment.