diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 73b7339e2edb..df059a6238e1 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -109,6 +109,44 @@ struct YoloReorgAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in proposal operators */ +struct ProposalAttrs : public tvm::AttrsNode { + Array scales; + Array ratios; + int feature_stride; + double threshold; + int rpn_pre_nms_top_n; + int rpn_post_nms_top_n; + int rpn_min_size; + bool iou_loss; + + TVM_DECLARE_ATTRS(ProposalAttrs, "relay.attrs.ProposalAttrs") { + TVM_ATTR_FIELD(scales) + .set_default(Array({4.0f, 8.0f, 16.0f, 32.0f})) + .describe("Used to generate anchor windows by enumerating scales"); + TVM_ATTR_FIELD(ratios) + .set_default(Array({0.5f, 1.0f, 2.0f})) + .describe("Used to generate anchor windows by enumerating ratios"); + TVM_ATTR_FIELD(feature_stride) + .set_default(16) + .describe( + "The size of the receptive field each unit in the convolution layer of the rpn," + "for example the product of all stride's prior to this layer."); + TVM_ATTR_FIELD(threshold) + .set_default(0.7) + .describe( + "IoU threshold of non-maximum suppresion (suppress boxes with IoU >= this threshold)"); + TVM_ATTR_FIELD(rpn_pre_nms_top_n) + .set_default(6000) + .describe("Number of top scoring boxes to apply NMS. -1 to use all boxes"); + TVM_ATTR_FIELD(rpn_post_nms_top_n) + .set_default(300) + .describe("Number of top scoring boxes to keep after applying NMS to RPN proposals"); + TVM_ATTR_FIELD(rpn_min_size).set_default(16).describe("Minimum height or width in proposal"); + TVM_ATTR_FIELD(iou_loss).set_default(false).describe("Usage of IoU Loss"); + } +}; + } // namespace relay } // namespace tvm #endif // TVM_RELAY_ATTRS_VISION_H_ diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 4d341c76043a..69fa5e719f30 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -351,6 +351,20 @@ def _mx_roi_align(inputs, attrs): return _op.vision.roi_align(inputs[0], inputs[1], **new_attrs) +def _mx_proposal(inputs, attrs): + new_attrs = {} + new_attrs["scales"] = attrs.get_float_tuple("scales", (4.0, 8.0, 16.0, 32.0)) + new_attrs["ratios"] = attrs.get_float_tuple("ratios", (0.5, 1.0, 2.0)) + new_attrs["feature_stride"] = attrs.get_int("feature_stride", 16) + new_attrs["threshold"] = attrs.get_float("threshold", 0.7) + new_attrs["rpn_pre_nms_top_n"] = attrs.get_int("rpn_pre_nms_top_n", 6000) + new_attrs["rpn_post_nms_top_n"] = attrs.get_int("rpn_post_nms_top_n", 300) + new_attrs["rpn_min_size"] = attrs.get_int("rpn_min_size", 16) + new_attrs["iou_loss"] = attrs.get_bool("iou_loss", False) + assert not attrs.get_bool("output_score", False), "proposal doesn't support output score" + return _op.vision.proposal(inputs[0], inputs[1], inputs[2], **new_attrs) + + # Note: due to attribute conversion constraint # ops in the identity set must be attribute free _identity_list = [ @@ -466,6 +480,8 @@ def _mx_roi_align(inputs, attrs): "_contrib_MultiBoxPrior" : _mx_multibox_prior, "_contrib_MultiBoxDetection" : _mx_multibox_detection, "_contrib_ROIAlign" : _mx_roi_align, + "_contrib_Proposal" : _mx_proposal, + "_contrib_MultiProposal" : _mx_proposal, # List of missing operators that are present in NNVMv1 # TODO(tvm-tvm): support all operators. # diff --git a/python/tvm/relay/op/vision/_rcnn.py b/python/tvm/relay/op/vision/_rcnn.py index 2617bf8562b9..9606ee64c7be 100644 --- a/python/tvm/relay/op/vision/_rcnn.py +++ b/python/tvm/relay/op/vision/_rcnn.py @@ -1,7 +1,7 @@ # pylint: disable=invalid-name, unused-argument """Faster R-CNN and Mask R-CNN operations.""" import topi -from topi.util import get_const_tuple +from topi.util import get_const_tuple, get_float_tuple, get_const_int from .. import op as reg from ..op import OpPattern @@ -21,3 +21,29 @@ def schedule_roi_align(_, outs, target): return topi.generic.vision.schedule_roi_align(outs) reg.register_pattern("vision.roi_align", OpPattern.OUT_ELEMWISE_FUSABLE) + +@reg.register_compute("vision.proposal") +def compute_proposal(attrs, inputs, _, target): + """Compute definition of proposal""" + scales = get_float_tuple(attrs.scales) + ratios = get_float_tuple(attrs.ratios) + feature_stride = attrs.feature_stride + threshold = attrs.threshold + rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n + rpn_post_nms_top_n = attrs.rpn_post_nms_top_n + rpn_min_size = attrs.rpn_min_size + iou_loss = bool(get_const_int(attrs.iou_loss)) + with target: + return [ + topi.vision.rcnn.proposal(inputs[0], inputs[1], inputs[2], scales, ratios, + feature_stride, threshold, rpn_pre_nms_top_n, + rpn_post_nms_top_n, rpn_min_size, iou_loss) + ] + +@reg.register_schedule("vision.proposal") +def schedule_proposal(_, outs, target): + """Schedule definition of proposal""" + with target: + return topi.generic.schedule_proposal(outs) + +reg.register_pattern("vision.proposal", OpPattern.OPAQUE) diff --git a/python/tvm/relay/op/vision/rcnn.py b/python/tvm/relay/op/vision/rcnn.py index 8bbafbe75c53..8e95435d0ecc 100644 --- a/python/tvm/relay/op/vision/rcnn.py +++ b/python/tvm/relay/op/vision/rcnn.py @@ -30,3 +30,63 @@ def roi_align(data, rois, pooled_size, spatial_scale, sample_ratio=-1, layout='N 4-D tensor with shape [num_roi, channel, pooled_size, pooled_size] """ return _make.roi_align(data, rois, pooled_size, spatial_scale, sample_ratio, layout) + + +def proposal(cls_prob, + bbox_pred, + im_info, + scales, + ratios, + feature_stride, + threshold, + rpn_pre_nms_top_n, + rpn_post_nms_top_n, + rpn_min_size, + iou_loss): + """Proposal operator. + + Parameters + ---------- + cls_prob : relay.Expr + 4-D tensor with shape [batch, 2 * num_anchors, height, width]. + + bbox_pred : relay.Expr + 4-D tensor with shape [batch, 4 * num_anchors, height, width]. + + im_info : relay.Expr + 2-D tensor with shape [batch, 3]. The last dimension should be in format of + [im_height, im_width, im_scale] + + scales : list/tuple of float + Scales of anchor windoes. + + ratios : list/tuple of float + Ratios of anchor windoes. + + feature_stride : int + The size of the receptive field each unit in the convolution layer of the rpn, for example + the product of all stride's prior to this layer. + + threshold : float + Non-maximum suppression threshold. + + rpn_pre_nms_top_n : int + Number of top scoring boxes to apply NMS. -1 to use all boxes. + + rpn_post_nms_top_n : int + Number of top scoring boxes to keep after applying NMS to RPN proposals. + + rpn_min_size : int + Minimum height or width in proposal. + + iou_loss : bool + Usage of IoU loss. + + Returns + ------- + output : relay.Expr + 2-D tensor with shape [batch * rpn_post_nms_top_n, 5]. The last dimension is in format of + [batch_index, w_start, h_start, w_end, h_end]. + """ + return _make.proposal(cls_prob, bbox_pred, im_info, scales, ratios, feature_stride, threshold, + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_min_size, iou_loss) diff --git a/src/relay/op/vision/rcnn_op.cc b/src/relay/op/vision/rcnn_op.cc index e46eaf2207fb..6dbc76599708 100644 --- a/src/relay/op/vision/rcnn_op.cc +++ b/src/relay/op/vision/rcnn_op.cc @@ -63,5 +63,72 @@ RELAY_REGISTER_OP("vision.roi_align") .set_support_level(5) .add_type_rel("ROIAlign", ROIAlignRel); +TVM_REGISTER_NODE_TYPE(ProposalAttrs); + +bool ProposalRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + auto proposal_attrs = attrs.as(); + CHECK_EQ(types.size(), 4); + const auto* cls_prob = types[0].as(); + const auto* bbox_pred = types[1].as(); + const auto* im_info = types[2].as(); + + if (!cls_prob || !bbox_pred || !im_info) { + return false; + } + + CHECK_EQ(cls_prob->shape.size(), 4U) + << "The dimension of class probability should be 4, but received " << cls_prob->shape.size(); + CHECK_EQ(bbox_pred->shape.size(), 4U) + << "The dimension of box prediction should be 4, but received " << bbox_pred->shape.size(); + CHECK_EQ(im_info->shape.size(), 2U) + << "The dimension of image info should be 2, but received " << im_info->shape.size(); + CHECK(reporter->AssertEQ(im_info->shape[1], 3)); + + auto batch = cls_prob->shape[0]; + + std::vector oshape( + {batch * proposal_attrs->rpn_post_nms_top_n, 5}); + reporter->Assign(types[3], TensorTypeNode::make(oshape, cls_prob->dtype)); + return true; +} + +Expr MakeProposal(Expr cls_prob, Expr bbox_pred, Expr im_info, Array scales, + Array ratios, int feature_stride, double threshold, + int rpn_pre_nms_top_n, int rpn_post_nms_top_n, int rpn_min_size, + bool iou_loss) { + auto attrs = make_node(); + attrs->scales = scales; + attrs->ratios = ratios; + attrs->feature_stride = feature_stride; + attrs->threshold = threshold; + attrs->rpn_pre_nms_top_n = rpn_pre_nms_top_n; + attrs->rpn_post_nms_top_n = rpn_post_nms_top_n; + attrs->rpn_min_size = rpn_min_size; + attrs->iou_loss = iou_loss; + static const Op& op = Op::Get("vision.proposal"); + return CallNode::make(op, {cls_prob, bbox_pred, im_info}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op.vision._make.proposal") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeProposal, args, rv); + }); + +RELAY_REGISTER_OP("vision.proposal") + .describe(R"code(Generate region proposals via RPN. + + - **cls_prob**: 4-D with shape [batch, 2 * num_anchors, height, width]. + - **bbox_pred**: 4-D with shape [batch, 4 * num_anchors, height, width]. + - **im_info**: 2-D with shape [batch, 3]. + - **out**: 2-D with shape [batch * rpn_post_nms_top_n, 5]. + )code" TVM_ADD_FILELINE) +.set_num_inputs(3) +.add_argument("cls_prob", "Tensor", "Score of how likely proposal is object") +.add_argument("bbox_pred", "Tensor", "BBox predicted deltas from anchors for proposals") +.add_argument("im_info", "Tensor", "Image size and scale") +.set_support_level(5) +.add_type_rel("Proposal", ProposalRel); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 8db6d747ef5e..003318f01a2f 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -306,6 +306,72 @@ def verify_roi_align(data_shape, rois_shape, pooled_size, spatial_scale, sample_ verify_roi_align((4, 4, 16, 16), (32, 5), pooled_size=7, spatial_scale=0.5, sample_ratio=2) +def test_proposal(): + def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): + cls_prob = relay.var("cls_prob", relay.ty.TensorType(np_cls_prob.shape, "float32")) + bbox_pred = relay.var("bbox_pred", relay.ty.TensorType(np_bbox_pred.shape, "float32")) + im_info = relay.var("im_info", relay.ty.TensorType(np_im_info.shape, "float32")) + z = relay.vision.proposal(cls_prob, bbox_pred, im_info, **attrs) + zz = relay.ir_pass.infer_type(z) + + assert zz.checked_type == relay.ty.TensorType(np_out.shape, "float32") + + func = relay.Function([cls_prob, bbox_pred, im_info], z) + func = relay.ir_pass.infer_type(func) + for target in ['cuda']: + if not tvm.module.enabled(target): + print("Skip test because %s is not enabled." % target) + continue + ctx = tvm.context(target, 0) + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info) + tvm.testing.assert_allclose(op_res1.asnumpy(), np_out, rtol=1e-4) + intrp2 = relay.create_executor("debug", ctx=ctx, target=target) + op_res2 = intrp2.evaluate(func)(np_cls_prob, np_bbox_pred, np_im_info) + tvm.testing.assert_allclose(op_res2.asnumpy(), np_out, rtol=1e-4) + + attrs = { + 'scales': (0.5,), + 'ratios': (0.5,), + 'feature_stride': 16, + 'iou_loss': False, + 'rpn_min_size': 16, + 'threshold': 0.7, + 'rpn_pre_nms_top_n': 200, + 'rpn_post_nms_top_n': 4, + } + + np_cls_prob = np.array([[ + [[0.3, 0.6, 0.2], [0.4, 0.7, 0.5], [0.1, 0.4, 0.3]], + [[0.7, 0.5, 0.3], [0.6, 0.4, 0.8], [0.9, 0.2, 0.5]] + ]], dtype='float32') + np_bbox_pred = np.array([[ + [[0.5, 1.0, 0.6], [0.8, 1.2, 2.0], [0.9, 1.0, 0.8]], + [[0.5, 1.0, 0.7], [0.8, 1.2, 1.6], [2.1, 1.5, 0.7]], + [[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]], + [[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]], + ]], dtype='float32') + np_im_info = np.array([[48., 48., 1.]], dtype='float32') + np_out = np.array([ + [0., 0., 2.8451548,28.38012, 18.154846], + [0., 0., 15.354933, 41.96971, 41.245064], + [0., 18.019852, 1.0538368, 51.98015, 25.946163], + [0., 27.320923, -1.266357, 55., 24.666357] + ], dtype='float32') + + + verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) + + np_out = np.array([ + [ 0., -5.25, -2.5, 21.75, 19.], + [ 0., 11.25, -2., 37.25, 18.5], + [ 0., 26.849998, -2.3000002, 53.45, 18.6], + [ 0., -4.95, 13.799999, 22.25, 35.5] + ], dtype='float32') + attrs['iou_loss'] = True + verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) + + def test_yolo_reorg_infer_shape(): def verify_yolo_reorg(shape, stride, out_shape): x = relay.var("x", relay.TensorType(shape, "float32")) @@ -347,5 +413,6 @@ def verify_yolo_reorg(shape, stride): test_multibox_transform_loc() test_nms() test_roi_align() + test_proposal() test_yolo_reorg_infer_shape() test_yolo_reorg() diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 135b3857df31..3c0c3aa854d7 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -210,7 +210,7 @@ def test_roi_align(): def verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs): cls_prob = tvm.placeholder(np_cls_prob.shape) bbox_pred = tvm.placeholder(np_bbox_pred.shape) - im_info = tvm.placeholder(np_im_info.shape, dtype='int32') + im_info = tvm.placeholder(np_im_info.shape) def check_device(device): ctx = tvm.context(device, 0) @@ -252,7 +252,7 @@ def test_proposal(): [[1.0, 0.5, 0.7], [1.5, 0.9, 1.6], [1.4, 1.5, 0.8]], [[1.0, 0.5, 0.6], [1.5, 0.9, 2.0], [1.8, 1.0, 0.9]], ]], dtype='float32') - np_im_info = np.array([[48, 48, 1]], dtype='int32') + np_im_info = np.array([[48., 48., 1.]], dtype='float32') np_out = np.array([ [0., 0., 2.8451548,28.38012, 18.154846], [0., 0., 15.354933, 41.96971, 41.245064],