From 7b604ab0d17b58f5623846ece1fa3ed846fbb82d Mon Sep 17 00:00:00 2001 From: cgwang Date: Tue, 27 Mar 2018 11:24:42 -0700 Subject: [PATCH] Update language model (#19) * udpate lmexp * udpate word_language_model.py * udpate word_language_model.py * update with lm_decay * update word_language_model.py with new updated hiddensize, standardrnn exchange tied and dropout; update base.py with rnn_relu config; update lm.py with awd_lstm_lm_1150 pretrained setting, and with new sentiment analysis and lm example * remove lm test file --- example/gluon/word_language_model.py | 12 +++++++----- python/mxnet/gluon/model_zoo/text/base.py | 2 +- python/mxnet/gluon/model_zoo/text/lm.py | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/example/gluon/word_language_model.py b/example/gluon/word_language_model.py index 382f0b17ff5f..7dd8c697c69e 100644 --- a/example/gluon/word_language_model.py +++ b/example/gluon/word_language_model.py @@ -22,7 +22,7 @@ import mxnet as mx from mxnet import gluon, autograd from mxnet.gluon import data, text -from mxnet.gluon.model_zoo.text.lm import SimpleRNN, AWDRNN +from mxnet.gluon.model_zoo.text.lm import StandardRNN, AWDRNN parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.') parser.add_argument('--model', type=str, default='lstm', @@ -68,6 +68,8 @@ help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. (the result of multi-gpu training might be slightly different compared to single-gpu training, still need to be finalized)') args = parser.parse_args() +print(args) + ############################################################################### # Load data @@ -82,7 +84,7 @@ def get_frequencies(dataset): return collections.Counter(x for tup in dataset for x in tup[0] if x) -vocab = text.vocab.Vocabulary(get_frequencies(train_dataset)) +vocab = text.vocab.Vocabulary(get_frequencies(train_dataset), reserved_tokens=['', '']) def index_tokens(data, label): return vocab[data], vocab[label] @@ -124,8 +126,8 @@ def index_tokens(data, label): model = AWDRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, args.tied, args.dropout, args.weight_dropout, args.dropout_h, args.dropout_i) else: - model = SimpleRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, - args.tied, args.dropout) + model = StandardRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, args.dropout, + args.tied) model.initialize(mx.init.Xavier(), ctx=context) @@ -169,7 +171,7 @@ def train(): for epoch in range(args.epochs): total_L = 0.0 start_epoch_time = time.time() - hiddens = [model.begin_state(args.batch_size, func=mx.nd.zeros, ctx=ctx) for ctx in context] + hiddens = [model.begin_state(args.batch_size//len(context), func=mx.nd.zeros, ctx=ctx) for ctx in context] for i, (data, target) in enumerate(train_data): start_batch_time = time.time() data = data.T diff --git a/python/mxnet/gluon/model_zoo/text/base.py b/python/mxnet/gluon/model_zoo/text/base.py index 6f5f5557b880..7d84bdeda73e 100644 --- a/python/mxnet/gluon/model_zoo/text/base.py +++ b/python/mxnet/gluon/model_zoo/text/base.py @@ -118,7 +118,7 @@ def get_rnn_cell(mode, num_layers, input_size, hidden_size, def get_rnn_layer(mode, num_layers, input_size, hidden_size, dropout, weight_dropout): """create rnn layer given specs""" if mode == 'rnn_relu': - block = rnn.RNN(hidden_size, 'relu', num_layers, dropout=dropout, + block = rnn.RNN(hidden_size, num_layers, 'relu', dropout=dropout, input_size=input_size) elif mode == 'rnn_tanh': block = rnn.RNN(hidden_size, num_layers, dropout=dropout, diff --git a/python/mxnet/gluon/model_zoo/text/lm.py b/python/mxnet/gluon/model_zoo/text/lm.py index 7607cf5bf13c..34f060935ca1 100644 --- a/python/mxnet/gluon/model_zoo/text/lm.py +++ b/python/mxnet/gluon/model_zoo/text/lm.py @@ -209,7 +209,7 @@ def awd_lstm_lm_1150(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(), 'tie_weights': True, 'dropout': 0.4, 'weight_drop': 0.5, - 'drop_h': 0.3, + 'drop_h': 0.2, 'drop_i': 0.65} assert all(k not in kwargs for k in predefined_args), \ "Cannot override predefined model settings."