From 7e05d15bc79e60f5a237c64b8737d1c8806ef387 Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Tue, 8 Sep 2020 10:47:57 -0600 Subject: [PATCH] Dynamic Strided Slice (#6316) * Dynamic Strided Slice * fix clang-format lint * remove debug print * respond to review comments * respond to yongwww's comments * fix bad rebase * revert hybrid-script assert * reformat mxnet change * use new testing api * while getting test to work with the new testing API, refactor all of the tests iin the dyn directory --- python/tvm/relay/frontend/keras.py | 4 +- python/tvm/relay/frontend/mxnet.py | 16 +-- python/tvm/relay/frontend/onnx.py | 8 +- python/tvm/relay/frontend/pytorch.py | 18 +-- python/tvm/relay/op/_tensor_grad.py | 6 +- python/tvm/relay/op/_transform.py | 54 +++----- python/tvm/relay/op/dyn/_transform.py | 51 ++++++++ python/tvm/relay/op/transform.py | 19 ++- python/tvm/topi/cuda/conv2d_alter_op.py | 4 +- python/tvm/topi/x86/conv2d_alter_op.py | 4 +- src/relay/op/dyn/tensor/transform.cc | 113 ++++++++++++++++ src/relay/op/make_op.h | 3 +- src/relay/op/tensor/transform.cc | 121 ++++++++---------- .../combine_parallel_batch_matmul.cc | 12 +- .../transforms/combine_parallel_conv2d.cc | 17 +-- .../transforms/combine_parallel_dense.cc | 12 +- src/relay/transforms/dynamic_to_static.cc | 16 +++ src/relay/transforms/pass_util.h | 1 + .../relay/dyn/test_dynamic_op_level10.py | 19 ++- .../relay/dyn/test_dynamic_op_level2.py | 23 +++- .../relay/dyn/test_dynamic_op_level3.py | 15 ++- .../relay/dyn/test_dynamic_op_level4.py | 95 ++++++++++++++ .../relay/dyn/test_dynamic_op_level5.py | 3 +- .../relay/dyn/test_dynamic_op_level6.py | 4 +- tests/python/relay/test_op_level4.py | 68 ++++++++-- .../python/relay/test_pass_alter_op_layout.py | 12 +- ...test_pass_combine_parallel_batch_matmul.py | 36 +++--- .../test_pass_combine_parallel_conv2d.py | 54 ++++---- .../relay/test_pass_combine_parallel_dense.py | 36 +++--- .../relay/test_pass_dynamic_to_static.py | 64 +++++++++ 30 files changed, 634 insertions(+), 274 deletions(-) create mode 100644 tests/python/relay/dyn/test_dynamic_op_level4.py diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py index b469ed0045a1..d8bff8ca48d8 100644 --- a/python/tvm/relay/frontend/keras.py +++ b/python/tvm/relay/frontend/keras.py @@ -622,8 +622,8 @@ def _convert_cropping(inexpr, keras_layer, _): raise tvm.error.OpNotImplemented( 'Operator {} is not supported for frontend Keras.'.format(crop_type)) int32_max = np.iinfo(np.int32).max - return _op.strided_slice(inexpr, begin=_expr.const([0, 0, crop_t, crop_l]), \ - end=_expr.const([int32_max, int32_max, in_h-crop_b, in_w-crop_r])) + return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \ + end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) def _convert_batchnorm(inexpr, keras_layer, etab): diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 1b49c1c7e4eb..faa62e18a193 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -500,11 +500,11 @@ def _mx_slice(inputs, attrs): for i, ed in enumerate(end): if ed is None: end[i] = input_shape[i] - new_attrs = {'begin': _expr.const(list(begin), dtype="int32"), - 'end': _expr.const(list(end), dtype="int32")} + new_attrs = {'begin': list(begin), + 'end': list(end)} if stride is not None: stride = (x if x is not None else 1 for x in stride) - new_attrs['strides'] = _expr.const(list(stride), dtype="int32") + new_attrs['strides'] = list(stride) return _op.strided_slice(inputs[0], **new_attrs) @@ -544,9 +544,7 @@ def _mx_slice_axis(inputs, attrs): else: begin.append(ax_beg) end.append(ax_end) - return _op.strided_slice(inputs[0], - _expr.const(begin, dtype="int32"), - _expr.const(end, dtype="int32")) + return _op.strided_slice(inputs[0], begin, end) def _mx_crop_like(inputs, attrs): @@ -566,9 +564,9 @@ def _mx_crop_like(inputs, attrs): return _op.slice_like(*inputs, **new_attrs) expr = _infer_type(inputs[1]) like_shape = expr.checked_type.shape - new_attrs['begin'] = _expr.const([0, 0, offset[0], offset[1]], dtype="int32") - new_attrs['end'] = _expr.const([like_shape[0], like_shape[1], offset[0]+like_shape[2], - offset[1]+like_shape[3]], dtype="int32") + new_attrs['begin'] = [0, 0, offset[0], offset[1]] + new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2], + offset[1]+like_shape[3]] return _op.strided_slice(inputs[0], **new_attrs) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 877174cc55fe..ea39010df066 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1049,8 +1049,8 @@ def _impl_v1(cls, inputs, attr, params): end = list(attr['ends']) return _op.strided_slice(inputs[0], - begin=_expr.const(begin, dtype="int64"), - end=_expr.const(end, dtype="int64")) + begin=begin, + end=end) @classmethod def _impl_v10(cls, inputs, attr, params): @@ -1070,8 +1070,8 @@ def _impl_v10(cls, inputs, attr, params): attrs['starts'] = new_starts attrs['ends'] = new_ends return _op.strided_slice(inputs[0], - begin=_expr.const(attrs['starts'], dtype="int64"), - end=_expr.const(attrs['ends'], dtype="int64")) + begin=list(attrs['starts']), + end=list(attrs['ends'])) class Gather(OnnxOpConverter): diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 8d850093f71b..51d90e1fb985 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -309,9 +309,9 @@ def _impl(inputs, input_types): strides[dim] = int(inputs[4]) return _op.transform.strided_slice(data, - begin=_expr.const(begin), - end=_expr.const(end), - strides=_expr.const(strides), + begin=begin, + end=end, + strides=strides, slice_mode="end") return _impl @@ -1373,9 +1373,9 @@ def _impl(inputs, input_types): stride = [1] * len(shape) chunk_out = _op.transform.strided_slice(data, - begin=_expr.const(begin), - end=_expr.const(end), - strides=_expr.const(stride)) + begin=begin, + end=end, + strides=stride) chunks.append(chunk_out) if dim % num_chunks: @@ -1386,9 +1386,9 @@ def _impl(inputs, input_types): stride = [1] * len(shape) chunk_out = _op.transform.strided_slice(data, - begin=_expr.const(begin), - end=_expr.const(end), - strides=_expr.const(stride)) + begin=begin, + end=end, + strides=stride) chunks.append(chunk_out) return chunks diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 46a45354a9cc..5069f79de10f 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -407,9 +407,9 @@ def conv2d_grad(orig, grad): assert padded_weight_grad_w >= filter_w if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w: backward_weight = strided_slice(backward_weight, - begin=const([0, 0, 0, 0], dtype="int64"), - end=const([out_channel, in_channel // attrs.groups, - filter_h, filter_w], dtype="int64")) + begin=[0, 0, 0, 0], + end=[out_channel, in_channel // attrs.groups, + filter_h, filter_w]) return [backward_data, backward_weight] diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 937c36e60919..9d7c389ccc08 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -127,33 +127,6 @@ def arange_shape_func(attrs, inputs, _): """ return [_arange_shape_func(*inputs)] -@script -def _strided_slice_shape_func_input_data(data, begin, end, strides, - slice_mode): - ndim = len(data.shape) - out = output_tensor((ndim,), "int64") - for i in const_range(ndim): - cbegin = 0 - cend = data.shape[i] - cstride = 1 - if strides.shape[0] > i: - cstride = strides[i] - if begin.shape[0] > i: - cbegin = begin[i] - if end.shape[0] <= i: - cend = data.shape[i] - elif slice_mode != 0: - cstride = 1 - if end[i] < 0: - cend = data.shape[i] - else: - cend = cbegin + end[i] - else: - cend = end[i] - assert cstride != 0, "Strides can't be zero." - out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride))) - return out - @script def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode): ndim = data_shape.shape[0] @@ -166,6 +139,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice cstride = int64(strides[i]) if len(begin) > i: cbegin = int64(begin[i]) + if cbegin < 0: + cbegin += int64(data_shape[i]) if len(end) <= i: cend = int64(data_shape[i]) elif slice_mode != 0: @@ -175,23 +150,32 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice else: cend = cbegin + int64(end[i]) else: - cend = int64(end[i]) + if end[i] > data_shape[i]: + cend = int64(data_shape[i]) + else: + cend = int64(end[i]) + if cend < 0: + cend += int64(data_shape[i]) assert cstride != 0, "Strides can't be zero." - out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride))) + if cstride < 0: + slice_range = cbegin - cend + step = -cstride + else: + slice_range = cend - cbegin + step = cstride + + out[i] = int64(ceil_div(slice_range, step)) return out -@_reg.register_shape_func("strided_slice", True) +@_reg.register_shape_func("strided_slice", False) def strided_slice_shape_func(attrs, inputs, _): """ Shape func for strided_slice """ slice_mode = convert(0 if attrs.slice_mode == "end" else 1) - # data independent if begin, end and strides exist - if attrs.begin and attrs.end and attrs.strides: - return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end, - attrs.strides, slice_mode)] - return [_strided_slice_shape_func_input_data(*inputs, slice_mode)] + return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end, + attrs.strides, slice_mode)] @script def _concatenate_shape_func(inputs, axis): diff --git a/python/tvm/relay/op/dyn/_transform.py b/python/tvm/relay/op/dyn/_transform.py index 46778fef8410..6bf02ecf31ea 100644 --- a/python/tvm/relay/op/dyn/_transform.py +++ b/python/tvm/relay/op/dyn/_transform.py @@ -27,6 +27,7 @@ _reg.register_broadcast_schedule("dyn.tile") _reg.register_injective_schedule("dyn.one_hot") _reg.register_injective_schedule("dyn.full") +_reg.register_injective_schedule("dyn.strided_slice") @script def _reshape_shape_func_input_data(data, newshape, ndim): @@ -145,3 +146,53 @@ def one_hot_shape_func(attrs, inputs, _): """ axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))] + + +@script +def _strided_slice_shape_func_input_data(data, begin, end, strides, + slice_mode): + ndim = len(data.shape) + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + cbegin = int64(0) + cend = int64(data.shape[i]) + cstride = int64(1) + if strides.shape[0] > i: + cstride = int64(strides[i]) + if begin.shape[0] > i: + cbegin = int64(begin[i]) + if cbegin < 0: + cbegin += int64(data.shape[i]) + if end.shape[0] <= i: + cend = int64(data.shape[i]) + elif slice_mode != 0: + cstride = int64(1) + if end[i] < 0: + cend = int64(data.shape[i]) + else: + cend = cbegin + int64(end[i]) + else: + if end[i] > data.shape[i]: + cend = int64(data.shape[i]) + else: + cend = int64(end[i]) + if cend < 0: + cend += int64(data.shape[i]) + assert cstride != 0, "Strides can't be zero." + if cstride < 0: + slice_range = cbegin - cend + step = -cstride + else: + slice_range = cend - cbegin + step = cstride + + out[i] = int64(ceil_div(slice_range, step)) + return out + +@_reg.register_shape_func("dyn.strided_slice", True) +def strided_slice_shape_func(attrs, inputs, _): + """ + Shape func for strided_slice + """ + slice_mode = convert(0 if attrs.slice_mode == "end" else 1) + return [_strided_slice_shape_func_input_data(*inputs, slice_mode)] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 6d3c8be21fd0..01466f7dae7b 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -20,6 +20,7 @@ from . import _make from .dyn import _make as _dyn_make +from .tensor import shape_of from ..expr import TupleWrapper, const, Expr, Tuple from ...tir import expr as _expr @@ -827,13 +828,17 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): ret : relay.Expr The computed result. """ - strides = strides or const([1], dtype="int32") - if isinstance(begin, (tuple, list)): - begin = const(list(begin)) - if isinstance(end, (tuple, list)): - end = const(list(end)) - if isinstance(strides, (tuple, list)): - strides = const(list(strides)) + strides = strides or [1] + if (isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr)): + if isinstance(begin, (tuple, list)): + begin = const(list(begin)) + if isinstance(end, (tuple, list)): + end = const(list(end)) + if isinstance(strides, (tuple, list)): + strides = const(list(strides)) + normalized_begin = _make.where(begin < cast_like(const(0), begin), + begin + cast_like(shape_of(data), begin), begin) + return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode) return _make.strided_slice(data, begin, end, strides, slice_mode) diff --git a/python/tvm/topi/cuda/conv2d_alter_op.py b/python/tvm/topi/cuda/conv2d_alter_op.py index f07ef984025f..89a8569eaf8d 100644 --- a/python/tvm/topi/cuda/conv2d_alter_op.py +++ b/python/tvm/topi/cuda/conv2d_alter_op.py @@ -276,8 +276,8 @@ def _conv2d_legalize(attrs, inputs, arg_types): new_attrs['channels'] = new_out_channel out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] - out = relay.strided_slice(out, begin=relay.const([0, 0, 0, 0]), - end=relay.const(original_out_shape)) + out = relay.strided_slice(out, begin=[0, 0, 0, 0], + end=original_out_shape) else: out = relay.nn.conv2d(data, kernel, **new_attrs) return out diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index e9fc4223a9ea..992353eb0955 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -313,8 +313,8 @@ def _conv2d_legalize(attrs, inputs, arg_types): out = tvm.relay.nn.conv2d(data, kernel, **new_attrs) original_out_shape = [x.value for x in output_tensor.shape] out = relay.strided_slice(out, - begin=relay.const([0, 0, 0, 0], "int32"), - end=relay.const(original_out_shape, "int32")) + begin=[0, 0, 0, 0], + end=original_out_shape) else: out = relay.nn.conv2d(data, kernel, **new_attrs) diff --git a/src/relay/op/dyn/tensor/transform.cc b/src/relay/op/dyn/tensor/transform.cc index 06e1c579728f..de1cc5a4ed95 100644 --- a/src/relay/op/dyn/tensor/transform.cc +++ b/src/relay/op/dyn/tensor/transform.cc @@ -27,13 +27,17 @@ #include #include #include +#include #include #include #include +#include #include #include +#include "../../../transforms/infer_layout_util.h" + namespace tvm { namespace relay { namespace dyn { @@ -430,6 +434,115 @@ RELAY_REGISTER_OP("dyn.full") .set_attr("FTVMCompute", FullCompute) .set_attr("TOpPattern", kElemWise); +bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + // [data, begin, end, strides, out] + CHECK_EQ(types.size(), 5); + const StridedSliceAttrs* param = attrs.as(); + if (param == nullptr) { + return false; + } + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + auto dshape = data->shape; + int64_t num_axis = dshape.size(); + + // calculate output shape + std::vector oshape(num_axis); + for (int64_t i = 0; i < num_axis; ++i) { + oshape[i] = Any(); + } + + reporter->Assign(types[4], TensorType(oshape, data->dtype)); + return true; +} + +inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const te::Tensor& begin, + const te::Tensor& end, const te::Tensor& strides, + std::string name = "T_strided_slice_dynamic", + std::string tag = topi::kInjective) { + int64_t src_tensor_dim = input->shape.size(); + Array out_shape; + for (int64_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(tvm::tir::Var("dim")); + } + // TODO(yongwww): move the compute into topi + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (int32_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides(i) + begin(i)); + } + return input(real_indices); + }, + name, tag); +} + +Array StridedSliceCompute(const Attrs& attrs, const Array& inputs, + const Type& out_type) { + te::Tensor data = inputs[0]; + te::Tensor begin = inputs[1]; + te::Tensor end = inputs[2]; + te::Tensor strides = inputs[3]; + // Dynamic computation + int64_t data_rank = data->shape.size(); + CHECK(begin->shape[0].as()->value == data_rank && + end->shape[0].as()->value == data_rank && + strides->shape[0].as()->value == data_rank) + << "begin, end, and strides are required to have the same length" + << " if they are dynamic variables."; + return Array{DynamicStridedSlice(data, begin, end, strides)}; +} + +Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode) { + auto attrs = make_object(); + attrs->slice_mode = slice_mode; + static const Op& op = Op::Get("dyn.strided_slice"); + return Call(op, {data, begin, end, strides}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.dyn._make.strided_slice").set_body_typed(MakeStridedSlice); + +RELAY_REGISTER_OP("dyn.strided_slice") + .describe(R"code(Strided slice of an array. + +Examples:: + + x = [[ 1., 4., 7., 10.], + [ 2., 5., 8., 11.], + [ 3., 6., 9., 12.]] + + strided_slice(x, begin=[0, 1], end=[2, 4], stride=[1, 1]) = [[ 4., 7., 10.], + [ 5., 8., 11.]] + + x = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] + + strided_slice(x, begin=[0, 0], end=[2, 2]) = [[[ 1., 2.], + [ 3., 4.]], + + [[ 5., 6.], + [ 7., 8.]]] +)code" TVM_ADD_FILELINE) + .set_num_inputs(4) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") + .add_argument("end", "Tensor", "Indices indicating end of the slice.") + .add_argument("strides", "Tensor", "The stride values.") + .add_argument("slice_mode", "Tensor", "The slice mode.") + .set_support_level(4) + .set_attrs_type() + .add_type_rel("DynStridedSlice", StridedSliceRel) + .set_attr("FTVMCompute", StridedSliceCompute) + .set_attr("TOpPattern", kInjective) + .set_attr("AnyCodegenStrategy", kVariableDimensions); + } // namespace dyn } // namespace relay } // namespace tvm diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index dc9ddee0f0bb..631ec4c0d2f5 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -68,7 +68,8 @@ Expr MakeSqueeze(Expr data, Array axis); Expr MakeStack(Expr data, int axis); -Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode); +Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, + String slice_mode); Expr MakeTile(Expr data, Array reps); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 40051e43d57b..293875ebf6ea 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -40,6 +41,7 @@ #include #include "../../transforms/infer_layout_util.h" +#include "../../transforms/pass_util.h" #include "../../transforms/pattern_util.h" #include "../make_op.h" #include "../op_common.h" @@ -1985,7 +1987,7 @@ TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 5); + CHECK_EQ(types.size(), 2); const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); const auto* data = types[0].as(); @@ -2079,12 +2081,11 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); } } else { - for (int64_t i = 0; i < num_axis; ++i) { - oshape[i] = Any(); - } + CHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; + CHECK(param->end) << "strided_slice recieved invalid end " << param->end; + CHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; } - - reporter->Assign(types[4], TensorType(oshape, data->dtype)); + reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -2176,78 +2177,62 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, params->begin = new_begin; params->end = new_end; } - return {{layout, Layout("C"), Layout("C"), Layout("C")}, {layout}}; -} - -inline te::Tensor DynamicStridedSlice(const te::Tensor& input, const te::Tensor& begin, - const te::Tensor& end, const te::Tensor& strides, - std::string name = "T_strided_slice_dynamic", - std::string tag = topi::kInjective) { - int64_t src_tensor_dim = input->shape.size(); - Array out_shape; - for (int64_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - // TODO(yongwww): move the compute into topi - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (int32_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides(i) + begin(i)); - } - return input(real_indices); - }, - name, tag); + return {{layout}, {layout}}; } Array StridedSliceCompute(const Attrs& attrs, const Array& inputs, const Type& out_type) { const StridedSliceAttrs* param = attrs.as(); CHECK(param != nullptr); - if (param->begin && param->end && param->strides) { - Array begin, end, strides; - begin = param->begin.value(); - end = param->end.value(); - strides = param->strides.value(); - return Array{ - topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; - } else { - te::Tensor data = inputs[0]; - te::Tensor begin = inputs[1]; - te::Tensor end = inputs[2]; - te::Tensor strides = inputs[3]; - // Dynamic computation - int64_t attr_size = data->shape.size(); - CHECK(begin->shape[0].as()->value == attr_size && - end->shape[0].as()->value == attr_size && - strides->shape[0].as()->value == attr_size) - << "begin, end, and strides are required to have the same length" - << " if they are non-constant."; - return Array{DynamicStridedSlice(data, begin, end, strides)}; + Array begin, end, strides; + begin = param->begin.value(); + end = param->end.value(); + strides = param->strides.value(); + if (IsDynamic(out_type)) { + auto input = inputs[0]; + size_t src_tensor_dim = input->shape.size(); + CHECK(begin.size() == src_tensor_dim) + << "for dynamic inputs, len(begin) must equal the input dimension"; + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(tvm::tir::Var("dim")); + } + Array begin_expr; + Array strides_expr; + for (size_t i = 0; i < src_tensor_dim; ++i) { + int64_t begin_i = begin[i]->value; + if (begin_i < 0) { + begin_i += topi::detail::GetConstInt(input->shape[i]); + } + begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); + strides_expr.push_back( + tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), + (i < strides.size() ? strides[i]->value : 1))); + } + return Array{te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + } + return input(real_indices); + }, + std::string{"T_strided_slice_dynamic"}, std::string{topi::kInjective})}; } + return Array{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; } // Positional relay function to create StridedSlice operator used by frontend FFI. -Expr MakeStridedSlice(Expr data, Expr begin, Expr end, Expr strides, String slice_mode) { +Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, + String slice_mode) { auto attrs = make_object(); - const ConstantNode *cbegin, *cend, *cstrides; - if ((cbegin = begin.as()) && (cend = end.as()) && - (cstrides = strides.as())) { - CHECK_EQ(cbegin->data->ndim, 1); - CHECK_EQ(cend->data->ndim, 1); - CHECK_EQ(cstrides->data->ndim, 1); - Array begin, end, strides; - begin = ToVector(cbegin->data); - end = ToVector(cend->data); - strides = ToVector(cstrides->data); - attrs->begin = begin; - attrs->end = end; - attrs->strides = strides; - } + attrs->begin = std::move(begin); + attrs->end = std::move(end); + attrs->strides = std::move(strides); attrs->slice_mode = slice_mode; static const Op& op = Op::Get("strided_slice"); - return Call(op, {data, begin, end, strides}, Attrs(attrs), {}); + return Call(op, {data}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op._make.strided_slice").set_body_typed(MakeStridedSlice); @@ -2276,12 +2261,8 @@ Examples:: [[ 5., 6.], [ 7., 8.]]] )code" TVM_ADD_FILELINE) - .set_num_inputs(4) + .set_num_inputs(1) .add_argument("data", "Tensor", "The input tensor.") - .add_argument("begin", "Tensor", "The indices to begin with in the slicing.") - .add_argument("end", "Tensor", "Indices indicating end of the slice.") - .add_argument("strides", "Tensor", "The stride values.") - .add_argument("slice_mode", "Tensor", "The slice mode.") .set_support_level(4) .set_attrs_type() .add_type_rel("StridedSlice", StridedSliceRel) diff --git a/src/relay/transforms/combine_parallel_batch_matmul.cc b/src/relay/transforms/combine_parallel_batch_matmul.cc index 1529631d5ec1..b2b9703c28bc 100644 --- a/src/relay/transforms/combine_parallel_batch_matmul.cc +++ b/src/relay/transforms/combine_parallel_batch_matmul.cc @@ -116,8 +116,8 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { auto feature_dim = batch_matmul->args[1]->type_as()->shape[1]; auto fpp = tir::as_const_int(feature_dim); int64_t features = *fpp; - std::vector begin; - std::vector end; + Array begin; + Array end; for (size_t i = 0; i < 2; i++) { begin.push_back(0); end.push_back(-1); @@ -125,12 +125,8 @@ class ParallelBatchMatmulCombiner : public ParallelOpCombiner { begin.push_back(index); index += features; end.push_back(features); - std::vector strides(begin.size(), 1); - std::vector ndarray_shape = {static_cast(begin.size())}; - Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin); - Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end); - Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides); - auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size"); + Array strides(begin.size(), 1); + auto slice = MakeStridedSlice(data, begin, end, strides, "size"); subst_map->insert({GetRef(branch[depth]), slice}); } } diff --git a/src/relay/transforms/combine_parallel_conv2d.cc b/src/relay/transforms/combine_parallel_conv2d.cc index 0bf9e7fd38a6..a639fcd60af6 100644 --- a/src/relay/transforms/combine_parallel_conv2d.cc +++ b/src/relay/transforms/combine_parallel_conv2d.cc @@ -168,24 +168,17 @@ class ParallelConv2DCombiner : public ParallelOpCombiner { for (const auto& branch : branches) { const CallNode* conv2d = branch[0]; int64_t channels = GetConv2DSuperChannelsDim(conv2d); - std::vector begin; - std::vector end; + Array begin; + Array end; for (size_t i = 0; i < channel_pos_; i++) { begin.push_back(0); end.push_back(-1); } begin.push_back(index); index += channels; - end.push_back(index); - std::vector strides(begin.size(), 1); - for (size_t i = 0; i < begin.size(); ++i) { - end[i] -= begin[i]; - } - std::vector ndarray_shape = {static_cast(begin.size())}; - Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin); - Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end); - Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides); - auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size"); + end.push_back(channels); + Array strides(begin.size(), 1); + auto slice = MakeStridedSlice(data, begin, end, strides, "size"); subst_map->insert({GetRef(branch[depth]), slice}); } } diff --git a/src/relay/transforms/combine_parallel_dense.cc b/src/relay/transforms/combine_parallel_dense.cc index aec4315cb083..76b26d0e085b 100644 --- a/src/relay/transforms/combine_parallel_dense.cc +++ b/src/relay/transforms/combine_parallel_dense.cc @@ -183,9 +183,9 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner { auto& out_shape = call->type_as()->shape; auto out_dims = tir::as_const_int(out_shape[out_shape.size() - 1]); CHECK(out_dims != nullptr); - std::vector begin; - std::vector end; - std::vector strides; + Array begin; + Array end; + Array strides; for (size_t k = 0; k < out_shape.size() - 1; ++k) { begin.push_back(0); end.push_back(-1); @@ -195,11 +195,7 @@ class ParallelDenseToDenseCombiner : public ParallelOpCombiner { end.push_back(*out_dims); strides.push_back(1); index += *out_dims; - std::vector ndarray_shape = {static_cast(begin.size())}; - Constant begin_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, begin); - Constant end_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, end); - Constant strides_const = MakeConstantTensor(DataType::Int(64), ndarray_shape, strides); - auto slice = MakeStridedSlice(data, begin_const, end_const, strides_const, "size"); + auto slice = MakeStridedSlice(data, begin, end, strides, "size"); subst_map->insert({GetRef(branch[depth]), slice}); } } diff --git a/src/relay/transforms/dynamic_to_static.cc b/src/relay/transforms/dynamic_to_static.cc index 0c417ad857a2..113b599579ab 100644 --- a/src/relay/transforms/dynamic_to_static.cc +++ b/src/relay/transforms/dynamic_to_static.cc @@ -173,6 +173,22 @@ class DynamicToStaticMutator : public MixedModeMutator { } return Expr(nullptr); }}, + {Op::Get("dyn.strided_slice"), + [](const CallNode* call_node) { + const ConstantNode* begin = call_node->args[1].as(); + const ConstantNode* end = call_node->args[2].as(); + const ConstantNode* stride = call_node->args[3].as(); + if (begin && end && stride) { + CHECK_EQ(begin->data->ndim, 1); + CHECK_EQ(end->data->ndim, 1); + CHECK_EQ(stride->data->ndim, 1); + const StridedSliceAttrs* param = call_node->attrs.as(); + CHECK(param); + return MakeStridedSlice(call_node->args[0], ToVector(begin->data), ToVector(end->data), + ToVector(stride->data), param->slice_mode); + } + return Expr(nullptr); + }}, }; } diff --git a/src/relay/transforms/pass_util.h b/src/relay/transforms/pass_util.h index 63708c45bfe3..f3c99ccfa120 100644 --- a/src/relay/transforms/pass_util.h +++ b/src/relay/transforms/pass_util.h @@ -27,6 +27,7 @@ #include #include +#include #include #include diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index 8bc551be0ff1..e3c8c9eb0bea 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -27,8 +27,8 @@ import random import tvm.testing - -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_broadcast_to(): dtype = 'uint8' rank = 3 @@ -53,8 +53,8 @@ def test_dyn_broadcast_to(): op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_one_hot(): def _get_oshape(indices_shape, depth, axis): oshape = [] @@ -80,12 +80,11 @@ def _verify(indices_shape, depth, on_value, off_value, axis, dtype): indices_np = np.random.randint(0, depth, size=indices_shape).astype("int32") out_np = tvm.topi.testing.one_hot(indices_np, on_value, off_value, depth, axis, dtype) for target, ctx in tvm.testing.enabled_targets(): - if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32")) - tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + out_relay = intrp.evaluate()(indices_np, np.array(depth).astype("int32")) + tvm.testing.assert_allclose(out_relay.asnumpy(), out_np) _verify((3, ), 3, 1, 0, -1, "int32") _verify((3, ), 3, 1.0, 0.0, -1, "float32") diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index b863d09db0a5..63dfd1075b6d 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -27,6 +27,8 @@ import tvm.topi.testing from tvm.relay.testing import run_infer_type +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_upsampling_run(): def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=False): @@ -51,19 +53,21 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa zz = run_infer_type(z) func = relay.Function([x, scale_h_var, scale_w_var], z) - for target, ctx in enabled_targets(): - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) + for target, ctx in tvm.testing.enabled_targets(): + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data, np.array(scale_h).astype("float32"), np.array(scale_w).astype("float32")) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-4, atol=1e-6) - verify_upsampling((1, 16, 32, 32), 3, 2.0,"NCHW", "nearest_neighbor") + verify_upsampling((1, 16, 32, 32), 3, 2.0, "NCHW", "nearest_neighbor") verify_upsampling((1, 16, 32, 32), 5, 2.0, "NCHW", "bilinear", True) verify_upsampling((1, 16, 32, 32), 2.0, 6, "NHWC", "nearest_neighbor") verify_upsampling((1, 16, 32, 32), 2.0, 2.0,"NHWC", "bilinear", True) #tests upsampling type inference with scale_h passed in as a constant and scale_w as a variable +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_upsampling_infer_type_const(): n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), te.size_var("w") @@ -74,6 +78,8 @@ def test_dyn_upsampling_infer_type_const(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_upsampling3d_run(): def verify_upsampling3d(dshape, scale_d, scale_h, scale_w, layout, method, coord_trans="half_pixel"): @@ -126,6 +132,9 @@ def test_dyn_upsampling3d_infer_type_const(): zz = run_infer_type(z) assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any(), relay.Any()), "int8") + +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_pad(): def verify_pad(dshape, pad_width, pad_val, dtype): x = relay.var("x", relay.TensorType(dshape, dtype)) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index d6a2806719ab..74b4e106e1f6 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -35,7 +35,8 @@ def verify_func(func, data, ref_res): tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) relay.backend.compile_engine.get().clear() -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_reshape(): def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -60,7 +61,8 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20)) verify_reshape((2, 3, 4), (0, -3), (2, 12)) -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_shape_reshape(): def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -77,7 +79,8 @@ def verify_reshape(shape, newshape, oshape): verify_reshape((2, 3, 4), (8, 3), (8, 3)) verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_tile(): def verify_tile(dshape, reps): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -94,7 +97,8 @@ def verify_tile(dshape, reps): verify_tile((2, 3), (3, 2, 1)) -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_zeros_ones(): def verify_zeros_ones(shape, dtype): for op, ref in [(relay.zeros, np.zeros), (relay.ones, np.ones)]: @@ -110,7 +114,8 @@ def verify_zeros_ones(shape, dtype): verify_zeros_ones((1, 3), 'int64') verify_zeros_ones((8, 9, 1, 2), 'float32') -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dyn_full(): def verify_full(fill_value, src_shape, dtype): x = relay.var("x", relay.scalar_type(dtype)) diff --git a/tests/python/relay/dyn/test_dynamic_op_level4.py b/tests/python/relay/dyn/test_dynamic_op_level4.py new file mode 100644 index 000000000000..b739a0e59285 --- /dev/null +++ b/tests/python/relay/dyn/test_dynamic_op_level4.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import numpy as np +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_infer_type +import tvm.topi.testing + + +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu +def test_dynamic_strided_slice(): + def verify(dshape, begin, end, strides, output, slice_mode="end", + test_ref=True, dtype="int32"): + x = relay.var("x", relay.TensorType(dshape, "float32")) + ndim = len(dshape) + begin = begin if begin else [0] * ndim + end = end if end else list(dshape) + if strides: + if len(strides) == 1: + strides = strides * ndim + else: + strides = [1] * ndim + + # target numpy result + x_data = np.random.uniform(size=dshape).astype("float32") + ref_res = tvm.topi.testing.strided_slice_python( + x_data, begin, end, strides, slice_mode) + data = [x_data, np.array(begin), np.array(end)] + + begin = relay.const(begin, dtype=dtype) + end = relay.const(end, dtype=dtype) + + + if strides: + data.append(np.array(strides)) + strides = relay.const(strides, dtype=dtype) + z = relay.strided_slice(x, + begin=begin, + end=end, + strides=strides, + slice_mode=slice_mode) + else: + z = relay.strided_slice(x, + begin=begin, + end=end, + slice_mode=slice_mode) + func = relay.Function([x], z) + + func = run_infer_type(func) + text = func.astext() + + if not test_ref: + return + for target, ctx in tvm.testing.enabled_targets(): + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) + + verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64") + verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3], + [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64") + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], + (2, 4, 3), slice_mode="size", test_ref=False) + verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], + (2, 2, 3), slice_mode="size", test_ref=True) + + +if __name__ == "__main__": + test_dynamic_strided_slice() diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index eb804fe430e3..a6e5b61b1c70 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -36,7 +36,8 @@ def test_resize_infer_type(): assert zz.checked_type == relay.TensorType((n, c, relay.Any(), relay.Any()), "int8") -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_resize(): def verify_resize(dshape, scale, method, layout): if layout == "NHWC": diff --git a/tests/python/relay/dyn/test_dynamic_op_level6.py b/tests/python/relay/dyn/test_dynamic_op_level6.py index 6dcde953710d..58bf53ce1117 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level6.py +++ b/tests/python/relay/dyn/test_dynamic_op_level6.py @@ -23,7 +23,8 @@ from tvm import relay import tvm.testing -@tvm.testing.uses_gpu +# TODO(mbrookhart): Enable when VM supports heterogenus execution +# @tvm.testing.uses_gpu def test_dynamic_topk(): def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) @@ -53,7 +54,6 @@ def verify_topk(k, axis, ret_type, is_ascend, dtype): np_indices = np_indices.astype(dtype) for target, ctx in tvm.testing.enabled_targets(): - if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 8c62f8c0727f..4f74d72277d9 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -335,7 +335,7 @@ def test_mean_var_std(): @tvm.testing.uses_gpu def test_strided_slice(): def verify(dshape, begin, end, strides, output, slice_mode="end", - attr_const=True, test_ref=True, dtype="int32"): + test_ref=True, dtype="int32"): x = relay.var("x", relay.TensorType(dshape, "float32")) ndim = len(dshape) begin = begin if begin else [0] * ndim @@ -346,13 +346,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", ref_res = tvm.topi.testing.strided_slice_python( x_data, begin, end, strides, slice_mode) - if attr_const: - begin = relay.const(begin, dtype=dtype) - end = relay.const(end, dtype=dtype) - if strides: - if attr_const: - strides = relay.const(strides, dtype=dtype) z = relay.strided_slice(x, begin=begin, end=end, @@ -385,7 +379,6 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64") verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) - verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2), attr_const=False) verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4], None, (2, 3, 3)) @@ -397,6 +390,65 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) +#TODO(mbrookhart): enable once vm supports heterogenous execution +#@tvm.testing.uses_gpu +def test_dyn_strided_slice(): + def verify(dshape, begin, end, strides, output, slice_mode="end", + test_ref=True, dtype="int32"): + ndim = len(dshape) + begin = begin if begin else [0] * ndim + end = end if end else list(dshape) + + # target numpy result + x_data = np.random.uniform(size=dshape).astype("float32") + ref_res = tvm.topi.testing.strided_slice_python( + x_data, begin, end, strides, slice_mode) + + x = relay.var("x", relay.TensorType((relay.Any(), ) * ndim, "float32")) + if strides: + z = relay.strided_slice(x, + begin=begin, + end=end, + strides=strides, + slice_mode=slice_mode) + else: + z = relay.strided_slice(x, + begin=begin, + end=end, + slice_mode=slice_mode) + func = relay.Function([x], z) + + func = run_infer_type(func) + text = func.astext() + assert "begin=" in text + assert "end=" in text + + if not test_ref: + return + for target, ctx in tvm.testing.enabled_targets(): + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res) + + verify((1, 3, 10, 10), [0, 0, 0, 0], [-1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64") + verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3], + [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64") + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + #TODO(mbrookhart): fix static strided_slice with dynamic input and negative begin + #verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + #verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], + (2, 4, 3), slice_mode="size", test_ref=False) + verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], + (2, 2, 3), slice_mode="size", test_ref=True) + + @tvm.testing.uses_gpu def test_strided_set(): def verify(dshape, begin, end, strides, vshape, test_ref=True): diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 0e0ab570ec10..3bd82b2a9cf4 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -624,9 +624,9 @@ def before(): weight = relay.var('weight', shape=(32, 32, 3, 3)) y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1)) y = relay.strided_slice(y, - begin=relay.const([0, 16], "int32"), - end=relay.const([1, 33], "int32"), - strides=relay.const([1, 1], "int32")) + begin=[0, 16], + end=[1, 33], + strides=[1, 1]) y = relay.Function(analysis.free_vars(y), y) return y @@ -645,9 +645,9 @@ def expected(): data_layout="NCHW4c") y = relay.strided_slice(y, - begin=relay.const([0, 4], "int32"), - end=relay.const([1, 21], "int32"), - strides=relay.const([1, 1], "int32")) + begin=[0, 4], + end=[1, 21], + strides=[1, 1]) y = relay.layout_transform(y, "NCHW4c", "NCHW") y = relay.Function(analysis.free_vars(y), y) diff --git a/tests/python/relay/test_pass_combine_parallel_batch_matmul.py b/tests/python/relay/test_pass_combine_parallel_batch_matmul.py index 00d8ac40a129..edede97293b7 100644 --- a/tests/python/relay/test_pass_combine_parallel_batch_matmul.py +++ b/tests/python/relay/test_pass_combine_parallel_batch_matmul.py @@ -47,19 +47,19 @@ def expected(x, w1, w2, w3): w = relay.concatenate((w1, w2, w3), axis=1) y = relay.nn.batch_matmul(x, w) y1 = relay.strided_slice(y, - begin=relay.const([0, 0, 0], "int64"), - end=relay.const([-1, -1, s1], "int64"), - strides=relay.const([1, 1, 1], 'int64'), + begin=[0, 0, 0], + end=[-1, -1, s1], + strides=[1, 1, 1], slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, 0, s1], "int64"), - end=relay.const([-1, -1, s2], "int64"), - strides=relay.const([1, 1, 1], 'int64'), + begin=[0, 0, s1], + end=[-1, -1, s2], + strides=[1, 1, 1], slice_mode="size") y3 = relay.strided_slice(y, - begin=relay.const([0, 0, s1+s2], "int64"), - end=relay.const([-1, -1, s3], "int64"), - strides=relay.const([1, 1, 1], 'int64'), + begin=[0, 0, s1+s2], + end=[-1, -1, s3], + strides=[1, 1, 1], slice_mode="size") y = relay.Tuple((y1, y2, y3)) return relay.Function(args, y) @@ -104,19 +104,19 @@ def expected(x, w1, w2, w3, b1, b2, b3): y = relay.nn.batch_matmul(x, w) y = relay.add(y, b) y1 = relay.strided_slice(y, - begin=relay.const([0, 0, 0], "int64"), - end=relay.const([-1, -1, s1], "int64"), - strides=relay.const([1, 1, 1], 'int64'), + begin=[0, 0, 0], + end=[-1, -1, s1], + strides=[1, 1, 1], slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, 0, s1], "int64"), - end=relay.const([-1, -1, s2], "int64"), - strides=relay.const([1, 1, 1], 'int64'), + begin=[0, 0, s1], + end=[-1, -1, s2], + strides=[1, 1, 1], slice_mode="size") y3 = relay.strided_slice(y, - begin=relay.const([0, 0, s1+s2], "int64"), - end=relay.const([-1, -1, s3], "int64"), - strides=relay.const([1, 1, 1], 'int64'), + begin=[0, 0, s1+s2], + end=[-1, -1, s3], + strides=[1, 1, 1], slice_mode="size") y = relay.Tuple((y1, y2, y3)) return relay.Function(args, y) diff --git a/tests/python/relay/test_pass_combine_parallel_conv2d.py b/tests/python/relay/test_pass_combine_parallel_conv2d.py index 68e7fece7e98..f48cdd608242 100644 --- a/tests/python/relay/test_pass_combine_parallel_conv2d.py +++ b/tests/python/relay/test_pass_combine_parallel_conv2d.py @@ -50,20 +50,20 @@ def expected(x, w1, w2, w3, w4, channels1, channels2, channels3, channels4): w = relay.concatenate((w1, w2, w4), axis=0) y = relay.nn.conv2d(x, w, channels=channels1 + channels2 + channels4) y1 = relay.strided_slice(y, - begin=relay.const([0, 0], "int64"), - end=relay.const([-1, channels1], "int64"), - strides=relay.const([1, 1], 'int64'), + begin=[0, 0], + end=[-1, channels1], + strides=[1, 1], slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, channels1], "int64"), - end=relay.const([-1, channels2], "int64"), - strides=relay.const([1, 1], 'int64'), + begin=[0, channels1], + end=[-1, channels2], + strides=[1, 1], slice_mode="size") y3 = relay.nn.conv2d(x, w3) y4 = relay.strided_slice(y, - begin=relay.const([0, channels1 + channels2], "int64"), - end=relay.const([-1, channels4], "int64"), - strides=relay.const([1, 1], 'int64'), + begin=[0, channels1 + channels2], + end=[-1, channels4], + strides=[1, 1], slice_mode="size") y5 = relay.nn.max_pool2d(x) y = relay.Tuple((y1, y2, y3, y4, y5)) @@ -110,14 +110,14 @@ def expected(x, w1, w2, scale1, scale2, bias, channels1, channels2): y = relay.multiply(y, scale) y = relay.nn.relu(y) y1 = relay.strided_slice(y, - begin=relay.const([0, 0], "int64"), - end=relay.const([-1, channels1], "int64"), - strides=relay.const([1, 1], "int64"), + begin=[0, 0], + end=[-1, channels1], + strides=[1, 1], slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, channels1], "int64"), - end=relay.const([-1, channels2], "int64"), - strides=relay.const([1, 1], "int64"), + begin=[0, channels1], + end=[-1, channels2], + strides=[1, 1], slice_mode="size") y2 = relay.add(y2, bias) y = relay.Tuple((y1, y2)) @@ -157,14 +157,14 @@ def expected(x, w1, w2, scale1, scale2, channels1, channels2): w = relay.concatenate((w1, w2), axis=0) y = relay.nn.conv2d(x, w, channels=channels1 + channels2) y1 = relay.strided_slice(y, - begin=relay.const([0, 0], "int64"), - end=relay.const([-1, channels1], "int64"), - strides=relay.const([1, 1], "int64"), + begin=[0, 0], + end=[-1, channels1], + strides=[1, 1], slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, channels1], "int64"), - end=relay.const([-1, channels2], "int64"), - strides=relay.const([1, 1], "int64"), + begin=[0, channels1], + end=[-1, channels2], + strides=[1, 1], slice_mode="size") y1 = relay.multiply(y1, scale1) y2 = relay.multiply(y2, scale2) @@ -205,14 +205,14 @@ def expected(x, w, channels, repeat): w_concat = relay.concatenate((w, w), axis=0) y = relay.nn.conv2d(y, w_concat, channels=channels*2) y1 = relay.strided_slice(y, - begin=relay.const([0, 0], "int64"), - end=relay.const([-1, channels], "int64"), - strides=relay.const([1, 1], "int64"), + begin=[0, 0], + end=[-1, channels], + strides=[1, 1], slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, channels], "int64"), - end=relay.const([-1, channels], "int64"), - strides=relay.const([1, 1], "int64"), + begin=[0, channels], + end=[-1, channels], + strides=[1, 1], slice_mode="size") y = relay.concatenate((y1, y2), axis=1) return relay.Function(args, y) diff --git a/tests/python/relay/test_pass_combine_parallel_dense.py b/tests/python/relay/test_pass_combine_parallel_dense.py index 535f97a39d08..3c5cd9d8b054 100644 --- a/tests/python/relay/test_pass_combine_parallel_dense.py +++ b/tests/python/relay/test_pass_combine_parallel_dense.py @@ -204,18 +204,18 @@ def expected(x, w1, w2, w3, j): args = [x, w1, w2, w3] w_stacked = relay.concatenate((w1, w2, w3), axis=0) y = relay.nn.dense(x, w_stacked, units=6 * j) - strides = relay.const([1, 1], 'int64') + strides = [1, 1] y1 = relay.strided_slice(y, - begin=relay.const([0, 0], "int64"), - end=relay.const([-1, j], "int64"), + begin=[0, 0], + end=[-1, j], strides=strides, slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, j], "int64"), - end=relay.const([-1, 2 * j], "int64"), + begin=[0, j], + end=[-1, 2 * j], strides=strides, slice_mode="size") y3 = relay.strided_slice(y, - begin=relay.const([0, 3 * j], "int64"), - end=relay.const([-1, 3 * j], "int64"), + begin=[0, 3 * j], + end=[-1, 3 * j], strides=strides, slice_mode="size") y = relay.Tuple((y1, y2, y3)) return relay.Function(args, y) @@ -268,14 +268,14 @@ def expected(x, w1, w2, b1, b2, j, bias_shape1, bias_shape2): end = [-1 for _ in range(n_out_dims - 1)] strides = [1 for _ in range(n_out_dims)] y1 = relay.strided_slice(y, - begin=relay.const(begin + [0], "int64"), - end=relay.const(end + [j], "int64"), - strides=relay.const(strides, "int64"), + begin=begin + [0], + end=end + [j], + strides=strides, slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const(begin + [j], "int64"), - end=relay.const(end + [2 * j], "int64"), - strides=relay.const(strides, "int64"), + begin=begin + [j], + end=end + [2 * j], + strides=strides, slice_mode="size") return relay.Function(args, relay.Tuple((y1, y2))) @@ -335,14 +335,14 @@ def expected(x, w1, w2, b1, b2, scale1, scale2, newshape1, newshape2, j): scale2 = relay.repeat(scale2, 2 * j, 0) scale = relay.concatenate((scale1, scale2), axis=0) y = relay.multiply(y, scale) - strides = relay.const([1, 1], 'int64') + strides = [1, 1] y1 = relay.strided_slice(y, - begin=relay.const([0, 0], "int64"), - end=relay.const([-1, j], "int64"), + begin=[0, 0], + end=[-1, j], strides=strides, slice_mode="size") y2 = relay.strided_slice(y, - begin=relay.const([0, j], "int64"), - end=relay.const([-1, 2 * j], "int64"), + begin=[0, j], + end=[-1, 2 * j], strides=strides, slice_mode="size") y1 = relay.reshape(y1, newshape=newshape1) y2 = relay.reshape(y2, newshape=newshape2) diff --git a/tests/python/relay/test_pass_dynamic_to_static.py b/tests/python/relay/test_pass_dynamic_to_static.py index d1bf846d8aec..210dfc88125a 100644 --- a/tests/python/relay/test_pass_dynamic_to_static.py +++ b/tests/python/relay/test_pass_dynamic_to_static.py @@ -24,6 +24,7 @@ import tvm.topi.testing import tvm.testing + def run_opt_pass(expr, opt_pass): assert isinstance(opt_pass, tvm.transform.Pass) @@ -395,6 +396,68 @@ def verify_pad(data_shape, pad_width, pad_val, dtype): verify_pad((4, 10, 7, 7), ((1, 1), (2, 2), (3, 3), (4, 4)), 2.0, "int32") verify_pad((2, 7), ((1, 4), (2, 2)), 4.0, "float64") + +def test_dynamic_to_static_strided_slice(): + def verify(dshape, begin, end, strides, output, slice_mode="end", + test_ref=True, dtype="int32"): + x = relay.var("x", relay.TensorType(dshape, "float32")) + ndim = len(dshape) + begin = begin if begin else [0] * ndim + end = end if end else list(dshape) + if strides: + if len(strides) == 1: + strides = strides * ndim + else: + strides = [1] * ndim + + # target numpy result + x_data = np.random.uniform(size=dshape).astype("float32") + ref_res = tvm.topi.testing.strided_slice_python( + x_data, begin, end, strides, slice_mode) + data = [x_data, np.array(begin), np.array(end)] + + begin = relay.const(begin, dtype=dtype) + end = relay.const(end, dtype=dtype) + + + if strides: + data.append(np.array(strides)) + strides = relay.const(strides, dtype=dtype) + z = relay.strided_slice(x, + begin=begin, + end=end, + strides=strides, + slice_mode=slice_mode) + else: + z = relay.strided_slice(x, + begin=begin, + end=end, + slice_mode=slice_mode) + func = relay.Function([x], z) + + func = run_infer_type(func) + func2 = run_opt_pass(run_opt_pass(func, transform.DynamicToStatic()), transform.InferType()) + assert isinstance(func2.body, relay.Call) + assert func2.body.op == relay.op.get("strided_slice") + verify_func(func2, [x_data], ref_res) + + verify((1, 3, 10, 10), [0, 0, 0, 0], [1, 3, 10, 10], [1], (0, 3, 10, 10), dtype="int64") + verify((1, 224, 224, 3), [0, 20, 20, 0], [1, 140, 140, 3], + [1, 1, 1, 1], (1, 120, 120, 3), dtype="int64") + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], (1, 3, 3), dtype="int16") + verify((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], (3, 1, 2)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) + verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], + (2, 4, 3), slice_mode="size", test_ref=False) + verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], + (2, 2, 3), slice_mode="size", test_ref=True) + + if __name__ == "__main__": test_dynamic_to_static_reshape() test_dynamic_to_static_double_reshape() @@ -408,3 +471,4 @@ def verify_pad(data_shape, pad_width, pad_val, dtype): test_dynamic_to_static_full() test_dynamic_to_static_upsampling() test_dynamic_to_static_pad() + test_dynamic_to_static_strided_slice()