From 8b1d07ff7b33d522588e9360227fb2e218a93211 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 20 Feb 2019 09:03:34 +0800 Subject: [PATCH] [RELAY][OP] ROI Align (#2618) --- include/tvm/relay/attrs/vision.h | 24 ++++++++ python/tvm/relay/frontend/mxnet.py | 10 ++++ python/tvm/relay/op/vision/__init__.py | 2 + python/tvm/relay/op/vision/_rcnn.py | 23 ++++++++ python/tvm/relay/op/vision/rcnn.py | 32 +++++++++++ src/relay/op/vision/rcnn_op.cc | 67 +++++++++++++++++++++++ tests/python/relay/test_op_level5.py | 35 ++++++++++++ topi/python/topi/vision/rcnn/roi_align.py | 4 +- 8 files changed, 195 insertions(+), 2 deletions(-) create mode 100644 python/tvm/relay/op/vision/_rcnn.py create mode 100644 python/tvm/relay/op/vision/rcnn.py create mode 100644 src/relay/op/vision/rcnn_op.cc diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index b736bd9c06a0d..d1a5ea41bc694 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -74,6 +74,30 @@ struct NMSAttrs : public tvm::AttrsNode{ } }; +/*! \brief Attributes used in roi_align operators */ +struct ROIAlignAttrs : public tvm::AttrsNode { + Array pooled_size; + double spatial_scale; + int sample_ratio; + std::string layout; + TVM_DECLARE_ATTRS(ROIAlignAttrs, "relay.attrs.ROIAlignAttrs") { + TVM_ATTR_FIELD(pooled_size).describe("Output size of roi align."); + TVM_ATTR_FIELD(spatial_scale) + .describe( + "Ratio of input feature map height (or w) to raw image height (or w). " + "Equals the reciprocal of total stride in convolutional layers, which should be " + "in range (0.0, 1.0]"); + TVM_ATTR_FIELD(sample_ratio) + .set_default(-1) + .describe("Optional sampling ratio of ROI align, using adaptive size by default."); + TVM_ATTR_FIELD(layout).set_default("NCHW").describe( + "Dimension ordering of data and weight. Can be 'NCHW', 'NHWC', etc." + "'N', 'C', 'H', 'W' stands for batch, channel, height, and width" + "dimensions respectively. Convolution is applied on the 'H' and" + "'W' dimensions."); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_VISION_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index ea49a6642796e..f8b51c413193a 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -268,6 +268,15 @@ def _mx_multibox_detection(inputs, attrs): return _op.vision.nms(ret[0], ret[1], **new_attrs1) +def _mx_roi_align(inputs, attrs): + new_attrs = {} + new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size") + new_attrs["spatial_scale"] = attrs.get_float("spatial_scale") + new_attrs["sample_ratio"] = attrs.get_int("sample_ratio", -1) + new_attrs["layout"] = "NCHW" + return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs) + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -357,6 +366,7 @@ def _mx_multibox_detection(inputs, attrs): # vision "_contrib_MultiBoxPrior" : _mx_multibox_prior, "_contrib_MultiBoxDetection" : _mx_multibox_detection, + "_contrib_ROIAlign" : _mx_roi_align, # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # diff --git a/python/tvm/relay/op/vision/__init__.py b/python/tvm/relay/op/vision/__init__.py index ea3ed69e8f384..710adfeb4955c 100644 --- a/python/tvm/relay/op/vision/__init__.py +++ b/python/tvm/relay/op/vision/__init__.py @@ -4,4 +4,6 @@ from .multibox import * from .nms import * +from .rcnn import * from . import _multibox +from . import _rcnn diff --git a/python/tvm/relay/op/vision/_rcnn.py b/python/tvm/relay/op/vision/_rcnn.py new file mode 100644 index 0000000000000..2617bf8562b9e --- /dev/null +++ b/python/tvm/relay/op/vision/_rcnn.py @@ -0,0 +1,23 @@ +# pylint: disable=invalid-name, unused-argument +"""Faster R-CNN and Mask R-CNN operations.""" +import topi +from topi.util import get_const_tuple +from .. import op as reg +from ..op import OpPattern + + +@reg.register_compute("vision.roi_align") +def compute_roi_align(attrs, inputs, _, target): + """Compute definition of roi_align""" + assert attrs.layout == "NCHW" + return [topi.vision.rcnn.roi_align_nchw( + inputs[0], inputs[1], pooled_size=get_const_tuple(attrs.pooled_size), + spatial_scale=attrs.spatial_scale, sample_ratio=attrs.sample_ratio)] + +@reg.register_schedule("vision.roi_align") +def schedule_roi_align(_, outs, target): + """Schedule definition of roi_align""" + with target: + return topi.generic.vision.schedule_roi_align(outs) + +reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/vision/rcnn.py b/python/tvm/relay/op/vision/rcnn.py new file mode 100644 index 0000000000000..8bbafbe75c538 --- /dev/null +++ b/python/tvm/relay/op/vision/rcnn.py @@ -0,0 +1,32 @@ +"""Faster R-CNN and Mask R-CNN operations.""" +from . import _make + + +def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='NCHW'): + """ROI align operator. + + Parameters + ---------- + data : relay.Expr + 4-D tensor with shape [batch, channel, height, width] + + rois : relay.Expr + 2-D tensor with shape [num_roi, 5]. The last dimension should be in format of + [batch_index, w_start, h_start, w_end, h_end] + + pooled_size : list/tuple of two ints + output size + + spatial_scale : float + Ratio of input feature map height (or w) to raw image height (or w). Equals the reciprocal + of total stride in convolutional layers, which should be in range (0.0, 1.0] + + sample_ratio : int + Optional sampling ratio of ROI align, using adaptive size by default. + + Returns + ------- + output : relay.Expr + 4-D tensor with shape [num_roi, channel, pooled_size, pooled_size] + """ + return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout) diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc new file mode 100644 index 0000000000000..e46eaf2207fba --- /dev/null +++ b/src/relay/op/vision/rcnn_op.cc @@ -0,0 +1,67 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file rcnn_op.cc + * \brief Faster RCNN and Mask RCNN operators + */ +#include +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(ROIAlignAttrs); + +bool ROIAlignRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + auto roi_align_attrs = attrs.as(); + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* rois = types[1].as(); + const auto& dshape = data->shape; + const auto& rshape = rois->shape; + CHECK(roi_align_attrs); + CHECK_EQ(dshape.size(), 4) << "Input data should be 4-D."; + CHECK_EQ(rshape.size(), 2) << "Input rois should be 2-D."; + CHECK_EQ(roi_align_attrs->layout, "NCHW") << "ROI Align only supports NCHW layout"; + // assign output type + std::vector oshape( + {rshape[0], dshape[1], roi_align_attrs->pooled_size[0], roi_align_attrs->pooled_size[1]}); + reporter->Assign(types[2], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +Expr MakeROIAlign(Expr data, Expr rois, Array pooled_size, double spatial_scale, + int sample_ratio, std::string layout) { + auto attrs = make_node(); + attrs->pooled_size = pooled_size; + attrs->spatial_scale = spatial_scale; + attrs->sample_ratio = sample_ratio; + attrs->layout = layout; + static const Op& op = Op::Get("vision.roi_align"); + return CallNode::make(op, {data, rois}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.vision._make.roi_align") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeROIAlign, args, rv); + }); + +RELAY_REGISTER_OP("vision.roi_align") + .describe(R"doc(ROI Align operator. + + - **data**: This depends on the `layout` parameter. Input is 4D array of shape + (batch_size, channels, height, width) if `layout` is `NCHW`. + - **rois**: 2D array of shape (num_roi, 5). The last dimension should be in format of + [batch_index, w_start, h_start, w_end, h_end]. + - **out**: This depends on the `layout` parameter. Output is 4D array of shape + (num_roi, channels, pooled_height, pooled_width) if `layout` is `NCHW`. + )doc" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("rois", "Tensor", "The input rois") +.set_support_level(5) +.add_type_rel("ROIAlign", ROIAlignRel); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index aa31aa96ef45e..1d91d92a6abcc 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -273,9 +273,44 @@ def test_threshold(): test_threshold() +def test_roi_align(): + def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ratio): + data = relay.var("data", relay.ty.TensorType(data_shape, "float32")) + rois = relay.var("rois", relay.ty.TensorType(rois_shape, "float32")) + z = relay.vision.roi_align(data, rois, pooled_size=(pooled_size, pooled_size), + spatial_scale=spatial_scale, sample_ratio=sample_ratio, + layout="NCHW") + zz = relay.ir_pass.infer_type(z) + + batch, channel, in_size, _ = data_shape + num_roi = rois_shape[0] + assert zz.checked_type == relay.ty.TensorType( + (num_roi, channel, pooled_size, pooled_size), "float32") + + func = relay.Function([data, rois], z) + func = relay.ir_pass.infer_type(func) + np_data = np.random.uniform(size=data_shape).astype("float32") + np_rois = np.random.uniform(size=rois_shape).astype('float32') * in_size + np_rois[:, 0] = np.random.randint(low = 0, high = batch, size = num_roi) + ref_res = topi.testing.roi_align_nchw_python(np_data, np_rois, pooled_size=pooled_size, + spatial_scale=spatial_scale, + sample_ratio=sample_ratio) + for target, ctx in ctx_list(): + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(np_data, np_rois) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-4) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res2 = intrp2.evaluate(func)(np_data, np_rois) + tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-4) + + verify_roi_align((1, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=1.0, sample_ratio=-1) + verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2) + + if __name__ == "__main__": test_resize_infer_type() test_resize() test_multibox_prior() test_multibox_transform_loc() test_nms() + test_roi_align() diff --git a/topi/python/topi/vision/rcnn/roi_align.py b/topi/python/topi/vision/rcnn/roi_align.py index 397341aa1d010..760dce10d5363 100644 --- a/topi/python/topi/vision/rcnn/roi_align.py +++ b/topi/python/topi/vision/rcnn/roi_align.py @@ -68,8 +68,8 @@ def _sample(i, c, ph, pw): if sample_ratio > 0: roi_bin_grid_h = roi_bin_grid_w = tvm.const(sample_ratio, 'int32') else: - roi_bin_grid_h = tvm.ceil(roi_h / pooled_size).astype('int32') - roi_bin_grid_w = tvm.ceil(roi_w / pooled_size).astype('int32') + roi_bin_grid_h = tvm.ceil(roi_h / pooled_size_h).astype('int32') + roi_bin_grid_w = tvm.ceil(roi_w / pooled_size_w).astype('int32') count = roi_bin_grid_h * roi_bin_grid_w rh = tvm.reduce_axis((0, roi_bin_grid_h))