Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the code compatible with PyTorch 0.4.1 #9

Merged
merged 7 commits into from
Oct 22, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,5 @@ ENV/

# mypy
.mypy_cache/

data/*
10 changes: 6 additions & 4 deletions dynamiceval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def repackage_hidden(h):
return tuple(repackage_hidden(v) for v in h)


def get_batch(source, i, evaluation=False):
def get_batch(source, i):
seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
data = Variable(source[i:i+seq_len])
target = Variable(source[i+1:i+1+seq_len].view(-1))
return data, target

Expand All @@ -95,7 +95,8 @@ def gradstat():

while i < train_data.size(0) - 1 - 1:
seq_len = args.bptt
model.eval()
model.train()
model.use_dropout = False
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we turn off dropout?


data, targets = get_batch(train_data, i)
hidden = repackage_hidden(hidden)
Expand Down Expand Up @@ -161,7 +162,8 @@ def evaluate():
#loops through data
while i < eval_data.size(0) - 1 - 1:

model.eval()
model.train()
model.use_dropout = False
#gets last chunk of seqlence if seqlen doesn't divide full sequence cleanly
if (i+seq_len)>=eval_data.size(0):
if last:
Expand Down
3 changes: 2 additions & 1 deletion embed_regularize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np

import torch
import torch.nn.functional as F
from torch.autograd import Variable

def embedded_dropout(embed, words, dropout=0.1, scale=None):
Expand All @@ -16,7 +17,7 @@ def embedded_dropout(embed, words, dropout=0.1, scale=None):
padding_idx = embed.padding_idx
if padding_idx is None:
padding_idx = -1
X = embed._backend.Embedding.apply(words, masked_embed_weight,
X = F.embedding(words, masked_embed_weight,
padding_idx, embed.max_norm, embed.norm_type,
embed.scale_grad_by_freq, embed.sparse
)
Expand Down
25 changes: 13 additions & 12 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,17 @@ def evaluate(data_source, batch_size=10):
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args, evaluation=True)
targets = targets.view(-1)

log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data

total_loss += len(data) * loss
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args)
targets = targets.view(-1)

log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data

total_loss += len(data) * loss
hidden = repackage_hidden(hidden)
return total_loss.item() / len(data_source)

def train():
assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size'
Expand Down Expand Up @@ -210,13 +211,13 @@ def train():
end = start + args.small_batch_size

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

# total_loss += raw_loss.data
optimizer.param_groups[0]['lr'] = lr2
if batch % args.log_interval == 0 and batch > 0:
cur_loss = total_loss[0] / args.log_interval
cur_loss = total_loss.item() / args.log_interval
elapsed = time.time() - start_time
logging('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
Expand Down
17 changes: 9 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,17 @@ def evaluate(data_source, batch_size=10):
total_loss = 0
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(batch_size)
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args, evaluation=True)
targets = targets.view(-1)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, args.bptt):
data, targets = get_batch(data_source, i, args)
targets = targets.view(-1)

log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data
log_prob, hidden = parallel_model(data, hidden)
loss = nn.functional.nll_loss(log_prob.view(-1, log_prob.size(2)), targets).data

total_loss += loss * len(data)
total_loss += loss * len(data)

hidden = repackage_hidden(hidden)
hidden = repackage_hidden(hidden)
return total_loss[0] / len(data_source)


Expand Down Expand Up @@ -220,7 +221,7 @@ def train():
gc.collect()

# `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()

# total_loss += raw_loss.data
Expand Down
18 changes: 9 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nhidlast, nlayers,
dropout=0.5, dropouth=0.5, dropouti=0.5, dropoute=0.1, wdrop=0,
tie_weights=False, ldropout=0.5, n_experts=10):
super(RNNModel, self).__init__()
self.use_dropout = True
self.lockdrop = LockedDropout()
self.encoder = nn.Embedding(ntoken, ninp)

self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else nhidlast, 1, dropout=0) for l in range(nlayers)]
if wdrop:
self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns]
self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop if self.use_dropout else 0) for rnn in self.rnns]
self.rnns = torch.nn.ModuleList(self.rnns)

self.prior = nn.Linear(nhidlast, n_experts, bias=False)
Expand Down Expand Up @@ -68,10 +69,10 @@ def init_weights(self):
def forward(self, input, hidden, return_h=False, return_prob=False):
batch_size = input.size(1)

emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)
emb = embedded_dropout(self.encoder, input, dropout=self.dropoute if (self.training and self.use_dropout) else 0)
#emb = self.idrop(emb)

emb = self.lockdrop(emb, self.dropouti)
emb = self.lockdrop(emb, self.dropouti if self.use_dropout else 0)

raw_output = emb
new_hidden = []
Expand All @@ -85,21 +86,21 @@ def forward(self, input, hidden, return_h=False, return_prob=False):
raw_outputs.append(raw_output)
if l != self.nlayers - 1:
#self.hdrop(raw_output)
raw_output = self.lockdrop(raw_output, self.dropouth)
raw_output = self.lockdrop(raw_output, self.dropouth if self.use_dropout else 0)
outputs.append(raw_output)
hidden = new_hidden

output = self.lockdrop(raw_output, self.dropout)
output = self.lockdrop(raw_output, self.dropout if self.use_dropout else 0)
outputs.append(output)

latent = self.latent(output)
latent = self.lockdrop(latent, self.dropoutl)
latent = self.lockdrop(latent, self.dropoutl if self.use_dropout else 0)
logit = self.decoder(latent.view(-1, self.ninp))

prior_logit = self.prior(output).contiguous().view(-1, self.n_experts)
prior = nn.functional.softmax(prior_logit)
prior = nn.functional.softmax(prior_logit, -1)

prob = nn.functional.softmax(logit.view(-1, self.ntoken)).view(-1, self.n_experts, self.ntoken)
prob = nn.functional.softmax(logit.view(-1, self.ntoken), -1).view(-1, self.n_experts, self.ntoken)
prob = (prob * prior.unsqueeze(2).expand_as(prob)).sum(1)

if return_prob:
Expand Down Expand Up @@ -129,4 +130,3 @@ def init_hidden(self, bsz):
# input = Variable(torch.LongTensor(13, 9).random_(0, 10))
# hidden = model.init_hidden(9)
# print(model.sample(input, hidden, 5, 6, 1, 2, sample_latent=True).size())

10 changes: 5 additions & 5 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

def repackage_hidden(h):
"""Wraps hidden states in new Variables, to detach them from their history."""
if type(h) == Variable:
return Variable(h.data)
else:
if isinstance(h, tuple) or isinstance(h, list):
return tuple(repackage_hidden(v) for v in h)
else:
return h.detach()

def batchify(data, bsz, args):
# Work out how cleanly we can divide the dataset into bsz parts.
Expand All @@ -21,9 +21,9 @@ def batchify(data, bsz, args):
data = data.cuda()
return data

def get_batch(source, i, args, seq_len=None, evaluation=False):
def get_batch(source, i, args, seq_len=None):
seq_len = min(seq_len if seq_len else args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
data = Variable(source[i:i+seq_len])
# target = Variable(source[i+1:i+1+seq_len].view(-1))
target = Variable(source[i+1:i+1+seq_len])
return data, target
Expand Down