Skip to content

Commit

Permalink
respond to review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart committed Aug 21, 2020
1 parent 19c70a8 commit ad7a37f
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 23 deletions.
6 changes: 3 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,9 +836,9 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
begin = _make.where(begin < cast_like(const(0), begin),
begin + cast_like(shape_of(data), begin), begin)
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
normalized_begin = _make.where(begin < cast_like(const(0), begin),
begin + cast_like(shape_of(data), begin), begin)
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)


Expand Down
11 changes: 6 additions & 5 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ RELAY_REGISTER_OP("dyn.full")

bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// [data, begin, end, strides, out]
CHECK_EQ(types.size(), 5);
const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
if (param == nullptr) {
Expand Down Expand Up @@ -487,12 +488,12 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
te::Tensor end = inputs[2];
te::Tensor strides = inputs[3];
// Dynamic computation
int64_t attr_size = data->shape.size();
CHECK(begin->shape[0].as<IntImmNode>()->value == attr_size &&
end->shape[0].as<IntImmNode>()->value == attr_size &&
strides->shape[0].as<IntImmNode>()->value == attr_size)
int64_t data_rank = data->shape.size();
CHECK(begin->shape[0].as<IntImmNode>()->value == data_rank &&
end->shape[0].as<IntImmNode>()->value == data_rank &&
strides->shape[0].as<IntImmNode>()->value == data_rank)
<< "begin, end, and strides are required to have the same length"
<< " if they are non-constant.";
<< " if they are dynamic variables.";
return Array<te::Tensor>{DynamicStridedSlice(data, begin, end, strides)};
}

Expand Down
4 changes: 3 additions & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2070,7 +2070,9 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step);
}
} else {
CHECK(false) << "strided_slice recieved invalid params";
CHECK(param->begin) << "strided_slice recieved invalid begin";
CHECK(param->end) << "strided_slice recieved invalid end";
CHECK(param->strides) << "strided_slice recieved invalid strides";
}
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
Expand Down
24 changes: 11 additions & 13 deletions src/relay/transforms/dynamic_to_static.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,17 @@ class DynamicToStaticMutator : public MixedModeMutator {
}},
{Op::Get("dyn.strided_slice"),
[](const CallNode* call_node) {
if (const ConstantNode* begin = call_node->args[1].as<ConstantNode>()) {
if (const ConstantNode* end = call_node->args[2].as<ConstantNode>()) {
if (const ConstantNode* stride = call_node->args[3].as<ConstantNode>()) {
CHECK_EQ(begin->data->ndim, 1);
CHECK_EQ(end->data->ndim, 1);
CHECK_EQ(stride->data->ndim, 1);
const StridedSliceAttrs* param = call_node->attrs.as<StridedSliceAttrs>();
CHECK(param);
return MakeStridedSlice(call_node->args[0], ToVector(begin->data),
ToVector(end->data), ToVector(stride->data),
param->slice_mode);
}
}
const ConstantNode* begin = call_node->args[1].as<ConstantNode>();
const ConstantNode* end = call_node->args[2].as<ConstantNode>();
const ConstantNode* stride = call_node->args[3].as<ConstantNode>();
if (begin && end && stride) {
CHECK_EQ(begin->data->ndim, 1);
CHECK_EQ(end->data->ndim, 1);
CHECK_EQ(stride->data->ndim, 1);
const StridedSliceAttrs* param = call_node->attrs.as<StridedSliceAttrs>();
CHECK(param);
return MakeStridedSlice(call_node->args[0], ToVector(begin->data), ToVector(end->data),
ToVector(stride->data), param->slice_mode);
}
return Expr(nullptr);
}},
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end",
text = func.astext()
assert "begin=" in text
assert "end=" in text

if output:
assert func.body.checked_type == relay.ty.TensorType(output, "float32")

Expand Down

0 comments on commit ad7a37f

Please sign in to comment.