Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Strided Slice #6316

Merged
merged 12 commits into from
Sep 8, 2020
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,8 @@ def _convert_cropping(inexpr, keras_layer, _):
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(crop_type))
int32_max = np.iinfo(np.int32).max
return _op.strided_slice(inexpr, begin=_expr.const([0, 0, crop_t, crop_l]), \
end=_expr.const([int32_max, int32_max, in_h-crop_b, in_w-crop_r]))
return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])


def _convert_batchnorm(inexpr, keras_layer, etab):
Expand Down
16 changes: 8 additions & 8 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,11 @@ def _mx_slice(inputs, attrs):
for i, ed in enumerate(end):
if ed is None:
end[i] = input_shape[i]
new_attrs = {'begin': _expr.const(list(begin), dtype="int32"),
'end': _expr.const(list(end), dtype="int32")}
new_attrs = {'begin': list(begin),
'end': list(end)}
if stride is not None:
stride = (x if x is not None else 1 for x in stride)
new_attrs['strides'] = _expr.const(list(stride), dtype="int32")
new_attrs['strides'] = list(stride)
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down Expand Up @@ -545,8 +545,8 @@ def _mx_slice_axis(inputs, attrs):
begin.append(ax_beg)
end.append(ax_end)
return _op.strided_slice(inputs[0],
_expr.const(begin, dtype="int32"),
_expr.const(end, dtype="int32"))
begin,
end)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be one-liner



def _mx_crop_like(inputs, attrs):
Expand All @@ -566,9 +566,9 @@ def _mx_crop_like(inputs, attrs):
return _op.slice_like(*inputs, **new_attrs)
expr = _infer_type(inputs[1])
like_shape = expr.checked_type.shape
new_attrs['begin'] = _expr.const([0, 0, offset[0], offset[1]], dtype="int32")
new_attrs['end'] = _expr.const([like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]], dtype="int32")
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]]
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,8 +1048,8 @@ def _impl_v1(cls, inputs, attr, params):
end = list(attr['ends'])

return _op.strided_slice(inputs[0],
begin=_expr.const(begin, dtype="int64"),
end=_expr.const(end, dtype="int64"))
begin=begin,
end=end)

@classmethod
def _impl_v10(cls, inputs, attr, params):
Expand All @@ -1069,8 +1069,8 @@ def _impl_v10(cls, inputs, attr, params):
attrs['starts'] = new_starts
attrs['ends'] = new_ends
return _op.strided_slice(inputs[0],
begin=_expr.const(attrs['starts'], dtype="int64"),
end=_expr.const(attrs['ends'], dtype="int64"))
begin=list(attrs['starts']),
end=list(attrs['ends']))


class Gather(OnnxOpConverter):
Expand Down
18 changes: 9 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def _impl(inputs, input_types):

strides.append(int(inputs[4]))
return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
begin=begin,
end=end,
strides=strides,
slice_mode="end")
return _impl

Expand Down Expand Up @@ -1346,9 +1346,9 @@ def _impl(inputs, input_types):
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
begin=begin,
end=end,
strides=stride)
chunks.append(chunk_out)

if dim % num_chunks:
Expand All @@ -1359,9 +1359,9 @@ def _impl(inputs, input_types):
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
begin=begin,
end=end,
strides=stride)
chunks.append(chunk_out)

return chunks
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def conv2d_grad(orig, grad):
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(backward_weight,
begin=const([0, 0, 0, 0], dtype="int64"),
end=const([out_channel, in_channel // attrs.groups,
filter_h, filter_w], dtype="int64"))
begin=[0, 0, 0, 0],
end=[out_channel, in_channel // attrs.groups,
filter_h, filter_w])

return [backward_data, backward_weight]

