Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
ulmasov committed Feb 7, 2021
1 parent f8492e8 commit f6f44f8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 5 deletions.
5 changes: 5 additions & 0 deletions src/operator/sequence_mask-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class SequenceMaskOp : public Operator {
auto d0 = in_data[seq_mask::kData].size(0);
auto d1 = in_data[seq_mask::kData].size(1);
auto dsize = in_data[seq_mask::kData].Size();

if (dsize == 0) {
return; // noop if any input dimension is zero-sized, out_data is of a right shape
}

auto rest_size = dsize / (d0 * d1);

Shape<3> s3 = Shape3(d0, d1, rest_size);
Expand Down
5 changes: 5 additions & 0 deletions src/operator/sequence_reverse-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ class SequenceReverseOp : public Operator {
auto max_seq_len = in_data[seq_reverse::kData].size(0);
auto n = in_data[seq_reverse::kData].size(1);
auto total_size = in_data[seq_reverse::kData].Size();

if (total_size == 0) {
return; // noop if any input dimension is zero-sized, out_data is of a right shape
}

auto rest_dim = static_cast<int>(total_size / n / max_seq_len);

Shape<3> s3 = Shape3(max_seq_len, n, rest_dim);
Expand Down
27 changes: 22 additions & 5 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9429,10 +9429,27 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,
test_sldwin_atten_op_impl(1, 8, 2, 4, 2, symmetric, d)

def test_zero_sized_dim():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18938"""

mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown'
data = mx.nd.array(np.random.rand(1, 0, 0))
res = mx.nd.op.SequenceLast(data)
assert data.shape[1:] == res.shape
assert len(res) == 0

def seq_last():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18938"""
data = mx.nd.array(np.random.rand(1, 0, 0))
res = mx.nd.op.SequenceLast(data)
assert data.shape[1:] == res.shape

def seq_mask():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18939"""
data = mx.nd.array(np.random.rand(0, 1, 1))
res = mx.nd.op.SequenceMask(data)
assert data.shape == res.shape

def seq_reverse():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18940"""
data = mx.nd.array(np.random.rand(0, 1, 1))
res = mx.nd.op.SequenceReverse(data)
assert data.shape == res.shape

seq_last()
seq_reverse()
seq_mask()

0 comments on commit f6f44f8

Please sign in to comment.