Skip to content

Commit

Permalink
Dynamic Strided Slice (apache#6316)
Browse files Browse the repository at this point in the history
* Dynamic Strided Slice

* fix clang-format lint

* remove debug print

* respond to review comments

* respond to yongwww's comments

* fix bad rebase

* revert hybrid-script assert

* reformat mxnet change

* use new testing api

* while getting test to work with the new testing API, refactor all of the tests iin the dyn directory
  • Loading branch information
Matthew Brookhart authored and kevinthesun committed Sep 17, 2020
1 parent f399240 commit 9399aef
Show file tree
Hide file tree
Showing 30 changed files with 634 additions and 274 deletions.
4 changes: 2 additions & 2 deletions python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,8 @@ def _convert_cropping(inexpr, keras_layer, _):
raise tvm.error.OpNotImplemented(
'Operator {} is not supported for frontend Keras.'.format(crop_type))
int32_max = np.iinfo(np.int32).max
return _op.strided_slice(inexpr, begin=_expr.const([0, 0, crop_t, crop_l]), \
end=_expr.const([int32_max, int32_max, in_h-crop_b, in_w-crop_r]))
return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])


def _convert_batchnorm(inexpr, keras_layer, etab):
Expand Down
16 changes: 7 additions & 9 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,11 +500,11 @@ def _mx_slice(inputs, attrs):
for i, ed in enumerate(end):
if ed is None:
end[i] = input_shape[i]
new_attrs = {'begin': _expr.const(list(begin), dtype="int32"),
'end': _expr.const(list(end), dtype="int32")}
new_attrs = {'begin': list(begin),
'end': list(end)}
if stride is not None:
stride = (x if x is not None else 1 for x in stride)
new_attrs['strides'] = _expr.const(list(stride), dtype="int32")
new_attrs['strides'] = list(stride)
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down Expand Up @@ -544,9 +544,7 @@ def _mx_slice_axis(inputs, attrs):
else:
begin.append(ax_beg)
end.append(ax_end)
return _op.strided_slice(inputs[0],
_expr.const(begin, dtype="int32"),
_expr.const(end, dtype="int32"))
return _op.strided_slice(inputs[0], begin, end)


def _mx_crop_like(inputs, attrs):
Expand All @@ -566,9 +564,9 @@ def _mx_crop_like(inputs, attrs):
return _op.slice_like(*inputs, **new_attrs)
expr = _infer_type(inputs[1])
like_shape = expr.checked_type.shape
new_attrs['begin'] = _expr.const([0, 0, offset[0], offset[1]], dtype="int32")
new_attrs['end'] = _expr.const([like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]], dtype="int32")
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]]
return _op.strided_slice(inputs[0], **new_attrs)


Expand Down
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,8 +1049,8 @@ def _impl_v1(cls, inputs, attr, params):
end = list(attr['ends'])

return _op.strided_slice(inputs[0],
begin=_expr.const(begin, dtype="int64"),
end=_expr.const(end, dtype="int64"))
begin=begin,
end=end)

@classmethod
def _impl_v10(cls, inputs, attr, params):
Expand All @@ -1070,8 +1070,8 @@ def _impl_v10(cls, inputs, attr, params):
attrs['starts'] = new_starts
attrs['ends'] = new_ends
return _op.strided_slice(inputs[0],
begin=_expr.const(attrs['starts'], dtype="int64"),
end=_expr.const(attrs['ends'], dtype="int64"))
begin=list(attrs['starts']),
end=list(attrs['ends']))


class Gather(OnnxOpConverter):
Expand Down
18 changes: 9 additions & 9 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,9 @@ def _impl(inputs, input_types):
strides[dim] = int(inputs[4])

return _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(strides),
begin=begin,
end=end,
strides=strides,
slice_mode="end")
return _impl

Expand Down Expand Up @@ -1373,9 +1373,9 @@ def _impl(inputs, input_types):
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
begin=begin,
end=end,
strides=stride)
chunks.append(chunk_out)

if dim % num_chunks:
Expand All @@ -1386,9 +1386,9 @@ def _impl(inputs, input_types):
stride = [1] * len(shape)

chunk_out = _op.transform.strided_slice(data,
begin=_expr.const(begin),
end=_expr.const(end),
strides=_expr.const(stride))
begin=begin,
end=end,
strides=stride)
chunks.append(chunk_out)

return chunks
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,9 +407,9 @@ def conv2d_grad(orig, grad):
assert padded_weight_grad_w >= filter_w
if padded_weight_grad_h > filter_h or padded_weight_grad_w > filter_w:
backward_weight = strided_slice(backward_weight,
begin=const([0, 0, 0, 0], dtype="int64"),
end=const([out_channel, in_channel // attrs.groups,
filter_h, filter_w], dtype="int64"))
begin=[0, 0, 0, 0],
end=[out_channel, in_channel // attrs.groups,
filter_h, filter_w])

return [backward_data, backward_weight]

