Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fixing var-seq-len rnn backward() operator (#15278)
Browse files Browse the repository at this point in the history
* fixing var-seq-len rnn backward() operator

* updating var-length lstm to test backward pass

* removing bit of dbg print to stderr i forgot to remove earlier

* resolving TODO about using int32 for sequence_length

* setting rtol and atol similar to other tests in this file
  • Loading branch information
stephenrawls authored and szha committed Jun 20, 2019
1 parent 145f82d commit 4d96671
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
18 changes: 15 additions & 3 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1583,8 +1583,11 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
int dtype = in_types[rnn_enum::kData];
int itype = dtype;
if (param.use_sequence_length) {
itype = in_types[rnn_enum::kSequenceLength];
if (param.mode == rnn_enum::kLstm) itype -= 1;
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param.mode != rnn_enum::kLstm) {
seq_len_input_idx -= 1;
}
itype = in_types[seq_len_input_idx];
}

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
Expand Down Expand Up @@ -1649,7 +1652,7 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
// Hacky. This relies on fact that seq-len type is either the last input,
// or we aren't using seq-len input and this type should be same as dtype.
// Would prefer direct access to RNNParam object here but not sure how to get.
int itype = inputs[inputs.size()-1].type_flag_;
int itype = outputs[outputs.size()-1].type_flag_;

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
MSHADOW_TYPE_SWITCH(itype, IType, {
Expand All @@ -1669,6 +1672,15 @@ void RNNStatefulGradCompute(const OpStatePtr& state,
}
}


if (param.use_sequence_length) {
size_t seq_len_input_idx = rnn_enum::kSequenceLength;
if (param.mode != rnn_enum::kLstm) {
seq_len_input_idx -= 1;
}
in_data.push_back(outputs[seq_len_input_idx]);
}

op.Backward(ctx, out_grad, in_data, out_data, req, in_grad);
});
});
Expand Down
58 changes: 36 additions & 22 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,19 +227,6 @@ def forward(self, inpt):


def check_layer_bidirectional_varseqlen(size, in_size):
class RefBiLSTMVarSeqLen(gluon.Block):
def __init__(self, size, **kwargs):
super(RefBiLSTMVarSeqLen, self).__init__(**kwargs)
with self.name_scope():
self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0')
self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0')

def forward(self, inpt, sequence_length):
fwd = self._lstm_fwd(inpt)
bwd_inpt = nd.SequenceReverse(inpt, sequence_length=sequence_length, use_sequence_length=True)
bwd = self._lstm_bwd(bwd_inpt)
bwd = nd.SequenceReverse(bwd, sequence_length=sequence_length, use_sequence_length=True)
return nd.concat(fwd, bwd, dim=2)
weights = {}
for d in ['l', 'r']:
weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size))
Expand All @@ -248,31 +235,58 @@ def forward(self, inpt, sequence_length):
weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,))

net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=True, prefix='lstm_')
ref_net = RefBiLSTMVarSeqLen(size, prefix='lstm_')
ref_net = gluon.rnn.LSTM(size, bidirectional=True, use_sequence_length=False, prefix='lstm_ref_')
net.initialize()
ref_net.initialize()
net_params = net.collect_params()
ref_net_params = ref_net.collect_params()
for k in weights:
net_params[k].set_data(weights[k])
ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k])

ref_net_params[k.replace("lstm_", "lstm_ref_")].set_data(weights[k])

batch_size = 10
num_timesteps = 11
data = mx.random.uniform(shape=(num_timesteps, batch_size, in_size))
data_np = data.asnumpy()

# TODO: figure out why int32 doesn't work here
sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("float")

net_output = net(data, sequence_length=sequence_length).asnumpy()
ref_net_output = ref_net(data, sequence_length).asnumpy()
sequence_length = nd.random.randint(1, num_timesteps+1, shape=(batch_size)).astype("int32")
sequence_length_np = sequence_length.asnumpy().astype("int32")

# Reference net is processing batch elements one at a time, so that it is "perfectly sized"
# Because of that, we need to accumulate gradients in reference net.
for p in ref_net.collect_params().values():
p.grad_req = 'add'

ref_net_output = []
with autograd.record():
net_output = net(data.copy(), sequence_length=sequence_length.copy())

for b in range(batch_size):
data_slice = mx.nd.array(data_np[:sequence_length_np[b], b, :]).reshape(sequence_length_np[b], 1, in_size)
ref_output_slice = ref_net(data_slice)
ref_net_output.append(ref_output_slice)

net_output_np = net_output.asnumpy()

# TODO: test state return value as well output
# Only compare the valid sections for each batch entry
for b in range(batch_size):
assert_allclose(net_output[:sequence_length_np[b], b], ref_net_output[:sequence_length_np[b], b])
assert_allclose(net_output_np[:sequence_length_np[b], b], ref_net_output[b].asnumpy().squeeze(1),
rtol=1e-2, atol=1e-6)

# Now test backward
net_output.backward()

for ref_output_slice in ref_net_output:
ref_output_slice.backward()

ref_net_params = ref_net.collect_params()

for k in weights:
net_grad = net_params[k].grad()
ref_net_grad = ref_net_params[k.replace('lstm_', 'lstm_ref_')].grad()
assert_almost_equal(net_grad.asnumpy(), ref_net_grad.asnumpy(),
rtol=1e-2, atol=1e-6)


@with_seed()
Expand Down

0 comments on commit 4d96671

Please sign in to comment.