Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-107]Fused GRU implementation for CPU #10311

Merged
merged 63 commits into from
Jun 6, 2018
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
1e4aaf2
Add GRU Support and Test Case
Mar 29, 2018
87de652
skip the gpu test case that has nothing to do with RNN GRU
Mar 29, 2018
2b5b43d
fix robust bug for gru backward
Mar 30, 2018
54c64bc
fix bug for unifying weight parameter
Mar 30, 2018
f375c89
add GRU multiple layer and bidirection support with test case
Apr 12, 2018
6719685
fix test case bug
Apr 12, 2018
1ab7869
fix test case bug
Apr 12, 2018
f6ae0d1
fix bug for memory issue
Apr 12, 2018
4e11dc6
fix bug for bidirection
Apr 12, 2018
817fc30
rebase code and fix bug for memory corruption issue
Apr 13, 2018
e0a61cb
fix gpu compile issue
Apr 13, 2018
3ceaa00
fix bug and enable some test cases
Apr 27, 2018
2e7cbb0
fix robust bug
May 4, 2018
1b2288b
trigger the build to check if quantize-gpu case is covered
May 4, 2018
4f10a01
trigger the build to check if MKLDNN+GPU case is covered
May 4, 2018
18f7c5f
Merge pull request #1 from apache/master
May 4, 2018
e271184
disable failed gpu test case of MKLDNN_UTIL_FUNC-MemFormat because it…
May 4, 2018
646766c
skip failed test_reduce test case temporarily as it has nothing to do…
May 4, 2018
be9de01
enable several test cases
May 4, 2018
21e8978
retrigger the build
May 7, 2018
0ae12a2
rebase code from lstm
May 15, 2018
67c1434
rebase code for resolve conflict
May 15, 2018
a1c84eb
add gru code after resolve conflict
May 15, 2018
2ed2e0f
fix bug for resolve conflict
May 15, 2018
f60ef60
Merge pull request #3 from apache/master
May 15, 2018
42c729a
add Fused GRU code with test case
May 15, 2018
89d7326
retrigger the build
May 15, 2018
a06fecf
add GetRecommendedOMPThreadCount for omp
May 15, 2018
fc15942
fix conflict issue
May 16, 2018
a1d1713
Merge pull request #4 from apache/master
May 16, 2018
396fe19
add gru relate code
May 16, 2018
bac611f
fix bug for code
May 16, 2018
1daf4a1
update code for gru
May 16, 2018
759f6d1
retrigger the build
May 16, 2018
90414fd
fix code about gru condition
May 17, 2018
066b7b9
enhance test case to test gradient weights and bias
May 17, 2018
360cda9
fix bug for test case
May 17, 2018
7ea1c28
fix bug for test case
May 17, 2018
b793910
fix bug about dropout condition and test case
May 17, 2018
4655895
Merge pull request #5 from apache/master
May 17, 2018
320bc73
fix bug for test case
May 17, 2018
2d5c270
fix bug for test case
May 17, 2018
3042b95
retrigger the build
May 17, 2018
da2094c
rebase code
May 20, 2018
d7fc335
Merge pull request #6 from apache/master
May 20, 2018
ddebe95
add gru code
May 20, 2018
5f031fc
fix issues about namespace, removing define and memcpy
May 22, 2018
66dc9f7
retrigger the build
May 22, 2018
0bc9585
fix issues and add cudnn_gru_bucketing.py test case
May 25, 2018
6f25c26
retrigger the build
May 25, 2018
f1c43ef
Merge pull request #9 from apache/master
May 25, 2018
7336cc3
update cudnn_rnn_bucketing.py test case
Jun 1, 2018
33060ee
update cudnn_rnn_bucketing.py test case
Jun 1, 2018
0c580df
update cudnn_rnn_bucketing.py test case
Jun 1, 2018
41a1382
add check for req[kParams] and kAddTo from cudnn_rnn-inl.h
Jun 1, 2018
9173088
retrigger the build
Jun 1, 2018
bab3ced
retrigger the build
Jun 1, 2018
146ac33
Merge pull request #10 from apache/master
Jun 1, 2018
242ed83
retrigger the build
Jun 2, 2018
4dfb758
add kNullOp check
Jun 3, 2018
8bd9909
retrigger the build
Jun 3, 2018
daf5a86
update kNullOp support and test case for both GRU and LSTM
Jun 5, 2018
27ebb4f
update kAddToOp support for both GRU and LSTM
Jun 6, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
help='stack fused RNN cells to reduce communication overhead')
parser.add_argument('--dropout', type=float, default='0.0',
help='dropout probability (1.0 - keep probability)')
parser.add_argument('--rnntype', type=str, default='lstm',
help='rnn type: gru and lstm are supported')

#buckets = [32]
buckets = [10, 20, 30, 40, 50, 60]
Expand Down Expand Up @@ -97,13 +99,13 @@ def train(args):
cell = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
cell.add(mx.rnn.FusedRNNCell(args.num_hidden, num_layers=1,
mode='lstm', prefix='lstm_l%d'%i,
mode=args.rnntype, prefix='%s_l%d'%(args.rnntype,i),
bidirectional=args.bidirectional))
if args.dropout > 0 and i < args.num_layers - 1:
cell.add(mx.rnn.DropoutCell(args.dropout, prefix='lstm_d%d'%i))
if args.dropout > 0 and i < args.num_layers - 1 and args.rnntype == 'lstm':
cell.add(mx.rnn.DropoutCell(args.dropout, prefix='%s_d%d'%(args.rnntype,i)))
else:
cell = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers, dropout=args.dropout,
mode='lstm', bidirectional=args.bidirectional)
mode=args.rnntype, bidirectional=args.bidirectional)