Expand Down
54 changes: 19 additions & 35 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,33 +126,6 @@ def arange_shape_func(attrs, inputs, _):
"""
return [_arange_shape_func(*inputs)]

@script
def _strided_slice_shape_func_input_data(data, begin, end, strides,
slice_mode):
ndim = len(data.shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
cbegin = 0
cend = data.shape[i]
cstride = 1
if strides.shape[0] > i:
cstride = strides[i]
if begin.shape[0] > i:
cbegin = begin[i]
if end.shape[0] <= i:
cend = data.shape[i]
elif slice_mode != 0:
cstride = 1
if end[i] < 0:
cend = data.shape[i]
else:
cend = cbegin + end[i]
else:
cend = end[i]
assert cstride != 0, "Strides can't be zero."
out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride)))
return out

@script
def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode):
ndim = data_shape.shape[0]
Expand All @@ -165,6 +138,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
cstride = int64(strides[i])
if len(begin) > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data_shape[i])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about cbegin is still less than 0 after adding data_shape[i]? probably raise error for invalid data input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted to do this with an assert in the hybrid script, but something seems a little off in the compiler, even after these lines it was still checking against the original value.

if len(end) <= i:
cend = int64(data_shape[i])
elif slice_mode != 0:
Expand All @@ -174,23 +149,32 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
else:
cend = cbegin + int64(end[i])
else:
cend = int64(end[i])
if end[i] > data_shape[i]:
cend = int64(data_shape[i])
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data_shape[i])
assert cstride != 0, "Strides can't be zero."
out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride)))
if cstride < 0:
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride

out[i] = int64(ceil_div(slice_range, step))
return out


@_reg.register_shape_func("strided_slice", True)
@_reg.register_shape_func("strided_slice", False)
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
"""
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
# data independent if begin, end and strides exist
if attrs.begin and attrs.end and attrs.strides:
return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
attrs.strides, slice_mode)]
return [_strided_slice_shape_func_input_data(*inputs, slice_mode)]
return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
attrs.strides, slice_mode)]

@script
def _concatenate_shape_func(inputs, axis):
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_reg.register_broadcast_schedule("dyn.tile")
_reg.register_injective_schedule("dyn.one_hot")
_reg.register_injective_schedule("dyn.full")
_reg.register_injective_schedule("dyn.strided_slice")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
Expand Down Expand Up @@ -145,3 +146,53 @@ def one_hot_shape_func(attrs, inputs, _):
"""
axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]


@script
def _strided_slice_shape_func_input_data(data, begin, end, strides,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between _strided_slice_shape_func_input_shape and _strided_slice_shape_func_input_data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the static op, we only require the input shapes, since the tributes are static, thus the input_shape variant. For the dynamic op, we need the input data, since the attributes are coming in as values. Thus the input_data variant. Why the VM is set up to separate those instead of just sending in data all the time, I'm not sure. Optimization?

slice_mode):
ndim = len(data.shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
cbegin = int64(0)
cend = int64(data.shape[i])
cstride = int64(1)
if strides.shape[0] > i:
cstride = int64(strides[i])
if begin.shape[0] > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data.shape[i])
if end.shape[0] <= i:
cend = int64(data.shape[i])
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = int64(data.shape[i])
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data.shape[i]:
cend = int64(data.shape[i])
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data.shape[i])
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride

out[i] = int64(ceil_div(slice_range, step))
return out

@_reg.register_shape_func("dyn.strided_slice", True)
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
"""
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
return [_strided_slice_shape_func_input_data(*inputs, slice_mode)]
19 changes: 12 additions & 7 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from . import _make
from .dyn import _make as _dyn_make
from .tensor import shape_of
from ..expr import TupleWrapper, const, Expr, Tuple
from ...tir import expr as _expr

Expand Down Expand Up @@ -827,13 +828,17 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
ret : relay.Expr
The computed result.
"""
strides = strides or const([1], dtype="int32")
if isinstance(begin, (tuple, list)):
begin = const(list(begin))
if isinstance(end, (tuple, list)):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
strides = strides or [1]
if (isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr)):
if isinstance(begin, (tuple, list)):
begin = const(list(begin))
if isinstance(end, (tuple, list)):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
normalized_begin = _make.where(begin < cast_like(const(0), begin),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could consider moving the normalization step into strided_slice

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, yeah, seems a little odd to produce a subgraph as part of the constructor of an op, which is why I put it here. That being said, this makes the op less useful from other frontends, so...

Any other votes?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this make it less useful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'd have to reproduce this on any frontend that creates ops.

begin + cast_like(shape_of(data), begin), begin)
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def _conv2d_legalize(attrs, inputs, arg_types):
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=relay.const([0, 0, 0, 0]),
end=relay.const(original_out_shape))
out = relay.strided_slice(out, begin=[0, 0, 0, 0],
end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def _conv2d_legalize(attrs, inputs, arg_types):
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out,
begin=relay.const([0, 0, 0, 0], "int32"),
end=relay.const(original_out_shape, "int32"))
begin=[0, 0, 0, 0],
end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)

Expand Down
Loading