Skip to content

Commit

Permalink
[Relay][AlterOpLayout] Fix strided slice type change. (apache#8022)
Browse files Browse the repository at this point in the history
* Fixed strided_slice alteroplayout bug.

* add test for non standard int8 conv2d padding.

* Add test for large index slices.

* Us same dtype as input in strided slice.
  • Loading branch information
Josh Fromm authored May 13, 2021
1 parent 158aedd commit c999a84
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 12 deletions.
2 changes: 1 addition & 1 deletion python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def _conv2d_legalize(attrs, inputs, arg_types):
data = relay.cast(data, "uint8")

# Do external padding as pad value has to be 128.
if not (padding[0] == 0 and padding[1] == 0):
if any(padding):
data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
new_attrs["padding"] = (0, 0)

Expand Down
21 changes: 11 additions & 10 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2570,16 +2570,17 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
if (params->begin && params->end && params->strides) {
for (Integer i : params->strides.value()) {
ICHECK(i.defined());
strides.push_back(params->slice_mode == "size" ? 1 : i->value);
auto slice_val = Integer(IntImm(i->dtype, i->value));
strides.push_back(params->slice_mode == "size" ? Integer(IntImm(i->dtype, 1)) : slice_val);
}

for (Integer i : params->begin.value()) {
ICHECK(i.defined());
begin.push_back(i->value);
begin.push_back(IntImm(i->dtype, i->value));
}
for (Integer i : params->end.value()) {
ICHECK(i.defined());
end.push_back(i->value);
end.push_back(IntImm(i->dtype, i->value));
}
}

Expand Down Expand Up @@ -2619,9 +2620,9 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
ed = shape[new_index].as<IntImmNode>()->value;
}

new_begin.push_back(bg);
new_end.push_back(ed);
new_strides.push_back(st);
new_begin.push_back(IntImm(begin[0]->dtype, bg));
new_end.push_back(IntImm(end[0]->dtype, ed));
new_strides.push_back(IntImm(strides[0]->dtype, st));
}
params->begin = new_begin;
params->end = new_end;
Expand All @@ -2637,8 +2638,8 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
}
auto factor = new_layout.FactorOf(axis);
if (factor == -1) {
new_begin.push_back(begin[i]);
new_end.push_back(end[i]);
new_begin.push_back(IntImm(begin[i]->dtype, begin[i]));
new_end.push_back(IntImm(end[i]->dtype, end[i]));
} else {
if (strides.defined() && i < strides.size()) {
auto stride = strides[i];
Expand All @@ -2665,8 +2666,8 @@ Array<Array<Layout>> StridedSliceInferCorrectLayout(const Attrs& attrs,
// transform to original layout
return {{Layout::Undef()}, {Layout::Undef()}};
}
new_begin.push_back(tvm::Integer(bg / factor));
new_end.push_back(tvm::Integer(ed / factor));
new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor)));
new_end.push_back(IntImm(end[0]->dtype, (ed / factor)));
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,7 @@ def _compile(ic, oc, target, data_layout, kernel_layout, dtypes):
weight,
kernel_size=(ch, cw),
channels=oc,
padding=(1, 1),
padding=(0, 0, 0, 1),
dilation=(1, 1),
data_layout=data_layout,
kernel_layout=kernel_layout,
Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_op_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,8 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True,
verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3))
# Test backwards slicing.
verify((3, 4, 3), [-1, -1, -1], [-5, -5, -5], [-1, -1, -1], (3, 4, 3))
# Test slicing with overlarge indices.
verify((3, 4, 3), [0, 0, 0], [np.iinfo(np.int64).max] * 3, [1, 1, 1], (3, 4, 3))
# Test slice mode.
verify(
(3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False
Expand Down

0 comments on commit c999a84

Please sign in to comment.