From c999a840cb5579c493f5b5e7f20bc619260dad08 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 13 May 2021 09:26:27 -0700 Subject: [PATCH] [Relay][AlterOpLayout] Fix strided slice type change. (#8022) * Fixed strided_slice alteroplayout bug. * add test for non standard int8 conv2d padding. * Add test for large index slices. * Us same dtype as input in strided slice. --- python/tvm/topi/x86/conv2d_alter_op.py | 2 +- src/relay/op/tensor/transform.cc | 21 +++++++++++---------- tests/python/relay/test_op_level2.py | 2 +- tests/python/relay/test_op_level4.py | 2 ++ 4 files changed, 15 insertions(+), 12 deletions(-) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index f05bac82ff0c..8e47dff37ce6 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -338,7 +338,7 @@ def _conv2d_legalize(attrs, inputs, arg_types): data = relay.cast(data, "uint8") # Do external padding as pad value has to be 128. - if not (padding[0] == 0 and padding[1] == 0): + if any(padding): data = relay.nn.pad(data, pad_width=pad_width, pad_value=128) new_attrs["padding"] = (0, 0) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 9111fa529e5b..df60aeb16bf3 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2570,16 +2570,17 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, if (params->begin && params->end && params->strides) { for (Integer i : params->strides.value()) { ICHECK(i.defined()); - strides.push_back(params->slice_mode == "size" ? 1 : i->value); + auto slice_val = Integer(IntImm(i->dtype, i->value)); + strides.push_back(params->slice_mode == "size" ? Integer(IntImm(i->dtype, 1)) : slice_val); } for (Integer i : params->begin.value()) { ICHECK(i.defined()); - begin.push_back(i->value); + begin.push_back(IntImm(i->dtype, i->value)); } for (Integer i : params->end.value()) { ICHECK(i.defined()); - end.push_back(i->value); + end.push_back(IntImm(i->dtype, i->value)); } } @@ -2619,9 +2620,9 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, ed = shape[new_index].as()->value; } - new_begin.push_back(bg); - new_end.push_back(ed); - new_strides.push_back(st); + new_begin.push_back(IntImm(begin[0]->dtype, bg)); + new_end.push_back(IntImm(end[0]->dtype, ed)); + new_strides.push_back(IntImm(strides[0]->dtype, st)); } params->begin = new_begin; params->end = new_end; @@ -2637,8 +2638,8 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, } auto factor = new_layout.FactorOf(axis); if (factor == -1) { - new_begin.push_back(begin[i]); - new_end.push_back(end[i]); + new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); + new_end.push_back(IntImm(end[i]->dtype, end[i])); } else { if (strides.defined() && i < strides.size()) { auto stride = strides[i]; @@ -2665,8 +2666,8 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; } - new_begin.push_back(tvm::Integer(bg / factor)); - new_end.push_back(tvm::Integer(ed / factor)); + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } } diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index d0ff86bffcde..b76facae5aa3 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -1569,7 +1569,7 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes): weight, kernel_size=(ch, cw), channels=oc, - padding=(1, 1), + padding=(0, 0, 0, 1), dilation=(1, 1), data_layout=data_layout, kernel_layout=kernel_layout, diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 036d4a0f6044..8de644999c9e 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -429,6 +429,8 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) # Test backwards slicing. verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3)) + # Test slicing with overlarge indices. + verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int64).max] * 3, [1, 1, 1], (3, 4, 3)) # Test slice mode. verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False