-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dynamic Strided Slice #6316
Dynamic Strided Slice #6316
Changes from all commits
625cc97
570f6f8
24ac192
f98a29d
99ac920
3e0c042
470819c
2f8e802
c1bddce
ae6bf1f
d2ddad2
742550e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
_reg.register_broadcast_schedule("dyn.tile") | ||
_reg.register_injective_schedule("dyn.one_hot") | ||
_reg.register_injective_schedule("dyn.full") | ||
_reg.register_injective_schedule("dyn.strided_slice") | ||
|
||
@script | ||
def _reshape_shape_func_input_data(data, newshape, ndim): | ||
|
@@ -145,3 +146,53 @@ def one_hot_shape_func(attrs, inputs, _): | |
""" | ||
axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis | ||
return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))] | ||
|
||
|
||
@script | ||
def _strided_slice_shape_func_input_data(data, begin, end, strides, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's the difference between There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the static op, we only require the input shapes, since the tributes are static, thus the input_shape variant. For the dynamic op, we need the input data, since the attributes are coming in as values. Thus the input_data variant. Why the VM is set up to separate those instead of just sending in data all the time, I'm not sure. Optimization? |
||
slice_mode): | ||
ndim = len(data.shape) | ||
out = output_tensor((ndim,), "int64") | ||
for i in const_range(ndim): | ||
cbegin = int64(0) | ||
cend = int64(data.shape[i]) | ||
cstride = int64(1) | ||
if strides.shape[0] > i: | ||
cstride = int64(strides[i]) | ||
if begin.shape[0] > i: | ||
cbegin = int64(begin[i]) | ||
if cbegin < 0: | ||
cbegin += int64(data.shape[i]) | ||
if end.shape[0] <= i: | ||
cend = int64(data.shape[i]) | ||
elif slice_mode != 0: | ||
cstride = int64(1) | ||
if end[i] < 0: | ||
cend = int64(data.shape[i]) | ||
else: | ||
cend = cbegin + int64(end[i]) | ||
else: | ||
if end[i] > data.shape[i]: | ||
cend = int64(data.shape[i]) | ||
else: | ||
cend = int64(end[i]) | ||
if cend < 0: | ||
cend += int64(data.shape[i]) | ||
assert cstride != 0, "Strides can't be zero." | ||
if cstride < 0: | ||
slice_range = cbegin - cend | ||
step = -cstride | ||
else: | ||
slice_range = cend - cbegin | ||
step = cstride | ||
|
||
out[i] = int64(ceil_div(slice_range, step)) | ||
return out | ||
|
||
@_reg.register_shape_func("dyn.strided_slice", True) | ||
def strided_slice_shape_func(attrs, inputs, _): | ||
""" | ||
Shape func for strided_slice | ||
""" | ||
slice_mode = convert(0 if attrs.slice_mode == "end" else 1) | ||
return [_strided_slice_shape_func_input_data(*inputs, slice_mode)] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
|
||
from . import _make | ||
from .dyn import _make as _dyn_make | ||
from .tensor import shape_of | ||
from ..expr import TupleWrapper, const, Expr, Tuple | ||
from ...tir import expr as _expr | ||
|
||
|
@@ -827,13 +828,17 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): | |
ret : relay.Expr | ||
The computed result. | ||
""" | ||
strides = strides or const([1], dtype="int32") | ||
if isinstance(begin, (tuple, list)): | ||
begin = const(list(begin)) | ||
if isinstance(end, (tuple, list)): | ||
end = const(list(end)) | ||
if isinstance(strides, (tuple, list)): | ||
strides = const(list(strides)) | ||
strides = strides or [1] | ||
if (isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr)): | ||
if isinstance(begin, (tuple, list)): | ||
begin = const(list(begin)) | ||
if isinstance(end, (tuple, list)): | ||
end = const(list(end)) | ||
if isinstance(strides, (tuple, list)): | ||
strides = const(list(strides)) | ||
normalized_begin = _make.where(begin < cast_like(const(0), begin), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could consider moving the normalization step into strided_slice There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, yeah, seems a little odd to produce a subgraph as part of the constructor of an op, which is why I put it here. That being said, this makes the op less useful from other frontends, so... Any other votes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does this make it less useful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We'd have to reproduce this on any frontend that creates ops. |
||
begin + cast_like(shape_of(data), begin), begin) | ||
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode) | ||
return _make.strided_slice(data, begin, end, strides, slice_mode) | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about cbegin is still less than 0 after adding data_shape[i]? probably raise error for invalid data input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I attempted to do this with an assert in the hybrid script, but something seems a little off in the compiler, even after these lines it was still checking against the original value.