Skip to content

Commit

Permalink
respond to yongwww's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Aug 24, 2020
1 parent f98a29d commit 99ac920
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data_shape[i])
assert cbegin >= 0, "begin value is too negative"
if len(end) <= i:
cend = int64(data_shape[i])
elif slice_mode != 0:
Expand All @@ -155,6 +156,7 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
cend = int64(end[i])
if cend < 0:
cend += int64(data_shape[i])
assert cend >= 0, "end value is too negative"
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides,
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data.shape[i])
assert cbegin >= 0, "begin value is too negative"
if end.shape[0] <= i:
cend = int64(data.shape[i])
elif slice_mode != 0:
Expand All @@ -178,6 +179,7 @@ def _strided_slice_shape_func_input_data(data, begin, end, strides,
cend = int64(end[i])
if cend < 0:
cend += int64(data.shape[i])
assert cend >= 0, "end value is too negative"
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
Expand Down
6 changes: 3 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2070,9 +2070,9 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
}
} else {
CHECK(param->begin) << "strided_slice recieved invalid begin";
CHECK(param->end) << "strided_slice recieved invalid end";
CHECK(param->strides) << "strided_slice recieved invalid strides";
CHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin;
CHECK(param->end) << "strided_slice recieved invalid end " << param->end;
CHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides;
}
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
Expand Down

0 comments on commit 99ac920

Please sign in to comment.