diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 9668843a61dd..de175cd428fe 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -730,14 +730,17 @@ inline void GetIndexRange(const mxnet::TShape& dshape, CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]=" << e << " results in an empty tensor and is not supported"; } + } else if (len == 0) { + b = 0; + e = 0; } - if (Imperative::Get()->is_np_shape()) { + if (Imperative::Get()->is_np_shape() && len > 0) { // move the begin and end to correct position for calculating dim size b = b < 0 && s > 0 ? 0 : b; - b = b > len-1 && s < 0 ? len-1 : b; + b = b > len - 1 && s < 0 ? len-1 : b; // if the start value lead to empty tensor under step s, use -1 for indication - b = b < 0 || b > len-1 ? -1 : b; + b = b < 0 || b > len - 1 ? -1 : b; e = e > -1 ? e : -1; e = e > len ? len : e; } @@ -753,9 +756,14 @@ inline void GetIndexRange(const mxnet::TShape& dshape, } } -inline void SetSliceOpOutputDimSize(const index_t i, const int b, +inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape, + const index_t i, const int b, const int e, const int s, mxnet::TShape* oshape) { + if (!mxnet::dim_size_is_known(dshape, i)) { + (*oshape)[i] = -1; + return; + } if (!Imperative::Get()->is_np_shape()) { // handle as ndarray if (e != b) { if (s > 0) { @@ -788,6 +796,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& dshape = (*in_attrs)[0]; if (!mxnet::ndim_is_known(dshape)) return false; + CHECK_GT(dshape.ndim(), 0) << "slice only works for ndim > 0"; const SliceParam& param = nnvm::get(attrs.parsed); mxnet::TShape oshape = dshape; @@ -796,12 +805,12 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &oshape); + SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape); } }) SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return shape_is_known(oshape); + return shape_is_known(dshape) && shape_is_known(oshape); } template @@ -879,7 +888,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs, Stream* s = ctx.get_stream(); const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; - if (Imperative::Get()->is_np_shape() && out.Size() == 0) return; + if (out.Size() == 0) return; const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { common::StaticArray begin, end, step; @@ -979,7 +988,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs, } else if (req[0] == kWriteInplace) { LOG(FATAL) << "_slice_backward does not support kWriteInplace"; } - if (Imperative::Get()->is_np_shape() && ograd.Size() == 0) return; + if (ograd.Size() == 0) return; MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { common::StaticArray begin, end, step; GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step); @@ -1011,7 +1020,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs, GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &vshape); + SetSliceOpOutputDimSize(dshape, i, b, e, s, &vshape); } }) SHAPE_ASSIGN_CHECK(*in_attrs, 1, vshape); @@ -1150,7 +1159,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs, GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); for (index_t i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &vshape); + SetSliceOpOutputDimSize(data.shape_, i, b, e, s, &vshape); } MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { mxnet_op::Kernel, xpu>::Launch(s, vshape.FlatTo2D()[0],