From ad7a37f4b128da786d76e2752d6ea23b2836e810 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Fri, 21 Aug 2020 09:10:03 -0700 Subject: [PATCH] respond to review comments --- python/tvm/relay/op/transform.py | 6 +++--- src/relay/op/dyn/tensor/transform.cc | 11 ++++++----- src/relay/op/tensor/transform.cc | 4 +++- src/relay/transforms/dynamic_to_static.cc | 24 +++++++++++------------ tests/python/relay/test_op_level4.py | 2 +- 5 files changed, 24 insertions(+), 23 deletions(-) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index d63bfa9bfae1b..3193c9e21ad6e 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -836,9 +836,9 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): end = const(list(end)) if isinstance(strides, (tuple, list)): strides = const(list(strides)) - begin = _make.where(begin < cast_like(const(0), begin), - begin + cast_like(shape_of(data), begin), begin) - return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) + normalized_begin = _make.where(begin < cast_like(const(0), begin), + begin + cast_like(shape_of(data), begin), begin) + return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index b18cfa196ee1b..de1cc5a4ed953 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -436,6 +436,7 @@ RELAY_REGISTER_OP("dyn.full") bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { + // [data, begin, end, strides, out] CHECK_EQ(types.size(), 5); const StridedSliceAttrs* param = attrs.as(); if (param == nullptr) { @@ -487,12 +488,12 @@ Array StridedSliceCompute(const Attrs& attrs, const Arrayshape.size(); - CHECK(begin->shape[0].as()->value == attr_size && - end->shape[0].as()->value == attr_size && - strides->shape[0].as()->value == attr_size) + int64_t data_rank = data->shape.size(); + CHECK(begin->shape[0].as()->value == data_rank && + end->shape[0].as()->value == data_rank && + strides->shape[0].as()->value == data_rank) << "begin, end, and strides are required to have the same length" - << " if they are non-constant."; + << " if they are dynamic variables."; return Array{DynamicStridedSlice(data, begin, end, strides)}; } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 40a70f42c3148..fb8ea51fff04f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2070,7 +2070,9 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } } else { - CHECK(false) << "strided_slice recieved invalid params"; + CHECK(param->begin) << "strided_slice recieved invalid begin"; + CHECK(param->end) << "strided_slice recieved invalid end"; + CHECK(param->strides) << "strided_slice recieved invalid strides"; } reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 67aa66230e1b4..3421898e9d3a7 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -141,19 +141,17 @@ class DynamicToStaticMutator : public MixedModeMutator { }}, {Op::Get("dyn.strided_slice"), [](const CallNode* call_node) { - if (const ConstantNode* begin = call_node->args[1].as()) { - if (const ConstantNode* end = call_node->args[2].as()) { - if (const ConstantNode* stride = call_node->args[3].as()) { - CHECK_EQ(begin->data->ndim, 1); - CHECK_EQ(end->data->ndim, 1); - CHECK_EQ(stride->data->ndim, 1); - const StridedSliceAttrs* param = call_node->attrs.as(); - CHECK(param); - return MakeStridedSlice(call_node->args[0], ToVector(begin->data), - ToVector(end->data), ToVector(stride->data), - param->slice_mode); - } - } + const ConstantNode* begin = call_node->args[1].as(); + const ConstantNode* end = call_node->args[2].as(); + const ConstantNode* stride = call_node->args[3].as(); + if (begin && end && stride) { + CHECK_EQ(begin->data->ndim, 1); + CHECK_EQ(end->data->ndim, 1); + CHECK_EQ(stride->data->ndim, 1); + const StridedSliceAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeStridedSlice(call_node->args[0], ToVector(begin->data), ToVector(end->data), + ToVector(stride->data), param->slice_mode); } return Expr(nullptr); }}, diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 8751491fc5563..e739aa9d9bda8 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -337,7 +337,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", text = func.astext() assert "begin=" in text assert "end=" in text - + if output: assert func.body.checked_type == relay.ty.TensorType(output, "float32")