Skip to content

Commit

Permalink
Make the code compatible with PyTorch 0.4.1 (#9)
Browse files Browse the repository at this point in the history
* make main comptiable with torch 4.1

* clean

* change finetune

* change [0] to item() in finetune

* turn off dropout during training

* remove warnings

* fix some bugs
  • Loading branch information
billy-inn authored and Zhilin Yang committed Oct 22, 2018
1 parent 4c43dee commit 6f89b28
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 45 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,5 @@ ENV/

# mypy
.mypy_cache/

data/*
18 changes: 10 additions & 8 deletions dynamiceval.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,15 @@ def batchify(data, bsz):

def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
else:
if isinstance(h, tuple) or isinstance(h, list):
return tuple(repackage_hidden(v) for v in h)
else:
return h.detach()


def get_batch(source, i, evaluation=False):
def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
data = Variable(source[i:i+seq_len])
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data, target

Expand All @@ -95,7 +95,8 @@ def gradstat():

while i < train_data.size(0) - 1 - 1:
seq_len = args.bptt
model.eval()
model.train()
model.use_dropout = False

data, targets = get_batch(train_data, i)
hidden = repackage_hidden(hidden)
Expand Down Expand Up @@ -161,7 +162,8 @@ def evaluate():
#loops through data
while i < eval_data.size(0) - 1 - 1:

model.eval()
model.train()
model.use_dropout = False
#gets last chunk of seqlence if seqlen doesn't divide full sequence cleanly
if (i+seq_len)>=eval_data.size(0):
if last:
Expand Down Expand Up @@ -231,4 +233,4 @@ def evaluate():
print('running dynamic evaluation')
#apply dynamic evaluation
loss = evaluate()
print('perplexity loss: ' + str(loss[0]))
print('perplexity loss: ' + str(loss))
3 changes: 2 additions & 1 deletion embed_regularize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

import torch
import torch.nn.functional as F
from torch.autograd import Variable

def embedded_dropout(embed, words, dropout=0.1, scale=None):
Expand All @@ -16,7 +17,7 @@ def embedded_dropout(embed, words, dropout=0.1, scale=None):
padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
X = embed._backend.Embedding.apply(words, masked_embed_weight,
X = F.embedding(words, masked_embed_weight,
padding_idx, embed.max_norm, embed.norm_type,
embed.scale_grad_by_freq, embed.sparse
)
Expand Down
25 changes: 13 additions & 12 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,17 @@ def evaluate(data_source, batch_size=10):
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args, evaluation=True)
targets = targets.view(-1)

log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data

total_loss += len(data) * loss
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args)
targets = targets.view(-1)

log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data

total_loss += len(data) * loss
hidden = repackage_hidden(hidden)
return total_loss.item() / len(data_source)

def train():
assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size'
Expand Down Expand Up @@ -210,13 +211,13 @@ def train():
end = start + args.small_batch_size

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

# total_loss += raw_loss.data
optimizer.param_groups[0]['lr'] = lr2
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
cur_loss = total_loss.item() / args.log_interval
elapsed = time.time() - start_time
logging('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
Expand Down
21 changes: 11 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,17 +157,18 @@ def evaluate(data_source, batch_size=10):
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args, evaluation=True)
targets = targets.view(-1)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args)
targets = targets.view(-1)

log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data
log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data

total_loss += loss * len(data)
total_loss += loss * len(data)

hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
hidden = repackage_hidden(hidden)
return total_loss.item() / len(data_source)


def train():
Expand Down Expand Up @@ -220,13 +221,13 @@ def train():
gc.collect()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

# total_loss += raw_loss.data
optimizer.param_groups[0]['lr'] = lr2
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
cur_loss = total_loss.item() / args.log_interval
elapsed = time.time() - start_time
logging('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
Expand Down
18 changes: 9 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nhidlast, nlayers,
dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0,
tie_weights=False, ldropout=0.5, n_experts=10):
super(RNNModel, self).__init__()
self.use_dropout = True
self.lockdrop = LockedDropout()
self.encoder = nn.Embedding(ntoken, ninp)

self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else nhidlast, 1, dropout=0) for l in range(nlayers)]
if wdrop:
self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns]
self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop if self.use_dropout else 0) for rnn in self.rnns]
self.rnns = torch.nn.ModuleList(self.rnns)

self.prior = nn.Linear(nhidlast, n_experts, bias=False)
Expand Down Expand Up @@ -68,10 +69,10 @@ def init_weights(self):
def forward(self, input, hidden, return_h=False, return_prob=False):
batch_size = input.size(1)

emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if (self.training and self.use_dropout) else 0)
#emb = self.idrop(emb)

emb = self.lockdrop(emb, self.dropouti)
emb = self.lockdrop(emb, self.dropouti if self.use_dropout else 0)

raw_output = emb
new_hidden = []
Expand All @@ -85,21 +86,21 @@ def forward(self, input, hidden, return_h=False, return_prob=False):
raw_outputs.append(raw_output)
if l != self.nlayers - 1:
#self.hdrop(raw_output)
raw_output = self.lockdrop(raw_output, self.dropouth)
raw_output = self.lockdrop(raw_output, self.dropouth if self.use_dropout else 0)
outputs.append(raw_output)
hidden = new_hidden

output = self.lockdrop(raw_output, self.dropout)
output = self.lockdrop(raw_output, self.dropout if self.use_dropout else 0)
outputs.append(output)

latent = self.latent(output)
latent = self.lockdrop(latent, self.dropoutl)
latent = self.lockdrop(latent, self.dropoutl if self.use_dropout else 0)
logit = self.decoder(latent.view(-1, self.ninp))

prior_logit = self.prior(output).contiguous().view(-1, self.n_experts)
prior = nn.functional.softmax(prior_logit)
prior = nn.functional.softmax(prior_logit, -1)

prob = nn.functional.softmax(logit.view(-1, self.ntoken)).view(-1, self.n_experts, self.ntoken)
prob = nn.functional.softmax(logit.view(-1, self.ntoken), -1).view(-1, self.n_experts, self.ntoken)
prob = (prob * prior.unsqueeze(2).expand_as(prob)).sum(1)

if return_prob:
Expand Down Expand Up @@ -129,4 +130,3 @@ def init_hidden(self, bsz):
# input = Variable(torch.LongTensor(13, 9).random_(0, 10))
# hidden = model.init_hidden(9)
# print(model.sample(input, hidden, 5, 6, 1, 2, sample_latent=True).size())

10 changes: 5 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
else:
if isinstance(h, tuple) or isinstance(h, list):
return tuple(repackage_hidden(v) for v in h)
else:
return h.detach()

def batchify(data, bsz, args):
# Work out how cleanly we can divide the dataset into bsz parts.
Expand All @@ -21,9 +21,9 @@ def batchify(data, bsz, args):
data = data.cuda()
return data

def get_batch(source, i, args, seq_len=None, evaluation=False):
def get_batch(source, i, args, seq_len=None):
seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
data = Variable(source[i:i+seq_len])
# target = Variable(source[i+1:i+1+seq_len].view(-1))
target = Variable(source[i+1:i+1+seq_len])
return data, target
Expand Down

0 comments on commit 6f89b28

Please sign in to comment.