diff --git a/python/tvm/relay/op/dyn/image/_image.py b/python/tvm/relay/op/dyn/image/_image.py index fa528e9a202d..2d36708af04f 100644 --- a/python/tvm/relay/op/dyn/image/_image.py +++ b/python/tvm/relay/op/dyn/image/_image.py @@ -40,37 +40,31 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("dyn.image.resize") - @script -def _NCHW_resize_shape_func(dshape, size, ndim): +def _resize_shape_func(dshape, size, ndim, height_axis, width_axis): out = output_tensor((ndim, ), "int64") for i in const_range(ndim): out[i] = int64(dshape[i]) - out[2] = int64(size[0]) - out[3] = int64(size[1]) + out[height_axis] = int64(size[0]) + out[width_axis] = int64(size[1]) return out - -@script -def _NHWC_resize_shape_func(dshape, size, ndim): - out = output_tensor((ndim, ), "int64") - for i in const_range(ndim): - out[i] = int64(dshape[i]) - out[1] = int64(size[0]) - out[2] = int64(size[1]) - return out - - @reg.register_shape_func("dyn.image.resize", True) def resize_shape_func(attrs, inputs, _): """ Shape function for dyn.image.resize op. """ layout = attrs.layout - if layout == 'NHWC': - out = [_NHWC_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))] - elif (layout == 'NCHW') or nchw_pack_layout(layout) or nchw_xc_layout(layout): - out = [_NCHW_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)))] + if nchw_pack_layout(layout) or nchw_xc_layout(layout): + out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), + convert(2), convert(3))] else: - raise ValueError("Resize Unsupported Layout", layout) + height_axis = width_axis = 1 + for i, letter in enumerate(layout): + if letter == "H": + height_axis = i + if letter == "W": + width_axis = i + out = [_resize_shape_func(inputs[0].shape, inputs[1], convert(len(inputs[0].shape)), + convert(height_axis), convert(width_axis))] return out