def sym_gen(seq_len):
data = mx.sym.Variable('data')
Expand Down Expand Up @@ -168,16 +170,25 @@ def test(args):

if not args.stack_rnn:
stack = mx.rnn.FusedRNNCell(args.num_hidden, num_layers=args.num_layers,
mode='lstm', bidirectional=args.bidirectional).unfuse()
mode=args.rnntype, bidirectional=args.bidirectional).unfuse()
else:
stack = mx.rnn.SequentialRNNCell()
for i in range(args.num_layers):
cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dl0_'%i)
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='lstm_%dr0_'%i),
output_prefix='bi_lstm_%d'%i)
if args.rnntype == 'lstm':
cell = mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.LSTMCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))
elif args.rnntype == 'gru':
cell = mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dl0_'%(args.rnntype,i))
if args.bidirectional:
cell = mx.rnn.BidirectionalCell(
cell,
mx.rnn.GRUCell(num_hidden=args.num_hidden, prefix='%s_%dr0_'%(args.rnntype,i)),
output_prefix='bi_%s_%d'%(args.rnntype,i))

stack.add(cell)

def sym_gen(seq_len):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def forward(self, inputs, states=None):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu' or \
self._mode == 'lstm' and not self._dropout:
self._mode in ['lstm', 'gru'] and not self._dropout:
out = self._forward_kernel(inputs, states)
else:
out = self._forward(inputs, states)
Expand Down
46 changes: 35 additions & 11 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ inline size_t GetRNNWorkspaceSize(int seq_length,
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = (seq_length + 1) * batch_size * hidden_size * 4 + batch_size * hidden_size * 2
+ seq_length * batch_size * hidden_size * direction;
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * 4 + batch_size * hidden_size * 8;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand All @@ -125,12 +127,16 @@ inline size_t GetRNNReserveSpaceSize(int num_layer,
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
size = num_layer * direction * seq_length * batch_size * hidden_size * 6;
break;
case rnn_enum::kGru:
size = seq_length * batch_size * hidden_size * direction * num_layer * 8 +
batch_size * hidden_size * direction * 9 +
seq_length * batch_size * 7 * hidden_size * direction;
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand Down Expand Up @@ -221,14 +227,18 @@ void RNNForwardTraining(DType* ws,
switch (mode) {
case rnn_enum::kRnnTanh:
case rnn_enum::kRnnRelu:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardTraining<DType>(ws, rs, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
break;
default:
LOG(FATAL) << "unknown RNN mode " << mode;
break;
Expand Down Expand Up @@ -256,14 +266,18 @@ void RNNForwardInference(DType* ws,
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
LOG(FATAL) << "Only LSTM is supported at the moment";
LOG(FATAL) << "Only LSTM and GRU are supported at the moment";
break;
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr,
w_ptr, y_ptr, hy_ptr);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
break;
Expand Down Expand Up @@ -296,13 +310,17 @@ void RNNBackward(DType* ws,
switch (mode) {
case rnn_enum::kRnnRelu:
case rnn_enum::kRnnTanh:
case rnn_enum::kGru:
break;
case rnn_enum::kLstm:
LstmBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, cx_ptr, w_ptr, y_ptr,
dy_ptr, dhy_ptr, dcy_ptr, dx_ptr, dhx_ptr, dcx_ptr, dw_ptr, db_ptr);
break;
case rnn_enum::kGru:
GruBackward<DType>(ws, rs, num_layers, direction, seq_length, batch_size,
input_size, state_size, x_ptr, hx_ptr, w_ptr,
dy_ptr, dhy_ptr, dx_ptr, dhx_ptr, dw_ptr);
break;
default:
LOG(FATAL) << "unknown RNN mode" << mode;
break;
Expand Down Expand Up @@ -330,7 +348,8 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
Expand Down Expand Up @@ -442,8 +461,10 @@ class RNNOp : public Operator{
const std::vector<TBlob> &aux_args) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(param_.mode, rnn_enum::kLstm) << "Only lstm mode is supported at the moment.";
CHECK(param_.mode == rnn_enum::kLstm || param_.mode == rnn_enum::kGru)
<< "Only lstm and gru mode are supported at the moment.";
CHECK_EQ(param_.p, 0) << "Dropout is not supported at the moment.";

size_t in_expected = (param_.mode == rnn_enum::kLstm) ? 4 : 3;
size_t out_expected = (param_.mode == rnn_enum::kLstm) ? 3 : 2;
if (!param_.state_outputs) {
Expand Down Expand Up @@ -474,6 +495,9 @@ class RNNOp : public Operator{
CHECK(dw.CheckContiguous());
CHECK(dhx.CheckContiguous());
CHECK(dy.CheckContiguous());
if (req[rnn_enum::kParams] != kAddTo && req[rnn_enum::kParams] != kNullOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't gradient still going to be overwritten by backward kernel later?

dw = mshadow::expr::ScalarExp<DType>(0.0f);
}
param_.seq_length_ = x.shape_[0];
param_.batch_size_ = x.shape_[1];
param_.input_size_ = x.shape_[2];
Expand Down
Loading