Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix slice inconsistency
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Aug 9, 2019
1 parent c507da6 commit 61bb9c7
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -730,14 +730,17 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
CHECK_NE(b, e) << "slicing with begin[" << i << "]=end[" << i << "]="
<< e << " results in an empty tensor and is not supported";
}
} else if (len == 0) {
b = 0;
e = 0;
}

if (Imperative::Get()->is_np_shape()) {
if (Imperative::Get()->is_np_shape() && len > 0) {
// move the begin and end to correct position for calculating dim size
b = b < 0 && s > 0 ? 0 : b;
b = b > len-1 && s < 0 ? len-1 : b;
b = b > len - 1 && s < 0 ? len-1 : b;
// if the start value lead to empty tensor under step s, use -1 for indication
b = b < 0 || b > len-1 ? -1 : b;
b = b < 0 || b > len - 1 ? -1 : b;
e = e > -1 ? e : -1;
e = e > len ? len : e;
}
Expand All @@ -753,9 +756,14 @@ inline void GetIndexRange(const mxnet::TShape& dshape,
}
}

inline void SetSliceOpOutputDimSize(const index_t i, const int b,
inline void SetSliceOpOutputDimSize(const mxnet::TShape& dshape,
const index_t i, const int b,
const int e, const int s,
mxnet::TShape* oshape) {
if (!mxnet::dim_size_is_known(dshape, i)) {
(*oshape)[i] = -1;
return;
}
if (!Imperative::Get()->is_np_shape()) { // handle as ndarray
if (e != b) {
if (s > 0) {
Expand Down Expand Up @@ -788,6 +796,7 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& dshape = (*in_attrs)[0];
if (!mxnet::ndim_is_known(dshape)) return false;
CHECK_GT(dshape.ndim(), 0) << "slice only works for ndim > 0";
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
mxnet::TShape oshape = dshape;

Expand All @@ -796,12 +805,12 @@ inline bool SliceOpShape(const nnvm::NodeAttrs& attrs,
GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
for (int i = 0; i < param.begin.ndim(); ++i) {
const int b = begin[i], e = end[i], s = step[i];
SetSliceOpOutputDimSize(i, b, e, s, &oshape);
SetSliceOpOutputDimSize(dshape, i, b, e, s, &oshape);
}
})

SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
return shape_is_known(oshape);
return shape_is_known(dshape) && shape_is_known(oshape);
}

template<int ndim, int req, typename xpu>
Expand Down Expand Up @@ -879,7 +888,7 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
Stream<xpu>* s = ctx.get_stream<xpu>();
const TBlob& data = inputs[0];
const TBlob& out = outputs[0];
if (Imperative::Get()->is_np_shape() && out.Size() == 0) return;
if (out.Size() == 0) return;
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
MXNET_NDIM_SWITCH(data.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin, end, step;
Expand Down Expand Up @@ -979,7 +988,7 @@ void SliceOpBackward(const nnvm::NodeAttrs& attrs,
} else if (req[0] == kWriteInplace) {
LOG(FATAL) << "_slice_backward does not support kWriteInplace";
}
if (Imperative::Get()->is_np_shape() && ograd.Size() == 0) return;
if (ograd.Size() == 0) return;
MXNET_NDIM_SWITCH(ograd.ndim(), ndim, {
common::StaticArray<index_t, ndim> begin, end, step;
GetIndexRange(igrad.shape_, param.begin, param.end, param.step, &begin, &end, &step);
Expand Down Expand Up @@ -1011,7 +1020,7 @@ inline bool SliceAssignOpShape(const nnvm::NodeAttrs& attrs,
GetIndexRange(dshape, param.begin, param.end, param.step, &begin, &end, &step);
for (int i = 0; i < param.begin.ndim(); ++i) {
const int b = begin[i], e = end[i], s = step[i];
SetSliceOpOutputDimSize(i, b, e, s, &vshape);
SetSliceOpOutputDimSize(dshape, i, b, e, s, &vshape);
}
})
SHAPE_ASSIGN_CHECK(*in_attrs, 1, vshape);
Expand Down Expand Up @@ -1150,7 +1159,7 @@ void SliceAssignScalarOpForward(const nnvm::NodeAttrs& attrs,
GetIndexRange(data.shape_, param.begin, param.end, param.step, &begin, &end, &step);
for (index_t i = 0; i < param.begin.ndim(); ++i) {
const int b = begin[i], e = end[i], s = step[i];
SetSliceOpOutputDimSize(i, b, e, s, &vshape);
SetSliceOpOutputDimSize(data.shape_, i, b, e, s, &vshape);
}
MSHADOW_TYPE_SWITCH(out.type_flag_, DType, {
mxnet_op::Kernel<slice_assign_scalar<ndim>, xpu>::Launch(s, vshape.FlatTo2D()[0],
Expand Down

0 comments on commit 61bb9c7

Please sign in to comment.