Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register Shape Func for Some Operators to Handle Dynamic Shapes #5955

Merged
merged 13 commits into from
Jul 23, 2020
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,5 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func)
register_shape_func("fast_erf", False, elemwise_shape_func)
register_shape_func("floor", False, elemwise_shape_func)
register_shape_func("log", False, elemwise_shape_func)
19 changes: 19 additions & 0 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import

from tvm.te.hybrid import script
from tvm.runtime import convert

import topi
from topi.util import get_const_tuple
from .. import op as reg
Expand Down Expand Up @@ -64,6 +67,22 @@ def compute_crop_and_resize(attrs, inputs, out_type):

reg.register_injective_schedule("image.crop_and_resize")

@script
def _crop_and_resize_func(image_shape, boxes_shape, crop_size):
out = output_tensor((4,), "int64")
out[0] = boxes_shape[0]
out[1] = int64(crop_size[0])
out[2] = int64(crop_size[1])
out[3] = image_shape[3]

return out

@reg.register_shape_func("image.crop_and_resize", False)
def crop_and_resize_func(attrs, inputs, _):
crop_size = get_const_tuple(attrs.crop_size)
icemelon marked this conversation as resolved.
Show resolved Hide resolved

return [_crop_and_resize_func(inputs[0], inputs[1], convert(crop_size))]


# dilation2d
reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy)
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,19 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
reg.register_broadcast_schedule("nn.mirror_pad")


@script
def _mirror_pad_func(data_shape, pad_width):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(data_shape.shape[0]):
out[i] = data_shape[i] + int64(pad_width[i][0]) + int64(pad_width[i][1])
return out

@reg.register_shape_func("nn.mirror_pad", False)
def mirror_pad_func(attrs, inputs, _):
pad_width_tuple = [get_const_tuple(p) for p in attrs.pad_width]
return [_mirror_pad_func(inputs[0], convert(pad_width_tuple))]


# conv2d_winograd related operators
reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
strategy.conv2d_winograd_without_weight_transfrom_strategy)
Expand Down