From 71c11062ae51eec428e76e5fb9ad886a96329d1f Mon Sep 17 00:00:00 2001 From: Yong Wu Date: Fri, 10 Jan 2020 19:52:52 -0800 Subject: [PATCH] [TOPI][RELAY][OP] add op crop_and_resize (#4417) * [TOPI][RELAY][OP] add op crop_and_resize * fix pylint * incorporate comments * fix ci --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 2 + include/tvm/relay/attrs/image.h | 28 + python/tvm/relay/frontend/tensorflow.py | 43 +- python/tvm/relay/op/image/_image.py | 16 +- python/tvm/relay/op/image/image.py | 52 +- python/tvm/relay/op/op_attrs.py | 3 + src/lang/expr_operator.cc | 12 +- src/relay/op/image/resize.cc | 84 +++ .../frontend/tensorflow/test_forward.py | 58 +- tests/python/relay/test_op_level5.py | 44 +- topi/python/topi/image/resize.py | 696 ++++++++++++++---- topi/python/topi/testing/__init__.py | 1 + .../topi/testing/crop_and_resize_python.py | 114 +++ ...test_topi_resize.py => test_topi_image.py} | 65 +- 15 files changed, 1018 insertions(+), 202 deletions(-) create mode 100644 topi/python/topi/testing/crop_and_resize_python.py rename topi/tests/python/{test_topi_resize.py => test_topi_image.py} (66%) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index df31ca24403d4..75a4271291bfa 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -104,6 +104,7 @@ List of operators topi.ndarray_size topi.layout_transform topi.image.resize + topi.image.crop_and_resize topi.argsort topi.topk topi.sequence_mask @@ -207,6 +208,7 @@ topi.nn topi.image ~~~~~~~~~~ .. autofunction:: topi.image.resize +.. autofunction:: topi.image.crop_and_resize topi.sparse ~~~~~~~~~~~ diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 1fabd704482cf..e1e25a95485d9 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -169,6 +169,7 @@ This level enables additional math and transform operators. :nosignatures: tvm.relay.image.resize + tvm.relay.image.crop_and_resize tvm.relay.vision.multibox_prior tvm.relay.vision.multibox_transform_loc tvm.relay.vision.nms @@ -335,6 +336,7 @@ Level 4 Definitions Level 5 Definitions ------------------- .. autofunction:: tvm.relay.image.resize +.. autofunction:: tvm.relay.image.crop_and_resize .. autofunction:: tvm.relay.vision.multibox_prior .. autofunction:: tvm.relay.vision.multibox_transform_loc .. autofunction:: tvm.relay.vision.nms diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h index 87ad82d0293f7..22d657d0933fc 100644 --- a/include/tvm/relay/attrs/image.h +++ b/include/tvm/relay/attrs/image.h @@ -63,6 +63,34 @@ struct ResizeAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in image crop_and_resize operator */ +struct CropAndResizeAttrs : public tvm::AttrsNode { + Array crop_size; + std::string layout; + std::string method; + double extrapolation_value; + DataType out_dtype; + + TVM_DECLARE_ATTRS(CropAndResizeAttrs, "relay.attrs.CropAndResizeAttrs") { + TVM_ATTR_FIELD(crop_size).set_default(NullValue >()) + .describe("Target Size."); + TVM_ATTR_FIELD(layout).set_default("NCHW") + .describe("Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Resize is applied on the 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(method).set_default("bilinear") + .describe("Specify the mode to use for scaling." + "nearest_neighbor - Nearest Neighbor" + "bilinear - Bilinear Interpolation"); + TVM_ATTR_FIELD(extrapolation_value).set_default(0.0) + .describe("Specify value for extrapolation."); + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_IMAGE_H_ diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index dceadbf6dbe14..7e22d72131ac4 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -546,47 +546,20 @@ def _impl(inputs, attr, params): # input image is a 4-D tensor of shape [batch, image_height, image_width, depth] # boxes is a 2-D tensor of shape [num_boxes, 4], 4 is for [y1, x1, y2, x2] try: - boxes = _get_list_param(params, inputs[1]) - box_ind = _get_list_param(params, inputs[2]) crop_size = _get_list_param(params, inputs[3]) except (IndexError, KeyError): - boxes = _infer_value(inputs[1], params).asnumpy().tolist() - box_ind = _infer_value(inputs[2], params).asnumpy().tolist() crop_size = _infer_value(inputs[3], params).asnumpy().tolist() - data_shape = attr['_input_shapes'][inputs[0]] - data_dim = len(data_shape) method = attr['method'].decode() - - attrs = {} - attrs['size'] = crop_size - attrs['layout'] = 'NHWC' - if method.lower() == 'nearest': + method = 'nearest_neighbor' if method == 'nearest' else method + if method not in ['bilinear', 'nearest_neighbor']: raise tvm.error.OpAttributeUnImplemented( - 'Attribute method=nearest is not supported') - else: - attrs['coordinate_transformation_mode'] = 'align_corners' - attrs['method'] = 'bilinear' - - out = None - begin = [0] * data_dim - size = data_shape[:] - for idx in box_ind: - # 1) Crop - # y is mapped to the image coordinate at y * (image_height - 1) - # x is mapped to the image coordinate at x * (image_width - 1) - begin[0] = idx - begin[1] = int(round(boxes[idx][0] * (data_shape[1] - 1))) - begin[2] = int(round(boxes[idx][1] * (data_shape[2] - 1))) - size[0] = idx + 1 - size[1] = int(round((data_shape[1] - 1) * boxes[idx][2])) + 1 - size[2] = int(round((data_shape[2] - 1) * boxes[idx][3])) + 1 - res_crop = _op.strided_slice(inputs[0], begin=begin, end=size) - - # 2) Resize - res_resize = get_relay_op('resize')(res_crop, **attrs) - out = _op.concatenate([out, res_resize], axis=0) if out else res_resize - return out + 'Method {} is not supported'.format(method)) + layout = attr['layout'] if 'layout' in attr else 'NHWC' + extrapolation_value = attr['extrapolation_value'] + + return get_relay_op("crop_and_resize")(inputs[0], inputs[1], inputs[2], crop_size, + layout, method, extrapolation_value) return _impl def _cast(): diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 776435ada4979..89fde6dc17383 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -25,7 +25,6 @@ # resize reg.register_schedule("image.resize", schedule_injective) - @reg.register_compute("image.resize") def compute_resize(attrs, inputs, out_type, target): size = attrs.size @@ -34,3 +33,18 @@ def compute_resize(attrs, inputs, out_type, target): coord_trans = attrs.coordinate_transformation_mode out_dtype = attrs.out_dtype return [topi.image.resize(inputs[0], size, layout, method, coord_trans, out_dtype)] + + +# crop and resize +reg.register_schedule("image.crop_and_resize", schedule_injective) + +@reg.register_compute("image.crop_and_resize") +def compute_crop_and_resize(attrs, inputs, out_type, target): + crop_size = attrs.crop_size + layout = attrs.layout + method = attrs.method + extrapolation_value = attrs.extrapolation_value + out_dtype = attrs.out_dtype + return [topi.image.crop_and_resize(inputs[0], inputs[1], inputs[2], + crop_size, layout, method, + extrapolation_value, out_dtype)] diff --git a/python/tvm/relay/op/image/image.py b/python/tvm/relay/op/image/image.py index e0475a06025a3..284d6023db6f9 100644 --- a/python/tvm/relay/op/image/image.py +++ b/python/tvm/relay/op/image/image.py @@ -31,7 +31,7 @@ def resize(data, with data of shape (n, c, h, w) out will have a shape (n, c, size[0], size[1]) - method indicates the algorithm to be used while calculating ghe out value + method indicates the algorithm to be used while calculating the out value and method can be one of ("bilinear", "nearest_neighbor", "bicubic") Parameters @@ -63,3 +63,53 @@ def resize(data, The resized result. """ return _make.resize(data, size, layout, method, coordinate_transformation_mode, out_dtype) + + +def crop_and_resize(data, + boxes, + box_indices, + crop_size, + layout, + method="bilinear", + extrapolation_value=0, + out_dtype=None): + """Crop input images and resize them. + + method indicates the algorithm to be used while calculating the out value + and method can be either "bilinear" or "nearest_neighbor". + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + boxes : relay.Expr + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : relay.Expr + A 1-D tensor of shape [num_boxes], box_ind[i] specifies the data that + the i-th box refers to. + + crop_size : Tuple of Expr + The target size to which each box will be resized. + + layout : str, optional + Layout of the input. + + method : str, optional + Scale method, it can be either "nearest_neighbor" or "bilinear". + + extrapolation_value : float, optional + Value used for extrapolation, when applicable. + + out_dtype : str, optional + Type to return. If left None returns the same type as input. + + Returns + ------- + result: relay.Expr + The computed result. + """ + return _make.crop_and_resize(data, boxes, box_indices, crop_size, + layout, method, extrapolation_value, out_dtype) diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py index 37e62cc0455d1..e5a9a11fb012d 100644 --- a/python/tvm/relay/op/op_attrs.py +++ b/python/tvm/relay/op/op_attrs.py @@ -114,6 +114,9 @@ class DeformableConv2DAttrs(Attrs): class ResizeAttrs(Attrs): """Attributes for image.resize""" +@register_relay_attr_node +class CropAndResizeAttrs(Attrs): + """Attributes for image.crop_and_resize""" @register_relay_attr_node class ArgsortAttrs(Attrs): diff --git a/src/lang/expr_operator.cc b/src/lang/expr_operator.cc index 078ca628ad24e..d3875e28c8874 100644 --- a/src/lang/expr_operator.cc +++ b/src/lang/expr_operator.cc @@ -246,8 +246,8 @@ PrimExpr div(PrimExpr a, PrimExpr b) { } PrimExpr truncdiv(PrimExpr a, PrimExpr b) { - CHECK(a.dtype().is_int() || a.dtype().is_uint()); - CHECK(b.dtype().is_int() || b.dtype().is_uint()); + CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; return div(a, b); } @@ -276,8 +276,8 @@ PrimExpr indexmod(PrimExpr a, PrimExpr b) { } PrimExpr floordiv(PrimExpr a, PrimExpr b) { - CHECK(a.dtype().is_int() || a.dtype().is_uint()); - CHECK(b.dtype().is_int() || b.dtype().is_uint()); + CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; @@ -285,8 +285,8 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b) { } PrimExpr floormod(PrimExpr a, PrimExpr b) { - CHECK(a.dtype().is_int() || a.dtype().is_uint()); - CHECK(b.dtype().is_int() || b.dtype().is_uint()); + CHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + CHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; BinaryOpMatchTypes(a, b); PrimExpr ret = arith::TryConstFold(a, b); if (ret.defined()) return ret; diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index b7169b1b2e6ed..e387a712435f2 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -109,5 +109,89 @@ RELAY_REGISTER_OP("image.resize") .add_type_rel("Resize", ResizeRel) .set_attr("TOpPattern", kInjective); + +TVM_REGISTER_NODE_TYPE(CropAndResizeAttrs); + +bool CropAndResizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 4); + const auto* data = types[0].as(); + const auto* boxes = types[1].as(); + const auto* box_indices = types[2].as(); + if (data == nullptr || boxes == nullptr || + box_indices == nullptr) return false; + + const CropAndResizeAttrs* param = attrs.as(); + CHECK(param != nullptr); + auto crop_size = param->crop_size; + + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + + // 4-D tensor of shape [num_boxes, crop_height, crop_width, depth] + static const Layout kNCHW("NCHW"); + const Layout in_layout(param->layout); + auto layout_converter = BijectiveLayoutNode::make(in_layout, kNCHW); + auto oshape = layout_converter.ForwardShape(data->shape); + oshape.Set(0, box_indices->shape[0]); + oshape.Set(2, crop_size[0]); + oshape.Set(3, crop_size[1]); + auto bshape = layout_converter.BackwardShape(oshape); + // assign output type + reporter->Assign(types[3], + TensorTypeNode::make(layout_converter.BackwardShape(oshape), + out_dtype)); + return true; +} + +Expr MakeCropAndResize(Expr data, + Expr boxes, + Expr box_indices, + Array crop_size, + std::string layout, + std::string method, + double extrapolation_value, + DataType out_dtype) { + auto attrs = make_object(); + attrs->crop_size = std::move(crop_size); + attrs->layout = std::move(layout); + attrs->method = std::move(method); + attrs->extrapolation_value = std::move(extrapolation_value); + attrs->out_dtype = out_dtype; + static const Op& op = Op::Get("image.crop_and_resize"); + return CallNode::make(op, {data, boxes, box_indices}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relay.op.image._make.crop_and_resize") +.set_body_typed(MakeCropAndResize); + + +RELAY_REGISTER_OP("image.crop_and_resize") + .describe(R"code(Perform crop and resize to input array with nearest neighbour or bilinear interpolation. + +- **data**: data is 4D array of shape + (batch_size, channels, in_height, in_width) for NCHW + (batch_size, in_height, in_width, channels) for NHWC + +- **out**: Output is 4D array of shape + for layout NCHW + (batch_size, channels, crop_size[0], crop_size[1]) + + for layout NHWC + (batch_size, crop_size[0], crop_size[1], channels) +)code" TVM_ADD_FILELINE) +.set_num_inputs(3) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("boxes", "Tensor", "The boxes tensor.") +.add_argument("box_indices", "Tensor", "The box indices tensor.") +.set_attrs_type() +.set_support_level(5) +.add_type_rel("CropAndResize", CropAndResizeRel) +.set_attr("TOpPattern", kInjective); + } // namespace relay } // namespace tvm diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index b3940817be890..19e5b1ff9c3c7 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -1706,39 +1706,47 @@ def test_forward_crop(): # CropAndResize # ------------- -def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, method='bilinear', dtype="float32"): +def _test_forward_crop_and_resize(img_shape, boxes, box_idx, crop_size, + extrapolation_value=0.0, method='bilinear', dtype="float32"): image = np.random.uniform(0, 10, size=img_shape).astype(dtype) tf.reset_default_graph() in_data = tf.placeholder(dtype, image.shape, name="in_data") - tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, crop_size=crop_size, - method=method, name="crop_and_resize") + tf.image.crop_and_resize(in_data, boxes=boxes, box_ind=box_idx, + crop_size=crop_size, method=method, + extrapolation_value=extrapolation_value, + name="crop_and_resize") compare_tf_with_tvm([image], ['in_data:0'], 'crop_and_resize:0') def test_forward_crop_and_resize(): """ CropAndResize """ - _test_forward_crop_and_resize([1, 11, 11, 3], [[0, 0, 1, 1]], [0], [5, 5]) - _test_forward_crop_and_resize( - [1, 11, 11, 3], [[0, 0, .9, .9]], [0], [5, 5]) - _test_forward_crop_and_resize( - [1, 11, 11, 3], [[.1, .2, 1, 1]], [0], [5, 5]) - _test_forward_crop_and_resize( - [1, 21, 21, 3], [[.2, .3, .7, .9]], [0], [3, 4]) - _test_forward_crop_and_resize( - [1, 41, 41, 3], [[0.2, 0.4, 0.8, 0.8]], [0], [3, 3]) - _test_forward_crop_and_resize([10, 11, 11, 3], - [[0, 0, 0.9, 0.9], [0.2, 0.2, 0.8, 0.8]], - [0, 1], - [5, 5]) - _test_forward_crop_and_resize([3, 11, 11, 3], - [[0, 0, 0.9, 0.9], [ - 0.2, 0.2, 0.8, 0.8], [0, 0, 1, 1]], - [0, 1, 2], - [3, 3]) - _test_forward_crop_and_resize([3, 11, 11, 3], - [[0, 0, 1, 0.8], [0, 0, 0.9, 0.9], [0, 0, 1, 0.8]], - [2, 1, 0], - [3, 3]) + _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3]) + _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2) + _test_forward_crop_and_resize([1, 6, 6, 3], [[0, 0, 1, 1]], [0], [3, 3], 0.2, 'nearest') + _test_forward_crop_and_resize([1, 11, 11, 3], [[.3, .3, 1, 1]], [0], [21, 21]) + _test_forward_crop_and_resize([1, 41, 41, 3], [[.2, .4, .8, .8]], [0], [21, 11]) + _test_forward_crop_and_resize([1, 100, 100, 3], [[ 0, 0, .9, .9]], [0], [30, 30]) + _test_forward_crop_and_resize([1, 224, 224, 3], [[.1, .2, 1, 1]], [0], [9, 9]) + _test_forward_crop_and_resize([1, 249, 249, 3], [[ 0, 0, 1, 1]], [0], [9, 9]) + _test_forward_crop_and_resize([1, 201, 301, 3], [[.2, .3, .7, .8]], [0], [51, 51]) + _test_forward_crop_and_resize(img_shape=[10, 11, 11, 3], + boxes=[[ 0, 0, .9, .9], + [.2, .2, .8, .8]], + box_idx=[0, 1], crop_size=[5, 5]) + _test_forward_crop_and_resize(img_shape=[20, 576, 576, 3], + boxes=[[ 0, 0, 1, 1], + [ 0, 0, .8, .8], + [.1, .2, .9, 1], + [.2, 0, 1, 1]], + box_idx=[1, 0, 2, 3], crop_size=[24, 24], + extrapolation_value=0.3) + _test_forward_crop_and_resize(img_shape=[20, 229, 229, 3], + boxes=[[ 0, 0, .9, .9], + [.3, .3, 1, 1], + [.2, .1, .7, .8], + [ 0, 0, 1, 1]], + box_idx=[3, 0, 2, 1], crop_size=[58, 58], + extrapolation_value=0.2, method='nearest') ####################################################################### diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 2f2e8523161cf..808fc49c29bbb 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -72,6 +72,47 @@ def verify_resize(dshape, scale, method, layout): for layout in ["NHWC", "NCHW"]: verify_resize((1, 4, 4, 4), 2, method, layout) +def test_crop_and_resize(): + def verify_crop_and_resize(img_shape, boxes, box_indices, crop_size, + layout, method, extrapolation_value=0.0): + + image_data = np.random.uniform(size=img_shape).astype("float32") + + ref_res = topi.testing.crop_and_resize_python(image_data, + boxes, + box_indices, + crop_size, + layout, method, + extrapolation_value) + + img = relay.var("img", relay.TensorType(img_shape, 'float32')) + bx = relay.var('bx', relay.TensorType(boxes.shape, 'float32')) + bx_idx = relay.var('bx_idx', relay.TensorType(box_indices.shape, 'int32')) + + z = relay.image.crop_and_resize(img, bx, bx_idx, list(crop_size), + layout, method, extrapolation_value) + zz = run_infer_type(z) + assert zz.checked_type == relay.TensorType(ref_res.shape, "float32") + func = relay.Function([img, bx, bx_idx], z) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(image_data, boxes, box_indices) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-3, atol=1e-04) + + boxes_nhwc = np.array([[.1, .2, .8, .7], [.2, 0, 1, .6]]).astype("float32") + indices_nhwc = np.array([1, 0]).astype("int32") + size_nhwc = np.array([20, 30]).astype("int32") + boxes_nchw = np.array([[0, 0, 1, 1], [.2, .1, 1, .9]]).astype("float32") + indices_nchw = np.array([0, 1]).astype("int32") + size_nchw = np.array([30, 30]).astype("int32") + + for method in ["bilinear", "nearest_neighbor"]: + verify_crop_and_resize((10, 224, 224, 3), boxes_nhwc, indices_nhwc, + size_nhwc, 'NHWC', method) + verify_crop_and_resize((5, 3, 255, 255), boxes_nchw, indices_nchw, + size_nchw, 'NCHW', method, 0.1) def test_multibox_prior(): def get_ref_result(dshape, sizes=(1.0,), @@ -639,6 +680,7 @@ def verify_space_to_depth(dshape, block_size, layout): if __name__ == "__main__": test_resize_infer_type() test_resize() + test_crop_and_resize() test_multibox_prior() test_multibox_transform_loc() test_get_valid_counts() @@ -650,4 +692,4 @@ def verify_space_to_depth(dshape, block_size, layout): test_non_max_suppression() test_deformable_conv2d() test_depth_to_space() - test_space_to_depth() \ No newline at end of file + test_space_to_depth() diff --git a/topi/python/topi/image/resize.py b/topi/python/topi/image/resize.py index 004e04a604e5f..00ae5d6532d79 100644 --- a/topi/python/topi/image/resize.py +++ b/topi/python/topi/image/resize.py @@ -21,19 +21,48 @@ from .. import tag -def resize(data, size, layout="NCHW", method="bilinear", - coordinate_transformation_mode="half_pixel", out_dtype=None): - """Perform resize operation on the data. +def resize_nearest_neighbor(indices, data, image_height, image_width, + target_height, target_width, boxes=None, + box_indices=None, extrapolation_value=None, layout='NCHW', + coordinate_transformation_mode="align_corners", + out_dtype=None): + + """Perform resize operation with nearest neighbor method on the data. + For details about Nearest-neighbor interpolation please refer to + https://en.wikipedia.org/wiki/Nearest-neighbor_interpolation. Parameters ---------- - inputs : tvm.Tensor + indices : tuple + The indices of input data + + data : tvm.Tensor inputs is a 4-D tensor with shape [batch, channel, in_height, in_width] or [batch, in_height, in_width, channel] - size: Tuple - Output resolution scale to + image_height : integer + Input image height + + image_width : integer + Input image width + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. layout: string, optional "NCHW", "NHWC", or "NCHWc". @@ -44,45 +73,37 @@ def resize(data, size, layout="NCHW", method="bilinear", Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"bilinear", "nearest_neighbor", "bicubic"} - Method to be used for resizing. - out_dtype: string, optional Type to return. If left None will be same as input type. Returns ------- - output : tvm.Tensor - 4-D with shape [batch, channel, in_height*scale, in_width*scale] - or [batch, in_height*scale, in_width*scale, channel] - or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] + output : out_dtype + The computed result with type out_dtype """ - method = method.lower() - if layout == 'NHWC': - in_n, in_h, in_w, in_c = data.shape - output_shape = [in_n, size[0], size[1], in_c] - elif layout == 'NCHW': - in_n, in_c, in_h, in_w = data.shape - output_shape = [in_n, in_c, size[0], size[1]] - # Otherwise layout must be NCHWxc - else: - in_n, in_c, in_h, in_w, in_cc = data.shape - output_shape = [in_n, in_c, size[0], size[1], in_cc] + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) - if coordinate_transformation_mode == "align_corners": - y_ratio = (in_h - 1).astype('float') / (size[0] - 1) - x_ratio = (in_w - 1).astype('float') / (size[1] - 1) - elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: - y_ratio = (in_h).astype('float') / (size[0]) - x_ratio = (in_w).astype('float') / (size[1]) - else: - raise ValueError("Unsupported coordinate_transformation_mode: {}".format( - coordinate_transformation_mode)) + def _get_indices(indices, layout='NCHW'): + if layout == 'NHWC': + n, y, x, c = indices + cc = None + elif layout == 'NCHW': + n, c, y, x = indices + cc = None + else: + n, c, y, x, cc = indices + return n, c, y, x, cc - def _get_pixel(n, c, y, x, cc): - y = tvm.max(tvm.min(y, in_h - 1), 0) - x = tvm.max(tvm.min(x, in_w - 1), 0) + def _get_pixel(data, layout, n, c, y, x, cc): + if boxes is None: + y = tvm.max(tvm.min(y, image_height - 1), 0) + x = tvm.max(tvm.min(x, image_width - 1), 0) if layout == 'NHWC': return data(n, y, x, c).astype('float') if layout == 'NCHW': @@ -90,7 +111,130 @@ def _get_pixel(n, c, y, x, cc): # else must be NCHWxc return data(n, c, y, x, cc).astype('float') - def _get_indices(*indices): + n, c, y, x, cc = _get_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + if boxes is not None: + y1, x1 = boxes(n, 0), boxes(n, 1) + y2, x2 = boxes(n, 2), boxes(n, 3) + + in_h = (image_height - 1) * (y2 - y1) + in_w = (image_width - 1) * (x2 - x1) + h_scale = in_h.astype('float') / (target_height - 1) + w_scale = in_w.astype('float') / (target_width - 1) + + in_y = y1 * (image_height - 1) + h_scale * y + in_x = x1 * (image_width - 1) + w_scale * x + else: + if coordinate_transformation_mode == "align_corners": + h_scale = (image_height - 1).astype('float') / (target_height - 1) + w_scale = (image_width - 1).astype('float') / (target_width - 1) + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: + h_scale = image_height.astype('float') / target_height + w_scale = image_width.astype('float') / target_width + else: + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) + in_y = h_scale * y + in_x = w_scale * x + + if coordinate_transformation_mode == "align_corners" or boxes is not None: + closest_x_index = tvm.round(in_x).astype("int32") + closest_y_index = tvm.round(in_y).astype("int32") + else: + # Add epsilon to floor to prevent gpu rounding errors. + epsilon = 1e-5 + closest_y_index = tvm.floor(in_y + epsilon).astype('int32') + closest_x_index = tvm.floor(in_x + epsilon).astype('int32') + + value = _get_pixel(data, layout, box_idx, c, closest_y_index, closest_x_index, cc) + + if extrapolation_value is not None: + out = tvm.if_then_else(in_y < 0, + extrapolation_value, + tvm.if_then_else(in_y > image_height - 1, + extrapolation_value, + value)) + # use extrapolation_value if in_x is out of boundary + value = tvm.if_then_else(in_x < 0, + extrapolation_value, + tvm.if_then_else(in_x > image_width - 1, + extrapolation_value, + out)) + return _cast_output(value, data.dtype, out_dtype=out_dtype) + + +def resize_bilinear(indices, data, image_height, image_width, + target_height, target_width, boxes=None, + box_indices=None, extrapolation_value=None, layout='NCHW', + coordinate_transformation_mode="align_corners", + out_dtype=None): + + """Perform resize operation with bilinear method on the data. + For details about Bilinear interpolation please refer to + https://en.wikipedia.org/wiki/Bilinear_interpolation. + + Parameters + ---------- + indices : tuple + The indices of input data + + data : tvm.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + image_height : integer + Input image height + + image_width : integer + Input image width + + target_height : integer + The target resized image height + + target_width : integer + The target resized image width + + boxes : tvm.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ + + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) + + def _lerp(A, B, t): + return A * (1.0 - t) + B * t + + def _get_indices(indices, layout='NCHW'): if layout == 'NHWC': n, y, x, c = indices cc = None @@ -99,120 +243,334 @@ def _get_indices(*indices): cc = None else: n, c, y, x, cc = indices - return n, c, y, x, cc - def _cast_output(value): - if out_dtype: - dtype = out_dtype - else: - dtype = data.dtype - return value.astype(dtype) + def _get_pixel(data, layout, n, c, y, x, cc): + if boxes is None: + y = tvm.max(tvm.min(y, image_height - 1), 0) + x = tvm.max(tvm.min(x, image_width - 1), 0) + if layout == 'NHWC': + return data(n, y, x, c).astype('float') + if layout == 'NCHW': + return data(n, c, y, x).astype('float') + # else must be NCHWxc + return data(n, c, y, x, cc).astype('float') - # Nearest neighbor computation - def _nearest_neighbor(*indices): - n, c, y, x, cc = _get_indices(*indices) + n, c, y, x, cc = _get_indices(indices, layout=layout) + box_idx = box_indices(n) if box_indices is not None else n - in_y = y_ratio * y - in_x = x_ratio * x + if boxes is not None: + y1, x1 = boxes(n, 0), boxes(n, 1) + y2, x2 = boxes(n, 2), boxes(n, 3) + + in_h = (image_height - 1) * (y2 - y1) + in_w = (image_width - 1) * (x2 - x1) + h_scale = in_h.astype('float') / (target_height - 1) + w_scale = in_w.astype('float') / (target_width - 1) + in_y = y1 * (image_height - 1) + h_scale * y + in_x = x1 * (image_width - 1) + w_scale * x + else: if coordinate_transformation_mode == "align_corners": - yint = tvm.round(in_y).astype('int32') - xint = tvm.round(in_x).astype('int32') + h_scale = (image_height - 1).astype('float') / (target_height - 1) + w_scale = (image_width - 1).astype('float') / (target_width - 1) + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: + h_scale = image_height.astype('float') / target_height + w_scale = image_width.astype('float') / target_width else: - # Add epsilon to floor to prevent gpu rounding errors. - epsilon = 1e-5 - yint = tvm.floor(in_y + epsilon).astype('int32') - xint = tvm.floor(in_x + epsilon).astype('int32') + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) - return _cast_output(_get_pixel(n, c, yint, xint, cc)) + if coordinate_transformation_mode == "half_pixel": + in_y = h_scale * (y + 0.5) - 0.5 + in_x = w_scale * (x + 0.5) - 0.5 + else: + in_y = h_scale * y + in_x = w_scale * x + + top_y_index = tvm.floor(in_y).astype('int32') + bottom_y_index = tvm.ceil(in_y).astype('int32') + y_lerp = in_y - top_y_index + + left_x_index = tvm.floor(in_x).astype('int32') + right_x_index = tvm.ceil(in_x).astype('int32') + x_lerp = in_x - left_x_index + + top_left = _get_pixel(data, layout, box_idx, c, top_y_index, left_x_index, cc) + top_right = _get_pixel(data, layout, box_idx, c, top_y_index, right_x_index, cc) + bottom_left = _get_pixel(data, layout, box_idx, c, bottom_y_index, left_x_index, cc) + bottom_right = _get_pixel(data, layout, box_idx, c, bottom_y_index, right_x_index, cc) + + top = _lerp(top_left, top_right, x_lerp) + bottom = _lerp(bottom_left, bottom_right, x_lerp) + value = _lerp(top, bottom, y_lerp) + + # use extrapolation_value if in_y/in_x is out of boundary + if extrapolation_value is not None: + out = tvm.if_then_else(in_y < 0, + extrapolation_value, + tvm.if_then_else(in_y > image_height - 1, + extrapolation_value, + value)) + value = tvm.if_then_else(in_x < 0, + extrapolation_value, + tvm.if_then_else(in_x > image_width - 1, + extrapolation_value, + out)) + return _cast_output(value, data.dtype, out_dtype=out_dtype) + + +def resize_bicubic(indices, data, image_height, image_width, + target_height, target_width, boxes=None, + box_indices=None, extrapolation_value=None, layout='NCHW', + coordinate_transformation_mode="align_corners", + out_dtype=None): + """Perform resize operation with bicubic method on the data. + More details about Bicubic interpolation please refer to + https://en.wikipedia.org/wiki/Bicubic_interpolation. - # Bilinear helper functions and computation. - def _lerp(A, B, t): - return A * (1.0 - t) + B * t + Parameters + ---------- + indices : tuple + The indices of input data - def _bilinear(*indices): - n, c, y, x, cc = _get_indices(*indices) + data : tvm.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] - if coordinate_transformation_mode == "half_pixel": - in_y = y_ratio * (y + 0.5) - 0.5 - in_x = x_ratio * (x + 0.5) - 0.5 - else: - in_y = y_ratio * y - in_x = x_ratio * x + image_height : integer + Input image height - xint = tvm.floor(in_x).astype('int32') - xfract = in_x - tvm.floor(in_x) + image_width : integer + Input image width - yint = tvm.floor(in_y).astype('int32') - yfract = in_y - tvm.floor(in_y) + target_height : integer + The target resized image height - p00 = _get_pixel(n, c, yint, xint, cc) - p10 = _get_pixel(n, c, yint, xint + 1, cc) - p01 = _get_pixel(n, c, yint + 1, xint, cc) - p11 = _get_pixel(n, c, yint + 1, xint + 1, cc) + target_width : integer + The target resized image width - col0 = _lerp(p00, p10, xfract) - col1 = _lerp(p01, p11, xfract) - value = _lerp(col0, col1, yfract) - return _cast_output(value) + boxes : tvm.Tensor, optional + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.Tensor, optional + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : out_dtype + The computed result with type out_dtype + """ - # Bicubic helper function and computation. def _cubic_kernel(A, B, C, D, t): - a = -A / 2.0 + (3.0*B) / 2.0 - (3.0*C) / 2.0 + D / 2.0 - b = A - (5.0*B) / 2.0 + 2.0*C - D / 2.0 + a = -A / 2.0 + (3.0 * B) / 2.0 - (3.0 * C) / 2.0 + D / 2.0 + b = A - (5.0 * B) / 2.0 + 2.0 * C - D / 2.0 c = -A / 2.0 + C / 2.0 d = B + return a * t * t * t + b * t * t + c * t + d - return a*t*t*t + b*t*t + c*t + d + def _cast_output(value, data_dtype="float32", out_dtype=None): + if out_dtype: + dtype = out_dtype + else: + dtype = data_dtype + return value.astype(dtype) - def _bicubic(*indices): - n, c, y, x, cc = _get_indices(*indices) + def _get_indices(indices, layout='NCHW'): + if layout == 'NHWC': + n, y, x, c = indices + cc = None + elif layout == 'NCHW': + n, c, y, x = indices + cc = None + else: + n, c, y, x, cc = indices + return n, c, y, x, cc + + def _get_pixel(data, layout, n, c, y, x, cc): + if boxes is None: + y = tvm.max(tvm.min(y, image_height - 1), 0) + x = tvm.max(tvm.min(x, image_width - 1), 0) + if layout == 'NHWC': + return data(n, y, x, c).astype('float') + if layout == 'NCHW': + return data(n, c, y, x).astype('float') + # else must be NCHWxc + return data(n, c, y, x, cc).astype('float') + + n, c, y, x, cc = _get_indices(indices, layout) + box_idx = box_indices(n) if box_indices is not None else n + + if boxes is not None: + y1, x1 = boxes(n, 0), boxes(n, 1) + y2, x2 = boxes(n, 2), boxes(n, 3) + + in_h = (image_height - 1) * (y2 - y1) + in_w = (image_width - 1) * (x2 - x1) + h_scale = in_h.astype('float') / (target_height - 1) + w_scale = in_w.astype('float') / (target_width - 1) + + in_y = y1 * (image_height - 1) + h_scale * y + in_x = x1 * (image_width - 1) + w_scale * x + else: + if coordinate_transformation_mode == "align_corners": + h_scale = (image_height - 1).astype('float') / (target_height - 1) + w_scale = (image_width - 1).astype('float') / (target_width - 1) + elif coordinate_transformation_mode in ["asymmetric", "half_pixel"]: + h_scale = image_height.astype('float') / target_height + w_scale = image_width.astype('float') / target_width + else: + raise ValueError("Unsupported coordinate_transformation_mode: {}".format( + coordinate_transformation_mode)) if coordinate_transformation_mode == "half_pixel": - in_y = y_ratio * (y + 0.5) - 0.5 - in_x = x_ratio * (x + 0.5) - 0.5 + in_y = h_scale * (y + 0.5) - 0.5 + in_x = w_scale * (x + 0.5) - 0.5 else: - in_y = y_ratio * y - in_x = x_ratio * x + in_y = h_scale * y + in_x = w_scale * x + + xint = tvm.floor(in_x).astype('int32') + xfract = in_x - tvm.floor(in_x) + + yint = tvm.floor(in_y).astype('int32') + yfract = in_y - tvm.floor(in_y) + + # 1st row + p00 = _get_pixel(data, layout, box_idx, c, yint - 1, xint - 1, cc) + p10 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 0, cc) + p20 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 1, cc) + p30 = _get_pixel(data, layout, box_idx, c, yint - 1, xint + 2, cc) + + # 2nd row + p01 = _get_pixel(data, layout, box_idx, c, yint + 0, xint - 1, cc) + p11 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 0, cc) + p21 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 1, cc) + p31 = _get_pixel(data, layout, box_idx, c, yint + 0, xint + 2, cc) + + # 3rd row + p02 = _get_pixel(data, layout, box_idx, c, yint + 1, xint - 1, cc) + p12 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 0, cc) + p22 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 1, cc) + p32 = _get_pixel(data, layout, box_idx, c, yint + 1, xint + 2, cc) + + # 4th row + p03 = _get_pixel(data, layout, box_idx, c, yint + 2, xint - 1, cc) + p13 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 0, cc) + p23 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 1, cc) + p33 = _get_pixel(data, layout, box_idx, c, yint + 2, xint + 2, cc) + + # Interpolate bicubically + col0 = _cubic_kernel(p00, p10, p20, p30, xfract) + col1 = _cubic_kernel(p01, p11, p21, p31, xfract) + col2 = _cubic_kernel(p02, p12, p22, p32, xfract) + col3 = _cubic_kernel(p03, p13, p23, p33, xfract) + value = _cubic_kernel(col0, col1, col2, col3, yfract) + + # use extrapolation_value if in_y/in_x is out of boundary + if extrapolation_value is not None: + out = tvm.if_then_else(in_y < 0, + extrapolation_value, + tvm.if_then_else(in_y > image_height - 1, + extrapolation_value, + value)) + value = tvm.if_then_else(in_x < 0, + extrapolation_value, + tvm.if_then_else(in_x > image_width - 1, + extrapolation_value, + out)) + return _cast_output(value, data.dtype, out_dtype=out_dtype) - xint = tvm.floor(in_x).astype('int32') - xfract = in_x - tvm.floor(in_x) - yint = tvm.floor(in_y).astype('int32') - yfract = in_y - tvm.floor(in_y) +def resize(data, size, layout="NCHW", method="bilinear", + coordinate_transformation_mode="half_pixel", out_dtype=None): + """Perform resize operation on the data. - # 1st row - p00 = _get_pixel(n, c, yint - 1, xint - 1, cc) - p10 = _get_pixel(n, c, yint - 1, xint + 0, cc) - p20 = _get_pixel(n, c, yint - 1, xint + 1, cc) - p30 = _get_pixel(n, c, yint - 1, xint + 2, cc) - - # 2nd row - p01 = _get_pixel(n, c, yint + 0, xint - 1, cc) - p11 = _get_pixel(n, c, yint + 0, xint + 0, cc) - p21 = _get_pixel(n, c, yint + 0, xint + 1, cc) - p31 = _get_pixel(n, c, yint + 0, xint + 2, cc) - - # 3rd row - p02 = _get_pixel(n, c, yint + 1, xint - 1, cc) - p12 = _get_pixel(n, c, yint + 1, xint + 0, cc) - p22 = _get_pixel(n, c, yint + 1, xint + 1, cc) - p32 = _get_pixel(n, c, yint + 1, xint + 2, cc) - - # 4th row - p03 = _get_pixel(n, c, yint + 2, xint - 1, cc) - p13 = _get_pixel(n, c, yint + 2, xint + 0, cc) - p23 = _get_pixel(n, c, yint + 2, xint + 1, cc) - p33 = _get_pixel(n, c, yint + 2, xint + 2, cc) - - # Interpolate bicubically - col0 = _cubic_kernel(p00, p10, p20, p30, xfract) - col1 = _cubic_kernel(p01, p11, p21, p31, xfract) - col2 = _cubic_kernel(p02, p12, p22, p32, xfract) - col3 = _cubic_kernel(p03, p13, p23, p33, xfract) - value = _cubic_kernel(col0, col1, col2, col3, yfract) - return _cast_output(value) + Parameters + ---------- + data : tvm.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + size: Tuple + Output resolution scale to + + layout: string, optional + "NCHW", "NHWC", or "NCHWc". + + coordinate_transformation_mode: string, optional + Describes how to transform the coordinate in the resized tensor + to the coordinate in the original tensor. + Refer to the ONNX Resize operator specification for details. + Available options are "half_pixel", "align_corners" and "asymmetric". + + method: {"bilinear", "nearest_neighbor", "bicubic"} + Method to be used for resizing. + + out_dtype: string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : tvm.Tensor + 4-D with shape [batch, channel, in_height*scale, in_width*scale] + or [batch, in_height*scale, in_width*scale, channel] + or 5-D with shape [batch, channel-major, in_height*scale, in_width*scale, channel-minor] + """ + method = method.lower() + + if layout == 'NHWC': + in_n, in_h, in_w, in_c = data.shape + output_shape = [in_n, size[0], size[1], in_c] + elif layout == 'NCHW': + in_n, in_c, in_h, in_w = data.shape + output_shape = [in_n, in_c, size[0], size[1]] + elif layout.startswith("NCHW"):# for NCHWxc + in_n, in_c, in_h, in_w, in_cc = data.shape + output_shape = [in_n, in_c, size[0], size[1], in_cc] + else: + raise ValueError('%s layout is not supported.' % layout) + + + def _nearest_neighbor(*indices): + return resize_nearest_neighbor(indices, data, in_h, in_w, + size[0], size[1], layout=layout, + coordinate_transformation_mode= \ + coordinate_transformation_mode, + out_dtype=out_dtype) + + def _bilinear(*indices): + return resize_bilinear(indices, data, in_h, in_w, + size[0], size[1], layout=layout, + coordinate_transformation_mode= \ + coordinate_transformation_mode, + out_dtype=out_dtype) + + def _bicubic(*indices): + return resize_bicubic(indices, data, in_h, in_w, + size[0], size[1], layout, + coordinate_transformation_mode= \ + coordinate_transformation_mode, + out_dtype=out_dtype) # Determine which interpolation method to use then run it. if method == "nearest_neighbor": @@ -226,35 +584,111 @@ def _bicubic(*indices): return tvm.compute(output_shape, compute_func, name='resize', tag=tag.INJECTIVE) + +def crop_and_resize(data, boxes, box_indices, crop_size, layout="NCHW", + method="bilinear", extrapolation_value=0, out_dtype=None): + """Perform crop and resize operation on the data. + + Parameters + ---------- + data : tvm.Tensor + inputs is a 4-D tensor with shape + [batch, channel, in_height, in_width] + or [batch, in_height, in_width, channel] + + boxes : tvm.Tensor + A 2-D tensor of shape [num_boxes, 4]. Each row of the tensor specifies + the coordinates of a box. + + box_indices : tvm.Tensor + A 1-D tensor of shape [num_boxes], box_indices[i] specifies the data that + the i-th box refers to. + + crop_size : Tuple + The target size of each box. + + layout : string, optional + "NCHW", "NHWC" + + method : {"bilinear", "nearest_neighbor"} + Method to be used for resizing. + + extrapolation_value: float, optional + Value used for extrapolation, when applicable. + + out_dtype : string, optional + Type to return. If left None will be same as input type. + + Returns + ------- + output : tvm.Tensor + 4-D with shape [num_boxes, channel, crop_height, crop_width] + or [num_boxes, crop_height, crop_width, channel] + """ + method = method.lower() + target_h = crop_size[0] + target_w = crop_size[1] + + if layout == 'NHWC': + output_shape = [box_indices.shape[0], crop_size[0], crop_size[1], data.shape[3]] + image_h = data.shape[1].astype("int32") + image_w = data.shape[2].astype("int32") + elif layout == 'NCHW': + output_shape = [box_indices.shape[0], data.shape[1], crop_size[0], crop_size[1]] + image_h = data.shape[2].astype("int32") + image_w = data.shape[3].astype("int32") + elif layout.startswith("NCHW"):# for NCHWxc + output_shape = [box_indices.shape[0], data.shape[1], + crop_size[0], crop_size[1], data.shape[4]] + image_h = data.shape[2].astype("int32") + image_w = data.shape[3].astype("int32") + else: + raise ValueError('%s layout is not supported.' % layout) + + def _bilinear(*indices): + return resize_bilinear(indices, data, image_h, image_w, target_h, + target_w, boxes, box_indices, extrapolation_value, + layout, out_dtype=out_dtype) + + def _nearest_neighbor(*indices): + return resize_nearest_neighbor(indices, data, image_h, image_w, target_h, + target_w, boxes, box_indices, extrapolation_value, + layout, out_dtype=out_dtype) + + # Determine which interpolation method to use then run it. + if method == "nearest_neighbor": + compute_func = _nearest_neighbor + elif method == "bilinear": + compute_func = _bilinear + else: + raise ValueError('%s method is not supported.' % method) + + return tvm.compute(output_shape, compute_func, name='crop_and_resize', tag=tag.INJECTIVE) + + + def resize3d(data, size, layout="NCDHW", method="nearest_neighbor", coordinate_transformation_mode="align_corners", out_dtype=None): """Perform resize operation on the data. - Parameters ---------- inputs: tvm.Tensor inputs is a 5-D tensor with shape [batch, channel, in_depth, in_height, in_width] or [batch, in_depth, in_height, in_width, channel] - size: Tuple Output resolution scale to - layout: string, optional "NCDHW", "NDHWC", or "NCDHWc". - coordinate_transformation_mode: string, optional Describes how to transform the coordinate in the resized tensor to the coordinate in the original tensor. Refer to the ONNX Resize operator specification for details. Available options are "half_pixel", "align_corners" and "asymmetric". - method: {"trilinear", "nearest_neighbor"} Method to be used for resizing. - out_dtype: string, optional Type to return. If left None will be same as input type. - Returns ------- output : tvm.Tensor diff --git a/topi/python/topi/testing/__init__.py b/topi/python/topi/testing/__init__.py index 7546d7cd15a16..87e48ff006006 100644 --- a/topi/python/topi/testing/__init__.py +++ b/topi/python/topi/testing/__init__.py @@ -51,3 +51,4 @@ from .one_hot import one_hot from .depth_to_space import depth_to_space_python from .space_to_depth import space_to_depth_python +from .crop_and_resize_python import crop_and_resize_python diff --git a/topi/python/topi/testing/crop_and_resize_python.py b/topi/python/topi/testing/crop_and_resize_python.py new file mode 100644 index 0000000000000..a5f2cc0a614b1 --- /dev/null +++ b/topi/python/topi/testing/crop_and_resize_python.py @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals, too-many-nested-blocks +"""crop and resize in python""" +import math +import numpy as np + +def crop_and_resize_python(image, boxes, box_indices, crop_size, layout, + method='bilinear', extrapolation_value=0): + """Crop and resize using python""" + (target_h, target_w) = crop_size + + if layout == 'NHWC': + batch = boxes.shape[0] + image_height, image_width, channel = image.shape[1], image.shape[2], image.shape[3] + scaled_image = np.ones((batch, target_h, target_w, channel)) + else: + batch = boxes.shape[0] + channel, image_height, image_width = image.shape[1], image.shape[2], image.shape[3] + scaled_image = np.ones((batch, channel, target_h, target_w)) + + for n, box in enumerate(boxes): + b_in = box_indices[n] + y1, x1 = boxes[n][0], boxes[n][1] + y2, x2 = boxes[n][2], boxes[n][3] + + in_h = (image_height - 1) * (y2 - y1) + in_w = (image_width - 1) * (x2 - x1) + h_scale = np.float32(in_h)/np.float32(target_h - 1) + w_scale = np.float32(in_w)/np.float32(target_w - 1) + + for y in range(target_h): + + in_y = y1 * (image_height - 1) + h_scale * y + + if in_y < 0 or in_y > image_height - 1: + for x in range(target_w): + for d in range(channel): + if layout == 'NHWC': + scaled_image[n][y][x][d] = extrapolation_value + else: + scaled_image[n][d][y][x] = extrapolation_value + continue + + if method == 'bilinear': + top_y_index = math.floor(in_y) + bottom_y_index = math.ceil(in_y) + y_lerp = in_y - top_y_index + + for x in range(target_w): + in_x = x1 * (image_width - 1) + x * w_scale + if in_x < 0 or in_x > image_width - 1: + for d in range(channel): + if layout == 'NHWC': + scaled_image[n][y][x][d] = extrapolation_value + else: + scaled_image[n][d][y][x] = extrapolation_value + continue + + left_x_index = math.floor(in_x) + right_x_index = math.ceil(in_x) + x_lerp = in_x - left_x_index + + for d in range(channel): + if layout == "NHWC": + top_left = image[b_in][top_y_index][left_x_index][d] + top_right = image[b_in][top_y_index][right_x_index][d] + bottom_left = image[b_in][bottom_y_index][left_x_index][d] + bottom_right = image[b_in][bottom_y_index][right_x_index][d] + top = top_left + (top_right - top_left) * x_lerp + bottom = bottom_left + (bottom_right - bottom_left) * x_lerp + scaled_image[n][y][x][d] = top + (bottom - top) * y_lerp + else: + top_left = image[b_in][d][top_y_index][left_x_index] + top_right = image[b_in][d][top_y_index][right_x_index] + bottom_left = image[b_in][d][bottom_y_index][left_x_index] + bottom_right = image[b_in][d][bottom_y_index][right_x_index] + top = top_left + (top_right - top_left) * x_lerp + bottom = bottom_left + (bottom_right - bottom_left) * x_lerp + scaled_image[n][d][y][x] = top + (bottom - top) * y_lerp + + elif method == 'nearest_neighbor': + for x in range(target_w): + in_x = x1 * (image_width - 1) + x * w_scale + if in_x < 0 or in_x > image_width - 1: + for d in range(channel): + if layout == 'NHWC': + scaled_image[n][y][x][d] = extrapolation_value + else: + scaled_image[n][d][y][x] = extrapolation_value + continue + closest_x_index = np.round(in_x).astype("int32") + closest_y_index = np.round(in_y).astype("int32") + for d in range(channel): + if layout == "NHWC": + scaled_image[n][y][x][d] = image[b_in][closest_y_index][closest_x_index][d] + else: + scaled_image[n][d][y][x] = image[b_in][d][closest_y_index][closest_x_index] + + return scaled_image diff --git a/topi/tests/python/test_topi_resize.py b/topi/tests/python/test_topi_image.py similarity index 66% rename from topi/tests/python/test_topi_resize.py rename to topi/tests/python/test_topi_image.py index 206903ff1dc1b..21935cb911da1 100644 --- a/topi/tests/python/test_topi_resize.py +++ b/topi/tests/python/test_topi_image.py @@ -19,7 +19,6 @@ import tvm import topi import topi.testing -import math from common import get_all_backend @@ -99,7 +98,7 @@ def verify_resize3d(batch, in_channel, in_depth, in_height, in_width, out_depth, 'Layout not supported {} '.format(layout)) B = topi.image.resize3d(A, (out_depth, out_height, out_width), layout=layout, - coordinate_transformation_mode=coordinate_transformation_mode, method=method) + coordinate_transformation_mode=coordinate_transformation_mode, method=method) if method == "trilinear": b_np = topi.testing.trilinear_resize3d_python(a_np, (out_depth, out_height, out_width), layout, @@ -143,6 +142,68 @@ def test_resize3d(): verify_resize3d(4, 8, 16, 16, 16, 25, 25, 25, 'NDHWC', method="nearest_neighbor") +def test_crop_and_resize(): + def verify_crop_and_resize(image_shape, np_boxes, np_box_indices, np_crop_size, layout='NHWC', + method="bilinear", extrapolation_value=0.0): + + images = tvm.placeholder(image_shape, name='images', dtype='float32') + np_images = np.random.uniform(size=image_shape).astype("float32") + boxes = tvm.placeholder(np_boxes.shape, name="boxes", dtype="float32") + box_ind = tvm.placeholder(np_box_indices.shape, name="box_ind", dtype="int32") + + batch = len(np_box_indices) + target_height, target_width = np_crop_size[0], np_crop_size[1] + if layout == 'NHWC': + channel = image_shape[3] + out_shape = (batch, target_height, target_width, channel) + elif layout == 'NCHW': + channel = image_shape[1] + out_shape = (batch, channel, target_height, target_width) + else: + raise NotImplementedError( + 'Layout {} is not supported.'.format(layout)) + + out = topi.image.crop_and_resize(images, boxes, box_ind, np_crop_size, layout=layout, + method=method, extrapolation_value=extrapolation_value) + + baseline_np = topi.testing.crop_and_resize_python(np_images, np_boxes, np_box_indices, + np_crop_size, layout, method, + extrapolation_value) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + s = topi.generic.schedule_injective(out) + tvm_images = tvm.nd.array(np_images, ctx) + tvm_boxes = tvm.nd.array(np_boxes, ctx) + tvm_indices = tvm.nd.array(np_box_indices, ctx) + tvm_out = tvm.nd.array(np.zeros(out_shape, dtype="float32"), ctx) + f = tvm.build(s, [images, boxes, box_ind, out], device, name="crop_and_resize") + f(tvm_images, tvm_boxes, tvm_indices, tvm_out) + + tvm.testing.assert_allclose(tvm_out.asnumpy(), baseline_np, rtol=1e-3, atol=1e-3) + + for device in get_all_backend(): + check_device(device) + + boxes_1 = np.array([[.2, .3, .7, .9]], dtype="float32") + boxes_2 = np.array([[.2, .3, .7, .9], [0, .1, .8, 1]], dtype="float32") + indices_1 = np.array([0], dtype="int32") + indices_2 = np.array([1, 0], dtype="int32") + size_1 = (7, 11) + size_2 = (90, 60) + + verify_crop_and_resize((1, 255, 255, 3), boxes_1, indices_1, size_1, layout="NHWC") + verify_crop_and_resize((10, 224, 224, 5), boxes_2, indices_2, + size_2, extrapolation_value=0.3, layout="NHWC") + verify_crop_and_resize((1, 100, 100, 3), boxes_1, indices_1, + size_1, method='nearest_neighbor') + verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, layout="NCHW") + if __name__ == "__main__": test_resize() test_resize3d() + test_crop_and_resize()