From 17a447944016dfe125f8dcfea7e5578dd2648513 Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Tue, 26 Nov 2019 14:49:03 -0800 Subject: [PATCH] incorporate comments --- python/tvm/relay/op/image/image.py | 4 +- .../frontend/tensorflow/test_forward.py | 28 +- topi/python/topi/image/resize.py | 697 ++++++++++++------ topi/tests/python/test_topi_image.py | 41 +- 4 files changed, 498 insertions(+), 272 deletions(-) diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index 0f92c223490ca..385fb486b60d6 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -31,7 +31,7 @@ def resize(data, with data of shape (n, c, h, w) out will have a shape (n, c, size[0], size[1]) - method indicates the algorithm to be used while calculating ghe out value + method indicates the algorithm to be used while calculating the out value and method can be one of ("bilinear", "nearest_neighbor", "bicubic") Parameters @@ -72,7 +72,7 @@ def crop_and_resize(data, out_dtype=None): """Crop input images and resize them. - method indicates the algorithm to be used while calculating ghe out value + method indicates the algorithm to be used while calculating the out value and method can be either "bilinear" or "nearest_neighbor". Parameters diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 02548addca453..d1188baf449be 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1591,35 +1591,29 @@ def test_forward_crop_and_resize(): _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3]) _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2) _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2, 'nearest') - _test_forward_crop_and_resize([1, 11, 11, 3], [[.3, .3, 1, 1]], [0], [21, 23]) - _test_forward_crop_and_resize([1, 41, 41, 3], [[.2, .4, .8, .8]], [0], [50, 60]) + _test_forward_crop_and_resize([1, 11, 11, 3], [[.3, .3, 1, 1]], [0], [21, 21]) + _test_forward_crop_and_resize([1, 41, 41, 3], [[.2, .4, .8, .8]], [0], [21, 11]) _test_forward_crop_and_resize([1, 100, 100, 3], [[ 0, 0, .9, .9]], [0], [30, 30]) _test_forward_crop_and_resize([1, 224, 224, 3], [[.1, .2, 1, 1]], [0], [9, 9]) - _test_forward_crop_and_resize([1, 250, 250, 3], [[.2, .2, .9, .8]], [0], [6, 7]) - _test_forward_crop_and_resize([1, 200, 200, 3], [[.2, .3, .6, .8]], [0], [100, 150]) + _test_forward_crop_and_resize([1, 249, 249, 3], [[ 0, 0, 1, 1]], [0], [9, 9]) + _test_forward_crop_and_resize([1, 201, 301, 3], [[.2, .3, .7, .8]], [0], [51, 51]) _test_forward_crop_and_resize(img_shape=[10, 11, 11, 3], boxes=[[ 0, 0, .9, .9], [.2, .2, .8, .8]], box_idx=[0, 1], crop_size=[5, 5]) - _test_forward_crop_and_resize(img_shape=[20, 229, 229, 3], - boxes=[[ 0, 0, .9, .9], - [.3, .3, 1, 1], - [.2, .1, .7, .8], - [ 0, 0, 1, 1]], - box_idx=[0, 1, 2, 3], crop_size=[60, 90]) - _test_forward_crop_and_resize(img_shape=[20, 229, 229, 3], - boxes=[[ 0, 0, .9, .9], - [.3, .3, 1, 1], - [.2, .1, .7, .8], - [ 0, 0, 1, 1]], - box_idx=[3, 0, 2, 1], crop_size=[60, 90], + _test_forward_crop_and_resize(img_shape=[20, 576, 576, 3], + boxes=[[ 0, 0, 1, 1], + [ 0, 0, .8, .8], + [.1, .2, .9, 1], + [.2, 0, 1, 1]], + box_idx=[1, 0, 2, 3], crop_size=[24, 24], extrapolation_value=0.3) _test_forward_crop_and_resize(img_shape=[20, 229, 229, 3], boxes=[[ 0, 0, .9, .9], [.3, .3, 1, 1], [.2, .1, .7, .8], [ 0, 0, 1, 1]], - box_idx=[3, 0, 2, 1], crop_size=[60, 90], + box_idx=[3, 0, 2, 1], crop_size=[58, 58], extrapolation_value=0.2, method='nearest') diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 886930bce4994..00224b6ba7cab 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -21,18 +21,46 @@ from .. import tag -def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None): - """Perform resize operation on the data. +def resize_nearest_neighbor(indices, data, image_height, image_width, + target_height, target_width, boxes=None, + box_indices=None, extrapolation_value=None, + layout='NCHW', align_corners=True, out_dtype=None): + """Perform resize operation with nearest neighbor method on the data. + For details about Nearest-neighbor interpolation please refer to + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. Parameters ---------- + indices : tuple + The indices of input data + data : tvm.Tensor inputs is a 4-D tensor with shape [batch, channel, in_height, in_width] or [batch, in_height, in_width, channel] - size: Tuple - Output resolution scale to + image_height : integer + Input image height + + image_width : integer + Input image width + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. layout: string, optional "NCHW", "NHWC", or "NCHWc". @@ -40,42 +68,37 @@ def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out align_corners: Boolean, optional To preserve the values at the corner pixels. - method: {"bilinear", "nearest_neighbor", "bicubic"} - Method to be used for resizing. - out_dtype: string, optional Type to return. If left None will be same as input type. Returns ------- - output : tvm.Tensor - 4-D with shape [batch, channel, in_height*scale, in_width*scale] - or [batch, in_height*scale, in_width*scale, channel] - or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] + output : out_dtype + The computed result with type out_dtype """ - method = method.lower() - if layout == 'NHWC': - in_n, in_h, in_w, in_c = data.shape - output_shape = [in_n, size[0], size[1], in_c] - elif layout == 'NCHW': - in_n, in_c, in_h, in_w = data.shape - output_shape = [in_n, in_c, size[0], size[1]] - # Otherwise layout must be NCHWxc - else: - in_n, in_c, in_h, in_w, in_cc = data.shape - output_shape = [in_n, in_c, size[0], size[1], in_cc] + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) - if align_corners: - y_ratio = (in_h - 1).astype('float') / (size[0] - 1) - x_ratio = (in_w - 1).astype('float') / (size[1] - 1) - else: - y_ratio = (in_h).astype('float') / (size[0]) - x_ratio = (in_w).astype('float') / (size[1]) + def _get_indices(indices, layout='NCHW'): + if layout == 'NHWC': + n, y, x, c = indices + cc = None + elif layout == 'NCHW': + n, c, y, x = indices + cc = None + else: + n, c, y, x, cc = indices + return n, c, y, x, cc - def _get_pixel(n, c, y, x, cc): - y = tvm.max(tvm.min(y, in_h - 1), 0) - x = tvm.max(tvm.min(x, in_w - 1), 0) + def _get_pixel(data, layout, n, c, y, x, cc): + if boxes is None: + y = tvm.max(tvm.min(y, image_height - 1), 0) + x = tvm.max(tvm.min(x, image_width - 1), 0) if layout == 'NHWC': return data(n, y, x, c).astype('float') if layout == 'NCHW': @@ -83,7 +106,122 @@ def _get_pixel(n, c, y, x, cc): # else must be NCHWxc return data(n, c, y, x, cc).astype('float') - def _get_indices(*indices): + n, c, y, x, cc = _get_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + if boxes is not None: + y1, x1 = boxes(n, 0), boxes(n, 1) + y2, x2 = boxes(n, 2), boxes(n, 3) + + in_h = (image_height - 1) * (y2 - y1) + in_w = (image_width - 1) * (x2 - x1) + h_scale = in_h.astype('float') / (target_height - 1).astype('float') + w_scale = in_w.astype('float') / (target_width - 1).astype('float') + + in_y = y1 * (image_height - 1) + h_scale * y + in_x = x1 * (image_width - 1) + w_scale * x + else: + if align_corners: + h_scale = (image_height - 1).astype('float') / (target_height - 1) + w_scale = (image_width - 1).astype('float') / (target_width - 1) + else: + h_scale = image_height.astype('float') / target_height + w_scale = image_width.astype('float') / target_width + in_y = h_scale * y + in_x = w_scale * x + + if align_corners or boxes is not None: + closest_x_index = tvm.round(in_x).astype("int32") + closest_y_index = tvm.round(in_y).astype("int32") + else: + # Add epsilon to floor to prevent gpu rounding errors. + epsilon = 1e-5 + closest_y_index = tvm.floor(in_y + epsilon).astype('int32') + closest_x_index = tvm.floor(in_x + epsilon).astype('int32') + + value = _get_pixel(data, layout, box_idx, c, closest_y_index, closest_x_index, cc) + + if extrapolation_value is not None: + out = tvm.if_then_else(in_y < 0, + extrapolation_value, + tvm.if_then_else(in_y > image_height - 1, + extrapolation_value, + value)) + # use extrapolation_value if in_x is out of boundary + value = tvm.if_then_else(in_x < 0, + extrapolation_value, + tvm.if_then_else(in_x > image_width - 1, + extrapolation_value, + out)) + return _cast_output(value, data.dtype, out_dtype=out_dtype) + + +def resize_bilinear(indices, data, image_height, image_width, + target_height, target_width, boxes=None, + box_indices=None, extrapolation_value=None, + layout='NCHW', align_corners=True, out_dtype=None): + """Perform resize operation with bilinear method on the data. + For details about Bilinear interpolation please refer to + https://en.wikipedia.org/wiki/Bilinear_interpolation. + + Parameters + ---------- + indices : tuple + The indices of input data + + data : tvm.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + image_height : integer + Input image height + + image_width : integer + Input image width + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + align_corners: Boolean, optional + To preserve the values at the corner pixels. + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ + + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) + + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + def _get_indices(indices, layout='NCHW'): if layout == 'NHWC': n, y, x, c = indices cc = None @@ -92,112 +230,304 @@ def _get_indices(*indices): cc = None else: n, c, y, x, cc = indices - return n, c, y, x, cc - def _cast_output(value): + def _get_pixel(data, layout, n, c, y, x, cc): + if boxes is None: + y = tvm.max(tvm.min(y, image_height - 1), 0) + x = tvm.max(tvm.min(x, image_width - 1), 0) + if layout == 'NHWC': + return data(n, y, x, c).astype('float') + if layout == 'NCHW': + return data(n, c, y, x).astype('float') + # else must be NCHWxc + return data(n, c, y, x, cc).astype('float') + + n, c, y, x, cc = _get_indices(indices, layout=layout) + box_idx = box_indices(n) if box_indices is not None else n + + if boxes is not None: + y1, x1 = boxes(n, 0), boxes(n, 1) + y2, x2 = boxes(n, 2), boxes(n, 3) + + in_h = (image_height - 1) * (y2 - y1) + in_w = (image_width - 1) * (x2 - x1) + h_scale = in_h.astype('float') / (target_height - 1).astype('float') + w_scale = in_w.astype('float') / (target_width - 1).astype('float') + + in_y = y1 * (image_height - 1) + h_scale * y + in_x = x1 * (image_width - 1) + w_scale * x + else: + if align_corners: + h_scale = (image_height - 1).astype('float') / (target_height - 1) + w_scale = (image_width - 1).astype('float') / (target_width - 1) + else: + h_scale = image_height.astype('float') / target_height + w_scale = image_width.astype('float') / target_width + in_y = h_scale * y + in_x = w_scale * x + + top_y_index = tvm.floor(in_y).astype('int32') + bottom_y_index = tvm.ceil(in_y).astype('int32') + y_lerp = in_y - top_y_index + + left_x_index = tvm.floor(in_x).astype('int32') + right_x_index = tvm.ceil(in_x).astype('int32') + x_lerp = in_x - left_x_index + + top_left = _get_pixel(data, layout, box_idx, c, top_y_index, left_x_index, cc) + top_right = _get_pixel(data, layout, box_idx, c, top_y_index, right_x_index, cc) + bottom_left = _get_pixel(data, layout, box_idx, c, bottom_y_index, left_x_index, cc) + bottom_right = _get_pixel(data, layout, box_idx, c, bottom_y_index, right_x_index, cc) + + top = _lerp(top_left, top_right, x_lerp) + bottom = _lerp(bottom_left, bottom_right, x_lerp) + value = _lerp(top, bottom, y_lerp) + + # use extrapolation_value if in_y/in_x is out of boundary + if extrapolation_value is not None: + out = tvm.if_then_else(in_y < 0, + extrapolation_value, + tvm.if_then_else(in_y > image_height - 1, + extrapolation_value, + value)) + value = tvm.if_then_else(in_x < 0, + extrapolation_value, + tvm.if_then_else(in_x > image_width - 1, + extrapolation_value, + out)) + return _cast_output(value, data.dtype, out_dtype=out_dtype) + + +def resize_bicubic(indices, data, image_height, image_width, + target_height, target_width, boxes=None, + box_indices=None, extrapolation_value=None, + layout='NCHW', align_corners=True, out_dtype=None): + """Perform resize operation with bicubic method on the data. + More details about Bicubic interpolation please refer to + https://en.wikipedia.org/wiki/Bicubic_interpolation. + + Parameters + ---------- + indices : tuple + The indices of input data + + data : tvm.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + image_height : integer + Input image height + + image_width : integer + Input image width + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + align_corners: Boolean, optional + To preserve the values at the corner pixels. + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ + + def _cubic_kernel(A, B, C, D, t): + a = -A / 2.0 + (3.0 * B) / 2.0 - (3.0 * C) / 2.0 + D / 2.0 + b = A - (5.0 * B) / 2.0 + 2.0 * C - D / 2.0 + c = -A / 2.0 + C / 2.0 + d = B + return a * t * t * t + b * t * t + c * t + d + + def _cast_output(value, data_dtype="float32", out_dtype=None): if out_dtype: dtype = out_dtype else: - dtype = data.dtype + dtype = data_dtype return value.astype(dtype) - # Nearest neighbor computation - def _nearest_neighbor(*indices): - n, c, y, x, cc = _get_indices(*indices) + def _get_indices(indices, layout='NCHW'): + if layout == 'NHWC': + n, y, x, c = indices + cc = None + elif layout == 'NCHW': + n, c, y, x = indices + cc = None + else: + n, c, y, x, cc = indices + return n, c, y, x, cc - in_y = y_ratio * y - in_x = x_ratio * x + def _get_pixel(data, layout, n, c, y, x, cc): + if boxes is None: + y = tvm.max(tvm.min(y, image_height - 1), 0) + x = tvm.max(tvm.min(x, image_width - 1), 0) + if layout == 'NHWC': + return data(n, y, x, c).astype('float') + if layout == 'NCHW': + return data(n, c, y, x).astype('float') + # else must be NCHWxc + return data(n, c, y, x, cc).astype('float') + + n, c, y, x, cc = _get_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + + if boxes is not None: + y1, x1 = boxes(n, 0), boxes(n, 1) + y2, x2 = boxes(n, 2), boxes(n, 3) + in_h = (image_height - 1) * (y2 - y1) + in_w = (image_width - 1) * (x2 - x1) + h_scale = in_h.astype('float') / (target_height - 1).astype('float') + w_scale = in_w.astype('float') / (target_width - 1).astype('float') + + in_y = y1 * (image_height - 1) + h_scale * y + in_x = x1 * (image_width - 1) + w_scale * x + else: if align_corners: - yint = tvm.round(in_y).astype('int32') - xint = tvm.round(in_x).astype('int32') + h_scale = (image_height - 1).astype('float') / (target_height - 1) + w_scale = (image_width - 1).astype('float') / (target_width - 1) else: - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - yint = tvm.floor(in_y + epsilon).astype('int32') - xint = tvm.floor(in_x + epsilon).astype('int32') + h_scale = image_height.astype('float') / target_height + w_scale = image_width.astype('float') / target_width + in_y = h_scale * y + in_x = w_scale * x + + xint = tvm.floor(in_x).astype('int32') + xfract = in_x - tvm.floor(in_x) + + yint = tvm.floor(in_y).astype('int32') + yfract = in_y - tvm.floor(in_y) + + # 1st row + p00 = _get_pixel(data, layout, box_idx, c, yint - 1, xint - 1, cc) + p10 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 0, cc) + p20 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 1, cc) + p30 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 2, cc) + + # 2nd row + p01 = _get_pixel(data, layout, box_idx, c, yint + 0, xint - 1, cc) + p11 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 0, cc) + p21 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 1, cc) + p31 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 2, cc) + + # 3rd row + p02 = _get_pixel(data, layout, box_idx, c, yint + 1, xint - 1, cc) + p12 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 0, cc) + p22 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 1, cc) + p32 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 2, cc) + + # 4th row + p03 = _get_pixel(data, layout, box_idx, c, yint + 2, xint - 1, cc) + p13 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 0, cc) + p23 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 1, cc) + p33 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 2, cc) + + # Interpolate bicubically + col0 = _cubic_kernel(p00, p10, p20, p30, xfract) + col1 = _cubic_kernel(p01, p11, p21, p31, xfract) + col2 = _cubic_kernel(p02, p12, p22, p32, xfract) + col3 = _cubic_kernel(p03, p13, p23, p33, xfract) + value = _cubic_kernel(col0, col1, col2, col3, yfract) + + # use extrapolation_value if in_y/in_x is out of boundary + if extrapolation_value is not None: + out = tvm.if_then_else(in_y < 0, + extrapolation_value, + tvm.if_then_else(in_y > image_height - 1, + extrapolation_value, + value)) + value = tvm.if_then_else(in_x < 0, + extrapolation_value, + tvm.if_then_else(in_x > image_width - 1, + extrapolation_value, + out)) + return _cast_output(value, data.dtype, out_dtype=out_dtype) - return _cast_output(_get_pixel(n, c, yint, xint, cc)) - # Bilinear helper functions and computation. - def _lerp(A, B, t): - return A * (1.0 - t) + B * t +def resize(data, size, layout="NCHW", method="bilinear", align_corners=True, out_dtype=None): + """Perform resize operation on the data. - def _bilinear(*indices): - n, c, y, x, cc = _get_indices(*indices) + Parameters + ---------- + data : tvm.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] - in_y = y_ratio * y - in_x = x_ratio * x + size: Tuple + Output resolution scale to - xint = tvm.floor(in_x).astype('int32') - xfract = in_x - tvm.floor(in_x) + layout: string, optional + "NCHW", "NHWC", or "NCHWc". - yint = tvm.floor(in_y).astype('int32') - yfract = in_y - tvm.floor(in_y) + align_corners: Boolean, optional + To preserve the values at the corner pixels. - p00 = _get_pixel(n, c, yint, xint, cc) - p10 = _get_pixel(n, c, yint, xint + 1, cc) - p01 = _get_pixel(n, c, yint + 1, xint, cc) - p11 = _get_pixel(n, c, yint + 1, xint + 1, cc) + method: {"bilinear", "nearest_neighbor", "bicubic"} + Method to be used for resizing. - col0 = _lerp(p00, p10, xfract) - col1 = _lerp(p01, p11, xfract) - value = _lerp(col0, col1, yfract) - return _cast_output(value) + out_dtype: string, optional + Type to return. If left None will be same as input type. - # Bicubic helper function and computation. - def _cubic_kernel(A, B, C, D, t): - a = -A / 2.0 + (3.0*B) / 2.0 - (3.0*C) / 2.0 + D / 2.0 - b = A - (5.0*B) / 2.0 + 2.0*C - D / 2.0 - c = -A / 2.0 + C / 2.0 - d = B + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, channel, in_height*scale, in_width*scale] + or [batch, in_height*scale, in_width*scale, channel] + or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] + """ + method = method.lower() + + if layout == 'NHWC': + in_n, in_h, in_w, in_c = data.shape + output_shape = [in_n, size[0], size[1], in_c] + elif layout == 'NCHW': + in_n, in_c, in_h, in_w = data.shape + output_shape = [in_n, in_c, size[0], size[1]] + elif layout.startswith("NCHW"):# for NCHWxc + in_n, in_c, in_h, in_w, in_cc = data.shape + output_shape = [in_n, in_c, size[0], size[1], in_cc] + else: + raise ValueError('%s layout is not supported.' % layout) + + + def _nearest_neighbor(*indices): + return resize_nearest_neighbor(indices, data, in_h, in_w, size[0], size[1], + layout=layout, align_corners=align_corners, + out_dtype=out_dtype) - return a*t*t*t + b*t*t + c*t + d + def _bilinear(*indices): + return resize_bilinear(indices, data, in_h, in_w, size[0], size[1], + layout=layout, align_corners=align_corners, + out_dtype=out_dtype) def _bicubic(*indices): - n, c, y, x, cc = _get_indices(*indices) - - in_y = y_ratio * y - in_x = x_ratio * x - - xint = tvm.floor(in_x).astype('int32') - xfract = in_x - tvm.floor(in_x) - - yint = tvm.floor(in_y).astype('int32') - yfract = in_y - tvm.floor(in_y) - - # 1st row - p00 = _get_pixel(n, c, yint - 1, xint - 1, cc) - p10 = _get_pixel(n, c, yint - 1, xint + 0, cc) - p20 = _get_pixel(n, c, yint - 1, xint + 1, cc) - p30 = _get_pixel(n, c, yint - 1, xint + 2, cc) - - # 2nd row - p01 = _get_pixel(n, c, yint + 0, xint - 1, cc) - p11 = _get_pixel(n, c, yint + 0, xint + 0, cc) - p21 = _get_pixel(n, c, yint + 0, xint + 1, cc) - p31 = _get_pixel(n, c, yint + 0, xint + 2, cc) - - # 3rd row - p02 = _get_pixel(n, c, yint + 1, xint - 1, cc) - p12 = _get_pixel(n, c, yint + 1, xint + 0, cc) - p22 = _get_pixel(n, c, yint + 1, xint + 1, cc) - p32 = _get_pixel(n, c, yint + 1, xint + 2, cc) - - # 4th row - p03 = _get_pixel(n, c, yint + 2, xint - 1, cc) - p13 = _get_pixel(n, c, yint + 2, xint + 0, cc) - p23 = _get_pixel(n, c, yint + 2, xint + 1, cc) - p33 = _get_pixel(n, c, yint + 2, xint + 2, cc) - - # Interpolate bicubically - col0 = _cubic_kernel(p00, p10, p20, p30, xfract) - col1 = _cubic_kernel(p01, p11, p21, p31, xfract) - col2 = _cubic_kernel(p02, p12, p22, p32, xfract) - col3 = _cubic_kernel(p03, p13, p23, p33, xfract) - value = _cubic_kernel(col0, col1, col2, col3, yfract) - return _cast_output(value) + return resize_bicubic(indices, data, in_h, in_w, size[0], size[1], + layout, align_corners=align_corners, + out_dtype=out_dtype) # Determine which interpolation method to use then run it. if method == "nearest_neighbor": @@ -253,135 +583,34 @@ def crop_and_resize(data, boxes, box_indices, crop_size, layout="NCHW", or [num_boxes, crop_height, crop_width, channel] """ method = method.lower() - target_h = crop_size[0] - target_w = crop_size[1] + target_h = crop_size[0].astype("int32") + target_w = crop_size[1].astype("int32") if layout == 'NHWC': output_shape = [box_indices.shape[0], crop_size[0], crop_size[1], data.shape[3]] - image_height = data.shape[1] - image_width = data.shape[2] + image_h = data.shape[1].astype("int32") + image_w = data.shape[2].astype("int32") elif layout == 'NCHW': output_shape = [box_indices.shape[0], data.shape[1], crop_size[0], crop_size[1]] - image_height = data.shape[2] - image_width = data.shape[3] - # Otherwise layout must be NCHWxc - else: + image_h = data.shape[2].astype("int32") + image_w = data.shape[3].astype("int32") + elif layout.startswith("NCHW"):# for NCHWxc output_shape = [box_indices.shape[0], data.shape[1], crop_size[0], crop_size[1], data.shape[4]] - image_height = data.shape[2] - image_width = data.shape[3] - - def _get_pixel(n, c, y, x, cc): - if layout.lower() == 'nhwc': - return data(n, y.astype("int32"), x.astype("int32"), c).astype('float') - if layout.lower() == 'nchw': - return data(n, c, y.astype("int32"), x.astype("int32")).astype('float') - # else must be NCHWxc - return data(n, c, y.astype("int32"), x.astype("int32"), cc).astype('float') - - def _get_indices(*indices): - if layout == 'NHWC': - n, y, x, c = indices - cc = None - elif layout == 'NCHW': - n, c, y, x = indices - cc = None - else: - n, c, y, x, cc = indices - - return n, c, y, x, cc - - def _cast_output(value): - if out_dtype: - dtype = out_dtype - else: - dtype = data.dtype - return value.astype(dtype) - - # Nearest neighbor computation - def _nearest_neighbor(*indices): - n, c, y, x, cc = _get_indices(*indices) - box_idx = box_indices(n) - - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) - - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = tvm.div(in_h, target_h - 1) - w_scale = tvm.div(in_w, target_w - 1) - - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x - closest_x_index = tvm.round(in_x) - closest_y_index = tvm.round(in_y) - - value = _get_pixel(box_idx, c, closest_y_index, closest_x_index, cc) - out_y = tvm.if_then_else(in_y < 0, - extrapolation_value, - tvm.if_then_else(in_y > image_height - 1, - extrapolation_value, - value)) - - # use extrapolation_value if in_x is out of boundary - out = tvm.if_then_else(in_x < 0, - extrapolation_value, - tvm.if_then_else(in_x > image_width - 1, - extrapolation_value, - out_y)) - return _cast_output(out) - - - # Bilinear helper functions and computation. - def _lerp(A, B, t): - return A * (1.0 - t) + B * t + image_h = data.shape[2].astype("int32") + image_w = data.shape[3].astype("int32") + else: + raise ValueError('%s layout is not supported.' % layout) def _bilinear(*indices): - n, c, y, x, cc = _get_indices(*indices) - box_idx = box_indices(n) + return resize_bilinear(indices, data, image_h, image_w, target_h, + target_w, boxes, box_indices, extrapolation_value, + layout, out_dtype=out_dtype) - y1, x1 = boxes(n, 0), boxes(n, 1) - y2, x2 = boxes(n, 2), boxes(n, 3) - - in_h = (image_height - 1) * (y2 - y1) - in_w = (image_width - 1) * (x2 - x1) - h_scale = tvm.div(in_h, target_h - 1) - w_scale = tvm.div(in_w, target_w - 1) - - in_y = y1 * (image_height - 1) + h_scale * y - in_x = x1 * (image_width - 1) + w_scale * x - - top_y_index = tvm.floor(in_y).astype('int32') - bottom_y_index = tvm.ceil(in_y).astype('int32') - y_lerp = in_y - top_y_index - - left_x_index = tvm.floor(in_x) - right_x_index = tvm.ceil(in_x) - x_lerp = in_x - left_x_index - - top_left = _get_pixel(box_idx, c, top_y_index, left_x_index, cc) - top_right = _get_pixel(box_idx, c, top_y_index, right_x_index, cc) - bottom_left = _get_pixel(box_idx, c, bottom_y_index, left_x_index, cc) - bottom_right = _get_pixel(box_idx, c, bottom_y_index, right_x_index, cc) - - top = _lerp(top_left, top_right, x_lerp) - bottom = _lerp(bottom_left, bottom_right, x_lerp) - value = _lerp(top, bottom, y_lerp) - - # use extrapolation_value if in_y is out of boundary - out_y = tvm.if_then_else(in_y < 0, - extrapolation_value, - tvm.if_then_else(in_y > image_height - 1, - extrapolation_value, - value)) - - # use extrapolation_value if in_x is out of boundary - out = tvm.if_then_else(in_x < 0, - extrapolation_value, - tvm.if_then_else(in_x > image_width - 1, - extrapolation_value, - out_y)) - return _cast_output(out) + def _nearest_neighbor(*indices): + return resize_nearest_neighbor(indices, data, image_h, image_w, target_h, + target_w, boxes, box_indices, extrapolation_value, + layout, out_dtype=out_dtype) # Determine which interpolation method to use then run it. if method == "nearest_neighbor": diff --git a/topi/tests/python/test_topi_image.py b/topi/tests/python/test_topi_image.py index db94dd1697c80..a86eab159a5e8 100644 --- a/topi/tests/python/test_topi_image.py +++ b/topi/tests/python/test_topi_image.py @@ -35,7 +35,7 @@ def verify_resize(batch, in_channel, in_height, in_width, out_height, out_width, a_np = np.random.uniform(size=(batch, in_height, in_width, in_channel)).astype(dtype) else: raise NotImplementedError( - 'Layout not supported {} '.format(layout)) + 'Layout {} is not supported.'.format(layout)) B = topi.image.resize(A, (out_height, out_width), layout=layout, align_corners=align_corners, method=method) @@ -84,18 +84,21 @@ def verify_crop_and_resize(image_shape, np_boxes, np_box_indices, np_crop_size, method="bilinear", extrapolation_value=0.0): images = tvm.placeholder(image_shape, name='images', dtype='float32') - dtype = images.dtype - np_images = np.random.uniform(size=image_shape).astype(dtype) + np_images = np.random.uniform(size=image_shape).astype("float32") boxes = tvm.placeholder(np_boxes.shape, name="boxes", dtype="float32") box_ind = tvm.placeholder(np_box_indices.shape, name="box_ind", dtype="int32") + batch = len(np_box_indices) + target_height, target_width = np_crop_size[0], np_crop_size[1] if layout == 'NHWC': - out_shape = (len(np_box_indices), np_crop_size[0], np_crop_size[1], image_shape[3]) + channel = image_shape[3] + out_shape = (batch, target_height, target_width, channel) elif layout == 'NCHW': - out_shape = (len(np_box_indices), image_shape[1], np_crop_size[0], np_crop_size[1]) + channel = image_shape[1] + out_shape = (batch, channel, target_height, target_width) else: raise NotImplementedError( - 'Layout not supported {} '.format(layout)) + 'Layout {} is not supported.'.format(layout)) out = topi.image.crop_and_resize(images, boxes, box_ind, np_crop_size, layout=layout, method=method, extrapolation_value=extrapolation_value) @@ -111,24 +114,24 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): s = topi.generic.schedule_injective(out) - nd_images = tvm.nd.array(np_images, ctx) - nd_boxes = tvm.nd.array(np_boxes, ctx) - nd_indices = tvm.nd.array(np_box_indices, ctx) - tvm_out = tvm.nd.array(np.zeros(out_shape, dtype=images.dtype), ctx) - f = tvm.build(s, [images, boxes, box_ind, out], device) - f(nd_images, nd_boxes, nd_indices, tvm_out) + tvm_images = tvm.nd.array(np_images, ctx) + tvm_boxes = tvm.nd.array(np_boxes, ctx) + tvm_indices = tvm.nd.array(np_box_indices, ctx) + tvm_out = tvm.nd.array(np.zeros(out_shape, dtype="float32"), ctx) + f = tvm.build(s, [images, boxes, box_ind, out], device, name="crop_and_resize") + f(tvm_images, tvm_boxes, tvm_indices, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), baseline_np, rtol=1e-3, atol=1e-3) - for device in get_all_backend(): + for device in ['llvm', 'cuda', 'opencl']: check_device(device) - boxes_1 = np.array([[.2, .3, .7, .9]]).astype("float32") - boxes_2 = np.array([[.2, .3, .7, .9], [0, .1, .8, 1]]).astype("float32") - indices_1 = np.array([0]).astype("int32") - indices_2 = np.array([1, 0]).astype("int32") - size_1 = np.array([7, 11]).astype("int32") - size_2 = np.array([90, 60]).astype("int32") + boxes_1 = np.array([[.2, .3, .7, .9]], dtype="float32") + boxes_2 = np.array([[.2, .3, .7, .9], [0, .1, .8, 1]], dtype="float32") + indices_1 = np.array([0], dtype="int32") + indices_2 = np.array([1, 0], dtype="int32") + size_1 = np.array([7, 11], dtype="int32") + size_2 = np.array([90, 60], dtype="int32") verify_crop_and_resize((1, 255, 255, 3), boxes_1, indices_1, size_1, layout="NHWC") verify_crop_and_resize((10, 224, 224, 5), boxes_2, indices_2,