Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Strided Slice #6316

Merged
merged 12 commits into from
Sep 8, 2020

Conversation

mbrookhart
Copy link
Contributor

@mbrookhart mbrookhart commented Aug 20, 2020

This PR splits the dynamic-parameter version of strided slice from the constant parameter version, improves the shape funcs to pass more unit test, and supports executing the constant parameter version of the op on dynamically shaped inputs.

I was able to pass most tests, the only except is that on the static op with a dynamic input tensor, I can't pass tests with negative begin values.

There's a lot of te::compute being thrown around here, but I haven't seen a clean way to simplify it yet. I would love suggestions.

Thanks!

cc @zhiics @kevinthesun @yongwww @lixiaoquan @electriclilies

Copy link
Contributor

@electriclilies electriclilies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good to me! I added a few style suggestions.

PS that was a lot of relay.const calls!



@script
def _strided_slice_shape_func_input_data(data, begin, end, strides,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between _strided_slice_shape_func_input_shape and _strided_slice_shape_func_input_data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the static op, we only require the input shapes, since the tributes are static, thus the input_shape variant. For the dynamic op, we need the input data, since the attributes are coming in as values. Thus the input_data variant. Why the VM is set up to separate those instead of just sending in data all the time, I'm not sure. Optimization?

end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
begin = _make.where(begin < cast_like(const(0), begin),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you rename this begin for clarity?

@@ -430,6 +434,114 @@ RELAY_REGISTER_OP("dyn.full")
.set_attr<FTVMCompute>("FTVMCompute", FullCompute)
.set_attr<TOpPattern>("TOpPattern", kElemWise);

bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 5);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to add a comment saying what each of the input types are (ie types = [type1_description, ... ret_type]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

te::Tensor end = inputs[2];
te::Tensor strides = inputs[3];
// Dynamic computation
int64_t attr_size = data->shape.size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this stand for attribute size? If so, the name seems a bit inaccurate

end->shape[0].as<IntImmNode>()->value == attr_size &&
strides->shape[0].as<IntImmNode>()->value == attr_size)
<< "begin, end, and strides are required to have the same length"
<< " if they are non-constant.";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wording of this error is a bit confusing; "begin, end, and strides are required to have the same length or must all be constants" might be better

for (int64_t i = 0; i < num_axis; ++i) {
oshape[i] = Any();
}
CHECK(false) << "strided_slice recieved invalid params";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could state in this error that strided_slice received an incorrect beginning, end, or strides tensor.

{Op::Get("dyn.strided_slice"),
[](const CallNode* call_node) {
if (const ConstantNode* begin = call_node->args[1].as<ConstantNode>()) {
if (const ConstantNode* end = call_node->args[2].as<ConstantNode>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be cleaner to pull these definitions out of the if statements, and then check whether they are null or not in one if statement, though potentially slower

@@ -343,7 +337,7 @@ def verify(dshape, begin, end, strides, output, slice_mode="end",
text = func.astext()
assert "begin=" in text
assert "end=" in text

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

white space!!

@mbrookhart mbrookhart force-pushed the mbrookhart/dynamic_strided_slice branch from 8a1719b to ad7a37f Compare August 21, 2020 16:23
@mbrookhart
Copy link
Contributor Author

The unit test error appears to be a bug with the cython backend, we're getting int64t overflow if we use cython, but not with ctypes.

@mbrookhart mbrookhart force-pushed the mbrookhart/dynamic_strided_slice branch from ad7a37f to 311e79b Compare August 21, 2020 16:53
@mbrookhart
Copy link
Contributor Author

The failing unit test is a symptom of the bug fixed here: #6321

I'll rebase this PR once that's merged.

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall lgtm, thanks for migrating to dyn module.

@@ -165,6 +138,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
cstride = int64(strides[i])
if len(begin) > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data_shape[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about cbegin is still less than 0 after adding data_shape[i]? probably raise error for invalid data input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted to do this with an assert in the hybrid script, but something seems a little off in the compiler, even after these lines it was still checking against the original value.

end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
normalized_begin = _make.where(begin < cast_like(const(0), begin),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could consider moving the normalization step into strided_slice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeah, seems a little odd to produce a subgraph as part of the constructor of an op, which is why I put it here. That being said, this makes the op less useful from other frontends, so...

Any other votes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this make it less useful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd have to reproduce this on any frontend that creates ops.

for (int64_t i = 0; i < num_axis; ++i) {
oshape[i] = Any();
}
CHECK(param->begin) << "strided_slice recieved invalid begin";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print the received begin value in the message?

@yongwww
Copy link
Member

yongwww commented Aug 24, 2020

cc @kevinthesun @zhiics

@mbrookhart mbrookhart force-pushed the mbrookhart/dynamic_strided_slice branch from 311e79b to 99ac920 Compare August 24, 2020 16:52
@masahi
Copy link
Member

masahi commented Aug 24, 2020

@mbrookhart #6314 introduced another strided slice usage, you might need to update that too.

@mbrookhart
Copy link
Contributor Author

I don't think so, it's already targeting the dynamic op

        # squeeze the two outputs of nms for strided_slice
        size = get_relay_op("squeeze")(nms_ret[1], axis=[1])
        data_slice = get_relay_op("squeeze")(nms_ret[0], axis=[0])

        # strided slice to get the dynamic result
        return get_relay_op("strided_slice")(data_slice, begin=_expr.const([0]),
                                             end=size, slice_mode="size")

@mbrookhart
Copy link
Contributor Author

@yongwww @electriclilies Can you take another look?

Copy link
Member

@yongwww yongwww left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@zhiics
Copy link
Member

zhiics commented Sep 3, 2020

@mbrookhart please rebase

_expr.const(begin, dtype="int32"),
_expr.const(end, dtype="int32"))
begin,
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be one-liner

Copy link
Member

@zhiics zhiics left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbrookhart #6337 has been merged. You can try to enable the GPU tests now

@mbrookhart
Copy link
Contributor Author

@mbrookhart #6337 has been merged. You can try to enable the GPU tests now

I think I'd like to do it as a separate PR after we get this merged, I don't want to conflate the strided slice op with the others too much, but I'm excited to start on this :) I'll split a branch and start playing with that.

@zhiics
Copy link
Member

zhiics commented Sep 3, 2020

@mbrookhart no problem

BTW, I would suggest moving the compute to topi and make dynamic version there in the next PR. Do we need to change the topi strided_slice definition to let it use dynamic/static version?

@zhiics
Copy link
Member

zhiics commented Sep 3, 2020

@masahi @electriclilies could you take another look and approve/comment?

Copy link
Contributor

@electriclilies electriclilies left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@mbrookhart
Copy link
Contributor Author

@zhiics Could you take another look?

@zhiics zhiics merged commit eee413f into apache:master Sep 8, 2020
@zhiics
Copy link
Member

zhiics commented Sep 8, 2020

kevinthesun pushed a commit to kevinthesun/tvm that referenced this pull request Sep 17, 2020
* Dynamic Strided Slice

* fix clang-format lint

* remove debug print

* respond to review comments

* respond to yongwww's comments

* fix bad rebase

* revert hybrid-script assert

* reformat mxnet change

* use new testing api

* while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
kevinthesun pushed a commit to kevinthesun/tvm that referenced this pull request Sep 18, 2020
* Dynamic Strided Slice

* fix clang-format lint

* remove debug print

* respond to review comments

* respond to yongwww's comments

* fix bad rebase

* revert hybrid-script assert

* reformat mxnet change

* use new testing api

* while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 18, 2020
* Dynamic Strided Slice

* fix clang-format lint

* remove debug print

* respond to review comments

* respond to yongwww's comments

* fix bad rebase

* revert hybrid-script assert

* reformat mxnet change

* use new testing api

* while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants