Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Register Shape Func for Some Operators to Handle Dynamic Shapes #5955

Merged
merged 13 commits into from
Jul 23, 2020
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,5 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("fast_exp", False, elemwise_shape_func)
register_shape_func("fast_tanh", False, elemwise_shape_func)
register_shape_func("fast_erf", False, elemwise_shape_func)
register_shape_func("floor", False, elemwise_shape_func)
register_shape_func("log", False, elemwise_shape_func)
31 changes: 31 additions & 0 deletions python/tvm/relay/op/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"""Backend compiler related feature registration"""
from __future__ import absolute_import

from tvm.te.hybrid import script
from tvm.runtime import convert

import topi
from topi.util import get_const_tuple
from .. import op as reg
Expand Down Expand Up @@ -64,6 +67,34 @@ def compute_crop_and_resize(attrs, inputs, out_type):

reg.register_injective_schedule("image.crop_and_resize")

@script
def _crop_and_resize_func(image_shape, boxes_shape, crop_size,
height_axis, width_axis, channel_axis):
out = output_tensor((4,), "int64")
out[0] = boxes_shape[0]
out[height_axis] = int64(crop_size[0])
out[width_axis] = int64(crop_size[1])
out[channel_axis] = image_shape[channel_axis]
return out

@reg.register_shape_func("image.crop_and_resize", False)
def crop_and_resize_func(attrs, inputs, _):
"""
Shape function for crop_and_resize op.
"""
layout = attrs.layout
height_axis = width_axis = channel_axis = 1
for i, letter in enumerate(layout):
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
if letter == "C":
channel_axis = i
crop_size = get_const_tuple(attrs.crop_size)
icemelon marked this conversation as resolved.
Show resolved Hide resolved
return [_crop_and_resize_func(inputs[0], inputs[1], convert(crop_size),
convert(height_axis), convert(width_axis), convert(channel_axis))]


# dilation2d
reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy)
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,19 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
reg.register_broadcast_schedule("nn.mirror_pad")


@script
def _mirror_pad_func(data_shape, pad_width):
out = output_tensor((data_shape.shape[0],), "int64")
for i in const_range(data_shape.shape[0]):
out[i] = data_shape[i] + int64(pad_width[i][0]) + int64(pad_width[i][1])
return out

@reg.register_shape_func("nn.mirror_pad", False)
def mirror_pad_func(attrs, inputs, _):
pad_width_tuple = [get_const_tuple(p) for p in attrs.pad_width]
return [_mirror_pad_func(inputs[0], convert(pad_width_tuple))]


# conv2d_winograd related operators
reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
strategy.conv2d_winograd_without_weight_transfrom_strategy)
Expand Down
52 changes: 52 additions & 0 deletions tests/python/relay/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,6 +816,55 @@ def test_mixed_input_type():
assert result.asnumpy().shape == ref_out_shape, \
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))

def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size,
static_boxes, static_box_indices_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
indices_dtype = "int32"
data = relay.var('data', shape=data_shape, dtype=dtype)
boxes = relay.var('boxes', shape=boxes_shape, dtype=dtype)
box_indices = relay.var('box_indices', shape=box_indices_shape, dtype=indices_dtype)
y = relay.image.crop_and_resize(data, boxes, box_indices, crop_size, 'NHWC')
icemelon marked this conversation as resolved.
Show resolved Hide resolved
mod["main"] = relay.Function([data, boxes, box_indices], y)
data_np = np.random.uniform(size=data_shape).astype(dtype)
boxes_np = np.random.uniform(size=static_boxes).astype(dtype)
box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np, boxes_np, box_indices_np)
assert result.asnumpy().shape == ref_out_shape, \
icemelon marked this conversation as resolved.
Show resolved Hide resolved
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))

def test_any_crop_and_resize():
verify_any_crop_and_resize(
data_shape=(1, 234, 234, 256),
boxes_shape=(relay.Any(), 4),
box_indices_shape=(relay.Any(),),
crop_size=(14, 14),
static_boxes=(128, 4),
static_box_indices_shape=(128,),
ref_out_shape=(128, 14, 14, 256))

def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shape):
mod = tvm.IRModule()
dtype = "float32"
data = relay.var('data', shape=data_shape, dtype=dtype)
y = relay.nn.mirror_pad(data, pad_width)
mod["main"] = relay.Function([data], y)
data_np = np.random.uniform(size=static_data_shape).astype(dtype)
for kind in ["debug", "vm"]:
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
result = ex.evaluate()(data_np)
assert result.asnumpy().shape == ref_out_shape, \
icemelon marked this conversation as resolved.
Show resolved Hide resolved
"Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape))

def test_any_mirror_pad():
verify_any_mirror_pad(
data_shape=(1, 256, 232, 232),
pad_width=((0, 0), (0, 0), (1, 1), (1, 1)),
static_data_shape=(1, 256, 232, 232),
ref_out_shape=(1, 256, 234, 234))

if __name__ == "__main__":
test_any_full()
test_any_full_like()
Expand Down Expand Up @@ -850,3 +899,6 @@ def test_mixed_input_type():
test_recursive_concat_with_wrong_annotation()
test_tuple_get_item()
test_mixed_input_type()
test_any_crop_and_resize()
test_any_mirror_pad()