From b452f0b2e626efb99b4f3590b944b97c5aea5cc9 Mon Sep 17 00:00:00 2001 From: Mike Mao Date: Thu, 8 Aug 2019 08:23:42 +0000 Subject: [PATCH 1/5] Modify ndarray slice to have numpy compatbile behaviou --- src/operator/tensor/matrix_op-inl.h | 79 ++++++++++++++++++-------- src/operator/tensor/matrix_op.cc | 1 + tests/python/unittest/test_numpy_op.py | 72 +++++++++++++++++++++++ 3 files changed, 127 insertions(+), 25 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 49847502258f..3b0d4136d7aa 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -685,13 +685,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape, << "Static array size=" << ndim << " is not equal to data shape ndim=" << dshape.ndim(); - if (param_step.ndim() != 0) { + if (param_step.ndim() > 0) { CHECK_EQ(param_step.ndim(), param_begin.ndim()) << "step and begin must have the same length"; } for (int i = 0; i < param_begin.ndim(); ++i) { - index_t s = param_step.ndim() != 0U && param_step[i].has_value() ? param_step[i].value() : 1; + index_t s = param_step.ndim() > 0 && param_step[i].has_value() ? param_step[i].value() : 1; CHECK_NE(s, 0) << "slice op step[" << i << "] cannot be 0"; index_t b = 0, e = 0; @@ -703,29 +703,44 @@ inline void GetIndexRange(const mxnet::TShape& dshape, // checking upper and lower bounds for begin if (b < 0) { b += len; - CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len - << " exceeds limit of input dimension[" << i << "]=" << len; + if (!Imperative::Get()->is_np_shape()) { + CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len + << " exceeds limit of input dimension[" << i << "]=" << len; + } + } + if (!Imperative::Get()->is_np_shape()) { + CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b + << " exceeds limit of input dimension[" << i << "]=" << len; } - CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b - << " exceeds limit of input dimension[" << i << "]=" << len; - // checking upper and lower bounds for end if (e < 0 && param_end[i].has_value()) { - if (!(s < 0 && e == -1)) { - // Keep end=-1 as one-beyond-limits index for negative stride - e += len; + e += len; + if (!Imperative::Get()->is_np_shape()) { + CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len + << " exceeds limit of input dimension[" << i << "]=" << len; } - CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len - << " exceeds limit of input dimension[" << i << "]=" << len; } - CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e - << " exceeds limit of input dimension[" << i << "]=" << len; + if (!Imperative::Get()->is_np_shape()) { + CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e + << " exceeds limit of input dimension[" << i << "]=" << len; + } // checking begin==end case which is not supported - CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]=" - << e << " results in an empty tensor and is not supported"; + if (!Imperative::Get()->is_np_shape()) { + CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]=" + << e << " results in an empty tensor and is not supported"; + } } + if (Imperative::Get()->is_np_shape()) { + // move the begin and end to correct position for calculating dim size + b = b < 0 && s > 0 ? 0 : b; + b = b > len-1 && s < 0 ? len-1 : b; + // if the start value lead to empty tensor under step s, use -1 for indication + b = b < 0 || b > len-1 ? -1 : b; + e = e > -1 ? e : -1; + e = e > len ? len : e; + } (*begin)[i] = b; (*end)[i] = e; (*step)[i] = s; @@ -741,17 +756,29 @@ inline void GetIndexRange(const mxnet::TShape& dshape, inline void SetSliceOpOutputDimSize(const index_t i, const int b, const int e, const int s, mxnet::TShape* oshape) { - if (e != b) { - if (s > 0) { - CHECK_LT(b, e) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]=" - << e << ", and step[" << i << "]=" << s << " is invalid"; - (*oshape)[i] = (e - b - 1) / s + 1; + if (!Imperative::Get()->is_np_shape()) { //handle as ndarray + if (e != b) { + if (s > 0) { + CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" + << e << ", and step[" << i << "]=" << s << " is invalid"; + (*oshape)[i] = (e - b - 1) / s + 1; + } else { + CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" + << e << ", and step[" << i << "]=" << s << " is invalid"; + (*oshape)[i] = (b - e - 1) / (-s) + 1; + } + } // else leave oshape[i] as 0 for partial infer + } else { //handle as numpy compatible array + if (e != b && b >= 0) { + if (s > 0) { + (*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0; + } else { + (*oshape)[i] = e < b ? (b - e - 1) / (-s) + 1 : 0; + } } else { - CHECK_LT(e, b) << "slicing with begin[" << i << "]=" << b << ", end[" << i << "]=" - << e << ", and step[" << i << "]=" << s << " is invalid"; - (*oshape)[i] = (b - e - 1) / (-s) + 1; + (*oshape)[i] = 0; } - } // else leave oshape[i] as 0 for partial infer + } } inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, @@ -852,6 +879,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs, Stream* s = ctx.get_stream(); const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; + if (Imperative::Get()->is_np_shape() && out.Size() == 0) return; const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { common::StaticArray begin, end, step; @@ -951,6 +979,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs, } else if (req[0] == kWriteInplace) { LOG(FATAL) << "_slice_backward does not support kWriteInplace"; } + if (Imperative::Get()->is_np_shape() && ograd.Size() == 0) return; MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { common::StaticArray begin, end, step; GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step); diff --git a/src/operator/tensor/matrix_op.cc b/src/operator/tensor/matrix_op.cc index 59e8386d6679..6d1d02a5bc34 100644 --- a/src/operator/tensor/matrix_op.cc +++ b/src/operator/tensor/matrix_op.cc @@ -506,6 +506,7 @@ Example:: [5., 7.], [1., 3.]] )code" ADD_FILELINE) +.add_alias("_npx_slice") .set_attr_parser(ParamParser) .set_attr("FInferShape", SliceOpShape) .set_attr("FInferType", ElemwiseType<1, 1>) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b179f67e6128..9cbd6f6fb2d8 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -92,6 +92,78 @@ def is_int(dtype): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_npx_slice(): + class TestSlice(HybridBlock): + def __init__(self, begin, end, step): + super(TestSlice, self).__init__() + self._begin = begin + self._end = end + self._step = step + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.npx.slice(a, begin=self._begin, end=self._end, step=self._step) + + def get_start_end_step(shape): + start = [] + end = [] + step_switch = random.randint(-1,1) + step = None if step_switch == 0 else [] + for i in range(len(shape)): + s = random.randint(0, shape[i]-1) + e = random.randint(s+1, shape[i]) + if step_switch == 1: + step.append(1) + start.append(s) + end.append(e) + elif step_switch == -1: + step.append(-1) + if e == shape[i]: + e -= 1 + s -= 1 + if s == -1: + s = None + start.append(e) + end.append(s) + else: + start.append(s) + end.append(e) + return start, end, step + + for hybridize in [True, False]: + for i in range(10): + dim = random.randint(1,4) + shape = [random.randint(1,5) for i in range(dim)] + + # test gluon + start, end, step = get_start_end_step(shape) + test_slice = TestSlice(begin=start, end=end, step=step) + if hybridize: + test_slice.hybridize() + + a = mx.nd.random.uniform(shape=shape).as_np_ndarray() + a.attach_grad() + if step is not None: + expected_ret = a.as_nd_ndarray().slice(start, end, step) + else: + expected_ret = a.as_nd_ndarray().slice(start, end) + with mx.autograd.record(): + y = test_slice(a) + + assert_almost_equal(y.asnumpy(), expected_ret.asnumpy(), rtol=1e-3, atol=1e-5) + + # test backward + mx.autograd.backward(y) + expected_grad = _np.zeros(shape) + basic_index = tuple([ + slice(start[i], end[i], step[i]) if step is not None else slice(start[i], end[i]) + for i in range(len(start)) + ]) + expected_grad[basic_index] = 1 + assert_almost_equal(a.grad.asnumpy(), expected_grad, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule() From c507da6d932a8743023f2e6711b24095e9584753 Mon Sep 17 00:00:00 2001 From: Mike Mao Date: Fri, 9 Aug 2019 07:42:33 +0000 Subject: [PATCH 2/5] Minor syntax fix --- src/operator/tensor/matrix_op-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 3b0d4136d7aa..9668843a61dd 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -756,7 +756,7 @@ inline void GetIndexRange(const mxnet::TShape& dshape, inline void SetSliceOpOutputDimSize(const index_t i, const int b, const int e, const int s, mxnet::TShape* oshape) { - if (!Imperative::Get()->is_np_shape()) { //handle as ndarray + if (!Imperative::Get()->is_np_shape()) { // handle as ndarray if (e != b) { if (s > 0) { CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" @@ -768,7 +768,7 @@ inline void SetSliceOpOutputDimSize(const index_t i, const int b, (*oshape)[i] = (b - e - 1) / (-s) + 1; } } // else leave oshape[i] as 0 for partial infer - } else { //handle as numpy compatible array + } else { // handle as numpy compatible array if (e != b && b >= 0) { if (s > 0) { (*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0; From 61bb9c7cd4c968333a0b2107ca2cb32ec159dcbb Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 9 Aug 2019 14:13:11 -0700 Subject: [PATCH 3/5] Fix slice inconsistency --- src/operator/tensor/matrix_op-inl.h | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 9668843a61dd..de175cd428fe 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -730,14 +730,17 @@ inline void GetIndexRange(const mxnet::TShape& dshape, CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]=" << e << " results in an empty tensor and is not supported"; } + } else if (len == 0) { + b = 0; + e = 0; } - if (Imperative::Get()->is_np_shape()) { + if (Imperative::Get()->is_np_shape() && len > 0) { // move the begin and end to correct position for calculating dim size b = b < 0 && s > 0 ? 0 : b; - b = b > len-1 && s < 0 ? len-1 : b; + b = b > len - 1 && s < 0 ? len-1 : b; // if the start value lead to empty tensor under step s, use -1 for indication - b = b < 0 || b > len-1 ? -1 : b; + b = b < 0 || b > len - 1 ? -1 : b; e = e > -1 ? e : -1; e = e > len ? len : e; } @@ -753,9 +756,14 @@ inline void GetIndexRange(const mxnet::TShape& dshape, } } -inline void SetSliceOpOutputDimSize(const index_t i, const int b, +inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape, + const index_t i, const int b, const int e, const int s, mxnet::TShape* oshape) { + if (!mxnet::dim_size_is_known(dshape, i)) { + (*oshape)[i] = -1; + return; + } if (!Imperative::Get()->is_np_shape()) { // handle as ndarray if (e != b) { if (s > 0) { @@ -788,6 +796,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); const mxnet::TShape& dshape = (*in_attrs)[0]; if (!mxnet::ndim_is_known(dshape)) return false; + CHECK_GT(dshape.ndim(), 0) << "slice only works for ndim > 0"; const SliceParam& param = nnvm::get(attrs.parsed); mxnet::TShape oshape = dshape; @@ -796,12 +805,12 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs, GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &oshape); + SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape); } }) SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return shape_is_known(oshape); + return shape_is_known(dshape) && shape_is_known(oshape); } template @@ -879,7 +888,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs, Stream* s = ctx.get_stream(); const TBlob& data = inputs[0]; const TBlob& out = outputs[0]; - if (Imperative::Get()->is_np_shape() && out.Size() == 0) return; + if (out.Size() == 0) return; const SliceParam& param = nnvm::get(attrs.parsed); MXNET_NDIM_SWITCH(data.ndim(), ndim, { common::StaticArray begin, end, step; @@ -979,7 +988,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs, } else if (req[0] == kWriteInplace) { LOG(FATAL) << "_slice_backward does not support kWriteInplace"; } - if (Imperative::Get()->is_np_shape() && ograd.Size() == 0) return; + if (ograd.Size() == 0) return; MXNET_NDIM_SWITCH(ograd.ndim(), ndim, { common::StaticArray begin, end, step; GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step); @@ -1011,7 +1020,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs, GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step); for (int i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &vshape); + SetSliceOpOutputDimSize(dshape, i, b, e, s, &vshape); } }) SHAPE_ASSIGN_CHECK(*in_attrs, 1, vshape); @@ -1150,7 +1159,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs, GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step); for (index_t i = 0; i < param.begin.ndim(); ++i) { const int b = begin[i], e = end[i], s = step[i]; - SetSliceOpOutputDimSize(i, b, e, s, &vshape); + SetSliceOpOutputDimSize(data.shape_, i, b, e, s, &vshape); } MSHADOW_TYPE_SWITCH(out.type_flag_, DType, { mxnet_op::Kernel, xpu>::Launch(s, vshape.FlatTo2D()[0], From 33a6f474022f4f1ec434038725100adc27950caf Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 9 Aug 2019 15:52:13 -0700 Subject: [PATCH 4/5] Allow empty outputs after slicing ndarrays --- src/operator/tensor/matrix_op-inl.h | 33 ++---------- tests/python/unittest/test_numpy_op.py | 71 ++++++++++---------------- tests/python/unittest/test_operator.py | 9 ---- 3 files changed, 32 insertions(+), 81 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index de175cd428fe..ed6f0d437797 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -700,42 +700,13 @@ inline void GetIndexRange(const mxnet::TShape& dshape, b = param_begin[i].has_value() ? param_begin[i].value() : (s < 0 ? len - 1 : 0); e = param_end[i].has_value() ? param_end[i].value() : (s < 0 ? -1 : len); - // checking upper and lower bounds for begin if (b < 0) { b += len; - if (!Imperative::Get()->is_np_shape()) { - CHECK_GE(b, 0) << "slicing with begin[" << i << "]=" << b - len - << " exceeds limit of input dimension[" << i << "]=" << len; - } - } - if (!Imperative::Get()->is_np_shape()) { - CHECK_LT(b, len) << "slicing with begin[" << i << "]=" << b - << " exceeds limit of input dimension[" << i << "]=" << len; } - // checking upper and lower bounds for end if (e < 0 && param_end[i].has_value()) { e += len; - if (!Imperative::Get()->is_np_shape()) { - CHECK_GE(e, 0) << "slicing with end[" << i << "]=" << e - len - << " exceeds limit of input dimension[" << i << "]=" << len; - } - } - if (!Imperative::Get()->is_np_shape()) { - CHECK_LE(e, len) << "slicing with end[" << i << "]=" << e - << " exceeds limit of input dimension[" << i << "]=" << len; } - // checking begin==end case which is not supported - if (!Imperative::Get()->is_np_shape()) { - CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]=" - << e << " results in an empty tensor and is not supported"; - } - } else if (len == 0) { - b = 0; - e = 0; - } - - if (Imperative::Get()->is_np_shape() && len > 0) { // move the begin and end to correct position for calculating dim size b = b < 0 && s > 0 ? 0 : b; b = b > len - 1 && s < 0 ? len-1 : b; @@ -743,7 +714,11 @@ inline void GetIndexRange(const mxnet::TShape& dshape, b = b < 0 || b > len - 1 ? -1 : b; e = e > -1 ? e : -1; e = e > len ? len : e; + } else if (len == 0) { + b = 0; + e = 0; } + (*begin)[i] = b; (*end)[i] = e; (*step)[i] = s; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9cbd6f6fb2d8..c172336be32e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -102,66 +102,51 @@ def __init__(self, begin, end, step): self._end = end self._step = step - def hybrid_forward(self, F, a, *args, **kwargs): + def hybrid_forward(self, F, a): return F.npx.slice(a, begin=self._begin, end=self._end, step=self._step) - def get_start_end_step(shape): - start = [] - end = [] - step_switch = random.randint(-1,1) - step = None if step_switch == 0 else [] - for i in range(len(shape)): - s = random.randint(0, shape[i]-1) - e = random.randint(s+1, shape[i]) - if step_switch == 1: - step.append(1) - start.append(s) - end.append(e) - elif step_switch == -1: - step.append(-1) - if e == shape[i]: - e -= 1 - s -= 1 - if s == -1: - s = None - start.append(e) - end.append(s) - else: - start.append(s) - end.append(e) - return start, end, step + shape = (8, 16, 9, 9) + np_array = _np.arange(_np.prod(shape), dtype='int32').reshape(shape) + configs = [ + ([], [], None), + ([], [], []), + ([1], [4], None), + ([1], [10], [3]), + ([10], [0], [-2]), + ([None], [None], [None]), + ([None], [None], [-1]), + ([10], [None], [-1]), + ([1, 0, 3], [-2, 10, -4], [None, 2, 3]), + ([-2, -3, -5, -6], [1, 3, 4, 5], None), + ([-2, -3, -5, -6], [1, 3, 4, 5], [-1, -2, -3, -4]), + ([2, -3, -5, -6], [2, 3, 4, 5], None), + ([2, -3, -5, 5], [3, 3, 4, 5], None), + ] for hybridize in [True, False]: - for i in range(10): - dim = random.randint(1,4) - shape = [random.randint(1,5) for i in range(dim)] - - # test gluon - start, end, step = get_start_end_step(shape) + for config in configs: + start, end, step = config[0], config[1], config[2] test_slice = TestSlice(begin=start, end=end, step=step) if hybridize: test_slice.hybridize() - a = mx.nd.random.uniform(shape=shape).as_np_ndarray() + a = np.array(np_array, dtype=np_array.dtype) a.attach_grad() - if step is not None: - expected_ret = a.as_nd_ndarray().slice(start, end, step) - else: - expected_ret = a.as_nd_ndarray().slice(start, end) + basic_index = tuple([ + slice(start[i], end[i], step[i]) if step is not None else slice(start[i], end[i]) + for i in range(len(start)) + ]) + expected_ret = np_array[basic_index] with mx.autograd.record(): y = test_slice(a) - assert_almost_equal(y.asnumpy(), expected_ret.asnumpy(), rtol=1e-3, atol=1e-5) + assert same(y.asnumpy(), expected_ret) # test backward mx.autograd.backward(y) expected_grad = _np.zeros(shape) - basic_index = tuple([ - slice(start[i], end[i], step[i]) if step is not None else slice(start[i], end[i]) - for i in range(len(start)) - ]) expected_grad[basic_index] = 1 - assert_almost_equal(a.grad.asnumpy(), expected_grad, rtol=1e-3, atol=1e-5) + assert same(a.grad.asnumpy(), expected_grad) if __name__ == '__main__': diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 13ab03b94de9..51d4a1580b17 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7519,15 +7519,6 @@ def test_slice_forward_backward(a, index): for index in index_list: test_slice_forward_backward(arr, index) - def test_begin_equals_end(shape, begin, end, step): - in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape) - out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step) - - assertRaises(MXNetError, test_begin_equals_end, (4,), (2,), (2,), (1,)) - assertRaises(MXNetError, test_begin_equals_end, (1, 5), (None, 3), (None, 3), (-1, 1)) - assertRaises(MXNetError, test_begin_equals_end, (3, 4, 5), (1, 3, 1), (3, 3, 1), (1, -3, 2)) - assertRaises(MXNetError, test_begin_equals_end, (2, 4), (None, 2), (None, 2), (1, -1)) - # check numeric gradient in_data = np.arange(36).reshape(2, 2, 3, 3) data = mx.sym.Variable('data') From 010d8cf646139412184c55eead99bde68aadf96f Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 9 Aug 2019 16:20:21 -0700 Subject: [PATCH 5/5] Fix --- src/operator/tensor/matrix_op-inl.h | 34 +++++++++-------------------- 1 file changed, 10 insertions(+), 24 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index ed6f0d437797..b52e89d8fc0f 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -708,10 +708,10 @@ inline void GetIndexRange(const mxnet::TShape& dshape, } // move the begin and end to correct position for calculating dim size - b = b < 0 && s > 0 ? 0 : b; - b = b > len - 1 && s < 0 ? len-1 : b; + b = (b < 0 && s > 0) ? 0 : b; + b = (b > len - 1 && s < 0) ? len - 1 : b; // if the start value lead to empty tensor under step s, use -1 for indication - b = b < 0 || b > len - 1 ? -1 : b; + b = (b < 0 || b > len - 1) ? -1 : b; e = e > -1 ? e : -1; e = e > len ? len : e; } else if (len == 0) { @@ -724,7 +724,7 @@ inline void GetIndexRange(const mxnet::TShape& dshape, (*step)[i] = s; } - for (index_t i = param_begin.ndim(); i < dshape.ndim(); ++i) { + for (int i = param_begin.ndim(); i < dshape.ndim(); ++i) { (*begin)[i] = 0; (*end)[i] = dshape[i]; (*step)[i] = 1; @@ -739,28 +739,14 @@ inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape, (*oshape)[i] = -1; return; } - if (!Imperative::Get()->is_np_shape()) { // handle as ndarray - if (e != b) { - if (s > 0) { - CHECK_LT(b, e) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" - << e << ", and step[" << i << "]=" << s << " is invalid"; - (*oshape)[i] = (e - b - 1) / s + 1; - } else { - CHECK_LT(e, b) << "slicing with begin=[" << i << "]=" << b << ", end[" << i << "]=" - << e << ", and step[" << i << "]=" << s << " is invalid"; - (*oshape)[i] = (b - e - 1) / (-s) + 1; - } - } // else leave oshape[i] as 0 for partial infer - } else { // handle as numpy compatible array - if (e != b && b >= 0) { - if (s > 0) { - (*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0; - } else { - (*oshape)[i] = e < b ? (b - e - 1) / (-s) + 1 : 0; - } + if (e != b && b >= 0) { + if (s > 0) { + (*oshape)[i] = e > b ? (e - b - 1) / s + 1 : 0; } else { - (*oshape)[i] = 0; + (*oshape)[i] = e < b ? (b - e - 1) / (-s) + 1 : 0; } + } else { + (*oshape)[i] = 0; } }