diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 4daace2fb35c..f06d63e62392 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -670,13 +670,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape, << "Static array size=" << ndim << " is not equal to data shape ndim=" << dshape.ndim(); - if (param_step.ndim() != 0) { + if (param_step.ndim() > 0) { CHECK_EQ(param_step.ndim(), param_begin.ndim()) << "step and begin must have the same length"; } for (int i = 0; i < param_begin.ndim(); ++i) { - index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1; + index_t s = param_step.ndim() > 0 && param_step[i].has_value() ? param_step[i].value() : 1; CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0"; index_t b = 0, e = 0; @@ -685,30 +685,23 @@ inline void GetIndexRange(const mxnet::TShape& dshape, b = param_begin[i].has_value() ? param_begin[i].value() : (s < 0 ? len - 1 : 0); e = param_end[i].has_value() ? param_end[i].value() : (s < 0 ? -1 : len); - // checking upper and lower bounds for begin if (b < 0) { b += len; - CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len - << " exceeds limit of input dimension[" << i << "]=" << len; } - CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b - << " exceeds limit of input dimension[" << i << "]=" << len; - - // checking upper and lower bounds for end if (e < 0 && param_end[i].has_value()) { - if (!(s < 0 && e == -1)) { - // Keep end=-1 as one-beyond-limits index for negative stride - e += len; - } - CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len - << " exceeds limit of input dimension[" << i << "]=" << len; + e += len; } - CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e - << " exceeds limit of input dimension[" << i << "]=" << len; - // checking begin==end case which is not supported - CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]=" - << e << " results in an empty tensor and is not supported"; + // move the begin and end to correct position for calculating dim size + b = (b < 0 && s > 0) ? 0 : b; + b = (b > len - 1 && s < 0) ? len - 1 : b; + // if the start value lead to empty tensor under step s, use -1 for indication + b = (b < 0 || b > len - 1) ? -1 : b; + e = e > -1 ? e : -1; + e = e > len ? len : e; + } else if (len == 0) { + b = 0; + e = 0; } (*begin)[i] = b; @@ -716,27 +709,30 @@ inline void GetIndexRange(const mxnet::TShape& dshape, (*step)[i] = s; } - for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) { + for (int i = param_begin.ndim(); i < dshape.ndim(); ++i) { (*begin)[i] = 0; (*end)[i] = dshape[i]; (*step)[i] = 1; } } -inline void SetSliceOpOutputDimSize(const index_t i, const int b, +inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape, + const index_t i, const int b, const int e, const int s, mxnet::TShape* oshape) { - if (e != b) { + if (!mxnet::dim_size_is_known(dshape, i)) { + (*oshape)[i] = -1; + return; + } + if (e != b && b >= 0) { if (s > 0) { - CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]=" - << e << ", and step[" << i << "]=" << s << " is invalid"; - (*oshape)[i] = (e - b - 1) / s + 1; + (*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0; } else { - CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]=" - << e << ", and step[" << i << "]=" << s << " is invalid"; - (*oshape)[i] = (b - e - 1) / (-s) + 1; + (*oshape)[i] = e < b ? (b - e - 1) / (-s) + 1 : 0; } - } // else leave oshape[i] as 0 for partial infer + } else { + (*oshape)[i] = 0; + } } inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, @@ -746,6 +742,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& dshape = (*in_attrs)[0]; if (!mxnet::ndim_is_known(dshape)) return false; + CHECK_GT(dshape.ndim(), 0) << "slice only works for ndim > 0"; const SliceParam& param = nnvm::get(attrs.parsed); mxnet::TShape oshape = dshape; @@ -754,12 +751,12 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &oshape); + SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape); } }) SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return shape_is_known(oshape); + return shape_is_known(dshape) && shape_is_known(oshape); } template @@ -837,6 +834,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs, Stream* s = ctx.get_stream(); const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; + if (out.Size() == 0) return; const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { common::StaticArray begin, end, step; @@ -936,6 +934,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs, } else if (req[0] == kWriteInplace) { LOG(FATAL) << "_slice_backward does not support kWriteInplace"; } + if (ograd.Size() == 0) return; MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { common::StaticArray begin, end, step; GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step); @@ -967,7 +966,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs, GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &vshape); + SetSliceOpOutputDimSize(dshape, i, b, e, s, &vshape); } }) SHAPE_ASSIGN_CHECK(*in_attrs, 1, vshape); @@ -1106,7 +1105,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs, GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); for (index_t i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &vshape); + SetSliceOpOutputDimSize(data.shape_, i, b, e, s, &vshape); } MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { mxnet_op::Kernel, xpu>::Launch(s, vshape.FlatTo2D()[0], diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index bff76bc6bbb0..db7644edc4a5 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -506,6 +506,7 @@ Example:: [5., 7.], [1., 3.]] )code" ADD_FILELINE) +.add_alias("_npx_slice") .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceOpShape) .set_attr("FInferType", ElemwiseType<1, 1>) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index fee5ebbbbc29..eaf7932970ee 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7336,15 +7336,6 @@ def test_slice_forward_backward(a, index): for index in index_list: test_slice_forward_backward(arr, index) - def test_begin_equals_end(shape, begin, end, step): - in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape) - out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step) - - assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,)) - assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1)) - assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2)) - assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1)) - # check numeric gradient in_data = np.arange(36).reshape(2, 2, 3, 3) data = mx.sym.Variable('data')