Expand Down
54 changes: 19 additions & 35 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,33 +127,6 @@ def arange_shape_func(attrs, inputs, _):
"""
return [_arange_shape_func(*inputs)]

@script
def _strided_slice_shape_func_input_data(data, begin, end, strides,
slice_mode):
ndim = len(data.shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
cbegin = 0
cend = data.shape[i]
cstride = 1
if strides.shape[0] > i:
cstride = strides[i]
if begin.shape[0] > i:
cbegin = begin[i]
if end.shape[0] <= i:
cend = data.shape[i]
elif slice_mode != 0:
cstride = 1
if end[i] < 0:
cend = data.shape[i]
else:
cend = cbegin + end[i]
else:
cend = end[i]
assert cstride != 0, "Strides can't be zero."
out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride)))
return out

@script
def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice_mode):
ndim = data_shape.shape[0]
Expand All @@ -166,6 +139,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
cstride = int64(strides[i])
if len(begin) > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data_shape[i])
if len(end) <= i:
cend = int64(data_shape[i])
elif slice_mode != 0:
Expand All @@ -175,23 +150,32 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
else:
cend = cbegin + int64(end[i])
else:
cend = int64(end[i])
if end[i] > data_shape[i]:
cend = int64(data_shape[i])
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data_shape[i])
assert cstride != 0, "Strides can't be zero."
out[i] = int64(ceil_div((int64(cend) - int64(cbegin)), int64(cstride)))
if cstride < 0:
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride

out[i] = int64(ceil_div(slice_range, step))
return out


@_reg.register_shape_func("strided_slice", True)
@_reg.register_shape_func("strided_slice", False)
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
"""
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
# data independent if begin, end and strides exist
if attrs.begin and attrs.end and attrs.strides:
return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
attrs.strides, slice_mode)]
return [_strided_slice_shape_func_input_data(*inputs, slice_mode)]
return [_strided_slice_shape_func_input_shape(inputs[0], attrs.begin, attrs.end,
attrs.strides, slice_mode)]

@script
def _concatenate_shape_func(inputs, axis):
Expand Down
51 changes: 51 additions & 0 deletions python/tvm/relay/op/dyn/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
_reg.register_broadcast_schedule("dyn.tile")
_reg.register_injective_schedule("dyn.one_hot")
_reg.register_injective_schedule("dyn.full")
_reg.register_injective_schedule("dyn.strided_slice")

@script
def _reshape_shape_func_input_data(data, newshape, ndim):
Expand Down Expand Up @@ -145,3 +146,53 @@ def one_hot_shape_func(attrs, inputs, _):
"""
axis = len(inputs[0].shape) if attrs.axis == -1 else attrs.axis
return [_onehot_shape_func(inputs[0].shape, inputs[3], convert(axis))]


@script
def _strided_slice_shape_func_input_data(data, begin, end, strides,
slice_mode):
ndim = len(data.shape)
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
cbegin = int64(0)
cend = int64(data.shape[i])
cstride = int64(1)
if strides.shape[0] > i:
cstride = int64(strides[i])
if begin.shape[0] > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data.shape[i])
if end.shape[0] <= i:
cend = int64(data.shape[i])
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = int64(data.shape[i])
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data.shape[i]:
cend = int64(data.shape[i])
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data.shape[i])
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride

out[i] = int64(ceil_div(slice_range, step))
return out

@_reg.register_shape_func("dyn.strided_slice", True)
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
"""
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
return [_strided_slice_shape_func_input_data(*inputs, slice_mode)]
19 changes: 12 additions & 7 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from . import _make
from .dyn import _make as _dyn_make
from .tensor import shape_of
from ..expr import TupleWrapper, const, Expr, Tuple
from ...tir import expr as _expr

Expand Down Expand Up @@ -827,13 +828,17 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
ret : relay.Expr
The computed result.
"""
strides = strides or const([1], dtype="int32")
if isinstance(begin, (tuple, list)):
begin = const(list(begin))
if isinstance(end, (tuple, list)):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
strides = strides or [1]
if (isinstance(begin, Expr) or isinstance(end, Expr) or isinstance(strides, Expr)):
if isinstance(begin, (tuple, list)):
begin = const(list(begin))
if isinstance(end, (tuple, list)):
end = const(list(end))
if isinstance(strides, (tuple, list)):
strides = const(list(strides))
normalized_begin = _make.where(begin < cast_like(const(0), begin),
begin + cast_like(shape_of(data), begin), begin)
return _dyn_make.strided_slice(data, normalized_begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)


Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/cuda/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def _conv2d_legalize(attrs, inputs, arg_types):
new_attrs['channels'] = new_out_channel
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out, begin=relay.const([0, 0, 0, 0]),
end=relay.const(original_out_shape))
out = relay.strided_slice(out, begin=[0, 0, 0, 0],
end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)
return out
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def _conv2d_legalize(attrs, inputs, arg_types):
out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
original_out_shape = [x.value for x in output_tensor.shape]
out = relay.strided_slice(out,
begin=relay.const([0, 0, 0, 0], "int32"),
end=relay.const(original_out_shape, "int32"))
begin=[0, 0, 0, 0],
end=original_out_shape)
else:
out = relay.nn.conv2d(data, kernel, **new_attrs)

Expand Down
Loading

0 comments on commit 9399aef

Please sign in to comment.