diff --git a/src/operator/sequence_last-inl.h b/src/operator/sequence_last-inl.h index 167f04c8165f..2f386dd820e8 100644 --- a/src/operator/sequence_last-inl.h +++ b/src/operator/sequence_last-inl.h @@ -153,6 +153,10 @@ class SequenceLastOp : public Operator { auto d1 = in_data[seq_last::kData].size(1); auto dsize = in_data[seq_last::kData].Size(); + if (dsize == 0) { + return; // noop if any input dimension is zero-sized, out_data is of a right shape + } + auto batch = (axis != 0) ? d0 : d1; auto max_seq_len = in_data[seq_last::kData].size(axis); auto rest_size = dsize / (d0 * d1); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7a8536474211..4394d80a7584 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9428,3 +9428,11 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads, test_sldwin_atten_op_impl(2, 128, 2, 8, 16, symmetric, d) test_sldwin_atten_op_impl(1, 8, 2, 4, 2, symmetric, d) +def test_zero_sized_dim(): + """Test for issue: https://github.com/apache/incubator-mxnet/issues/18938""" + mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown' + data = mx.nd.array(np.random.rand(1, 0, 0)) + res = mx.nd.op.SequenceLast(data) + assert data.shape[1:] == res.shape + assert len(res) == 0 +