Skip to content

Commit

Permalink
[DYN][RELAY] Resize support for NCHW-convertible layouts (#6293)
Browse files Browse the repository at this point in the history
* fix lint

* fix typo

* remove channel_axis from resize shape func

* fix lint
  • Loading branch information
electriclilies authored Aug 26, 2020
1 parent 617949d commit 942c90b
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions python/tvm/relay/op/dyn/image/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,31 @@ def compute_resize(attrs, inputs, out_type):

reg.register_injective_schedule("dyn.image.resize")


@script
def _NCHW_resize_shape_func(dshape, size, ndim):
def _resize_shape_func(dshape, size, ndim, height_axis, width_axis):
out = output_tensor((ndim, ), "int64")
for i in const_range(ndim):
out[i] = int64(dshape[i])
out[2] = int64(size[0])
out[3] = int64(size[1])
out[height_axis] = int64(size[0])
out[width_axis] = int64(size[1])
return out


@script
def _NHWC_resize_shape_func(dshape, size, ndim):
out = output_tensor((ndim, ), "int64")
for i in const_range(ndim):
out[i] = int64(dshape[i])
out[1] = int64(size[0])
out[2] = int64(size[1])
return out


@reg.register_shape_func("dyn.image.resize", True)
def resize_shape_func(attrs, inputs, _):
"""
Shape function for dyn.image.resize op.
"""
layout = attrs.layout
if layout == 'NHWC':
out = [_NHWC_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
elif (layout == 'NCHW') or nchw_pack_layout(layout) or nchw_xc_layout(layout):
out = [_NCHW_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))]
if nchw_pack_layout(layout) or nchw_xc_layout(layout):
out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)),
convert(2), convert(3))]
else:
raise ValueError("Resize Unsupported Layout", layout)
height_axis = width_axis = 1
for i, letter in enumerate(layout):
if letter == "H":
height_axis = i
if letter == "W":
width_axis = i
out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)),
convert(height_axis), convert(width_axis))]
return out

0 comments on commit 942c90b

Please sign in to comment.