diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index 1046f01cf6e2..328e28de8537 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -1583,8 +1583,11 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs, int dtype = in_types[rnn_enum::kData]; int itype = dtype; if (param.use_sequence_length) { - itype = in_types[rnn_enum::kSequenceLength]; - if (param.mode == rnn_enum::kLstm) itype -= 1; + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param.mode != rnn_enum::kLstm) { + seq_len_input_idx -= 1; + } + itype = in_types[seq_len_input_idx]; } MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { @@ -1649,7 +1652,7 @@ void RNNStatefulGradCompute(const OpStatePtr& state, // Hacky. This relies on fact that seq-len type is either the last input, // or we aren't using seq-len input and this type should be same as dtype. // Would prefer direct access to RNNParam object here but not sure how to get. - int itype = inputs[inputs.size()-1].type_flag_; + int itype = outputs[outputs.size()-1].type_flag_; MSHADOW_REAL_TYPE_SWITCH(dtype, DType, { MSHADOW_TYPE_SWITCH(itype, IType, { @@ -1669,6 +1672,15 @@ void RNNStatefulGradCompute(const OpStatePtr& state, } } + + if (param.use_sequence_length) { + size_t seq_len_input_idx = rnn_enum::kSequenceLength; + if (param.mode != rnn_enum::kLstm) { + seq_len_input_idx -= 1; + } + in_data.push_back(outputs[seq_len_input_idx]); + } + op.Backward(ctx, out_grad, in_data, out_data, req, in_grad); }); }); diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index b60814a47a81..fc650294a538 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -227,19 +227,6 @@ def forward(self, inpt): def check_layer_bidirectional_varseqlen(size, in_size): - class RefBiLSTMVarSeqLen(gluon.Block): - def __init__(self, size, **kwargs): - super(RefBiLSTMVarSeqLen, self).__init__(**kwargs) - with self.name_scope(): - self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0') - self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0') - - def forward(self, inpt, sequence_length): - fwd = self._lstm_fwd(inpt) - bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True) - bwd = self._lstm_bwd(bwd_inpt) - bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True) - return nd.concat(fwd, bwd, dim=2) weights = {} for d in ['l', 'r']: weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) @@ -248,31 +235,58 @@ def forward(self, inpt, sequence_length): weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_') - ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_') + ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False, prefix='lstm_ref_') net.initialize() ref_net.initialize() net_params = net.collect_params() ref_net_params = ref_net.collect_params() for k in weights: net_params[k].set_data(weights[k]) - ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k]) - + ref_net_params[k.replace("lstm_", "lstm_ref_")].set_data(weights[k]) batch_size = 10 num_timesteps = 11 data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size)) + data_np = data.asnumpy() - # TODO: figure out why int32 doesn't work here - sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float") - - net_output = net(data, sequence_length=sequence_length).asnumpy() - ref_net_output = ref_net(data, sequence_length).asnumpy() + sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("int32") sequence_length_np = sequence_length.asnumpy().astype("int32") + # Reference net is processing batch elements one at a time, so that it is "perfectly sized" + # Because of that, we need to accumulate gradients in reference net. + for p in ref_net.collect_params().values(): + p.grad_req = 'add' + + ref_net_output = [] + with autograd.record(): + net_output = net(data.copy(), sequence_length=sequence_length.copy()) + + for b in range(batch_size): + data_slice = mx.nd.array(data_np[:sequence_length_np[b], b, :]).reshape(sequence_length_np[b], 1, in_size) + ref_output_slice = ref_net(data_slice) + ref_net_output.append(ref_output_slice) + + net_output_np = net_output.asnumpy() + # TODO: test state return value as well output # Only compare the valid sections for each batch entry for b in range(batch_size): - assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b]) + assert_allclose(net_output_np[:sequence_length_np[b], b], ref_net_output[b].asnumpy().squeeze(1), + rtol=1e-2, atol=1e-6) + + # Now test backward + net_output.backward() + + for ref_output_slice in ref_net_output: + ref_output_slice.backward() + + ref_net_params = ref_net.collect_params() + + for k in weights: + net_grad = net_params[k].grad() + ref_net_grad = ref_net_params[k.replace('lstm_', 'lstm_ref_')].grad() + assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(), + rtol=1e-2, atol=1e-6) @with_seed()