-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic Strided Slice #6316
Dynamic Strided Slice #6316
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good to me! I added a few style suggestions.
PS that was a lot of relay.const calls!
|
||
|
||
@script | ||
def _strided_slice_shape_func_input_data(data, begin, end, strides, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the difference between _strided_slice_shape_func_input_shape
and _strided_slice_shape_func_input_data
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the static op, we only require the input shapes, since the tributes are static, thus the input_shape variant. For the dynamic op, we need the input data, since the attributes are coming in as values. Thus the input_data variant. Why the VM is set up to separate those instead of just sending in data all the time, I'm not sure. Optimization?
python/tvm/relay/op/transform.py
Outdated
end = const(list(end)) | ||
if isinstance(strides, (tuple, list)): | ||
strides = const(list(strides)) | ||
begin = _make.where(begin < cast_like(const(0), begin), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you rename this begin for clarity?
@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full") | |||
.set_attr<FTVMCompute>("FTVMCompute", FullCompute) | |||
.set_attr<TOpPattern>("TOpPattern", kElemWise); | |||
|
|||
bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, | |||
const TypeReporter& reporter) { | |||
CHECK_EQ(types.size(), 5); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to add a comment saying what each of the input types are (ie types = [type1_description, ... ret_type]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/relay/op/dyn/tensor/transform.cc
Outdated
te::Tensor end = inputs[2]; | ||
te::Tensor strides = inputs[3]; | ||
// Dynamic computation | ||
int64_t attr_size = data->shape.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this stand for attribute size? If so, the name seems a bit inaccurate
src/relay/op/dyn/tensor/transform.cc
Outdated
end->shape[0].as<IntImmNode>()->value == attr_size && | ||
strides->shape[0].as<IntImmNode>()->value == attr_size) | ||
<< "begin, end, and strides are required to have the same length" | ||
<< " if they are non-constant."; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wording of this error is a bit confusing; "begin, end, and strides are required to have the same length or must all be constants" might be better
src/relay/op/tensor/transform.cc
Outdated
for (int64_t i = 0; i < num_axis; ++i) { | ||
oshape[i] = Any(); | ||
} | ||
CHECK(false) << "strided_slice recieved invalid params"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could state in this error that strided_slice received an incorrect beginning, end, or strides tensor.
{Op::Get("dyn.strided_slice"), | ||
[](const CallNode* call_node) { | ||
if (const ConstantNode* begin = call_node->args[1].as<ConstantNode>()) { | ||
if (const ConstantNode* end = call_node->args[2].as<ConstantNode>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be cleaner to pull these definitions out of the if statements, and then check whether they are null or not in one if statement, though potentially slower
tests/python/relay/test_op_level4.py
Outdated
@@ -343,7 +337,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", | |||
text = func.astext() | |||
assert "begin=" in text | |||
assert "end=" in text | |||
|
|||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
white space!!
8a1719b
to
ad7a37f
Compare
The unit test error appears to be a bug with the cython backend, we're getting int64t overflow if we use cython, but not with ctypes. |
ad7a37f
to
311e79b
Compare
The failing unit test is a symptom of the bug fixed here: #6321 I'll rebase this PR once that's merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall lgtm, thanks for migrating to dyn module.
@@ -165,6 +138,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice | |||
cstride = int64(strides[i]) | |||
if len(begin) > i: | |||
cbegin = int64(begin[i]) | |||
if cbegin < 0: | |||
cbegin += int64(data_shape[i]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about cbegin is still less than 0 after adding data_shape[i]? probably raise error for invalid data input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I attempted to do this with an assert in the hybrid script, but something seems a little off in the compiler, even after these lines it was still checking against the original value.
end = const(list(end)) | ||
if isinstance(strides, (tuple, list)): | ||
strides = const(list(strides)) | ||
normalized_begin = _make.where(begin < cast_like(const(0), begin), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could consider moving the normalization step into strided_slice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, yeah, seems a little odd to produce a subgraph as part of the constructor of an op, which is why I put it here. That being said, this makes the op less useful from other frontends, so...
Any other votes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does this make it less useful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'd have to reproduce this on any frontend that creates ops.
src/relay/op/tensor/transform.cc
Outdated
for (int64_t i = 0; i < num_axis; ++i) { | ||
oshape[i] = Any(); | ||
} | ||
CHECK(param->begin) << "strided_slice recieved invalid begin"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print the received begin value in the message?
311e79b
to
99ac920
Compare
@mbrookhart #6314 introduced another strided slice usage, you might need to update that too. |
I don't think so, it's already targeting the dynamic op
|
@yongwww @electriclilies Can you take another look? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
@mbrookhart please rebase |
python/tvm/relay/frontend/mxnet.py
Outdated
_expr.const(begin, dtype="int32"), | ||
_expr.const(end, dtype="int32")) | ||
begin, | ||
end) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this could be one-liner
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mbrookhart #6337 has been merged. You can try to enable the GPU tests now
I think I'd like to do it as a separate PR after we get this merged, I don't want to conflate the strided slice op with the others too much, but I'm excited to start on this :) I'll split a branch and start playing with that. |
@mbrookhart no problem BTW, I would suggest moving the compute to topi and make dynamic version there in the next PR. Do we need to change the topi strided_slice definition to let it use dynamic/static version? |
@masahi @electriclilies could you take another look and approve/comment? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@zhiics Could you take another look? |
* Dynamic Strided Slice * fix clang-format lint * remove debug print * respond to review comments * respond to yongwww's comments * fix bad rebase * revert hybrid-script assert * reformat mxnet change * use new testing api * while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
* Dynamic Strided Slice * fix clang-format lint * remove debug print * respond to review comments * respond to yongwww's comments * fix bad rebase * revert hybrid-script assert * reformat mxnet change * use new testing api * while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
* Dynamic Strided Slice * fix clang-format lint * remove debug print * respond to review comments * respond to yongwww's comments * fix bad rebase * revert hybrid-script assert * reformat mxnet change * use new testing api * while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
This PR splits the dynamic-parameter version of strided slice from the constant parameter version, improves the shape funcs to pass more unit test, and supports executing the constant parameter version of the op on dynamically shaped inputs.
I was able to pass most tests, the only except is that on the static op with a dynamic input tensor, I can't pass tests with negative
begin
values.There's a lot of te::compute being thrown around here, but I haven't seen a clean way to simplify it yet. I would love suggestions.
Thanks!
cc @zhiics @kevinthesun @yongwww @lixiaoquan @electriclilies