diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 550e24b8de26..2b905f5bd04b 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -88,7 +88,7 @@ struct GetValidCountsAttrs : public tvm::AttrsNode { /*! \brief Attributes used in non_maximum_suppression operator */ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode { - int max_output_size; + Optional max_output_size; double iou_threshold; bool force_suppress; int top_k; @@ -99,11 +99,7 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode 4 else 0.0 diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index a0dd6bfe7b15..e9feee669ea5 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -747,8 +747,10 @@ def get_valid_counts_strategy(attrs, inputs, out_type, target): def wrap_compute_nms(topi_compute): """wrap nms topi compute""" def _compute_nms(attrs, inputs, out_type): + max_output_size = inputs[3] + if attrs.max_output_size is not None: + max_output_size = attrs.max_output_size return_indices = bool(get_const_int(attrs.return_indices)) - max_output_size = get_const_int(attrs.max_output_size) iou_threshold = get_const_float(attrs.iou_threshold) force_suppress = bool(get_const_int(attrs.force_suppress)) top_k = get_const_int(attrs.top_k) diff --git a/python/tvm/relay/op/vision/nms.py b/python/tvm/relay/op/vision/nms.py index b60b49ab0ccd..60ff7a59f9f0 100644 --- a/python/tvm/relay/op/vision/nms.py +++ b/python/tvm/relay/op/vision/nms.py @@ -91,9 +91,9 @@ def non_max_suppression(data, second dimension are like the output of arange(num_anchors) if get_valid_counts is not used before non_max_suppression. - max_output_size : int, optional + max_output_size : int or relay.Expr, optional Max number of output valid boxes for each instance. - By default all valid boxes are returned. + Return all valid boxes if the value of max_output_size is less than 0. iou_threshold : float, optional Non-maximum suppression threshold. @@ -124,9 +124,11 @@ def non_max_suppression(data, out : relay.Expr or relay.Tuple return relay.Expr if return_indices is disabled, a 3-D tensor with shape [batch_size, num_anchors, 6] or [batch_size, num_anchors, 5]. - if return_indices is True, return relay.Tuple of two 2-D tensors, with + If return_indices is True, return relay.Tuple of two 2-D tensors, with shape [batch_size, num_anchors] and [batch_size, num_valid_anchors] respectively. """ + if isinstance(max_output_size, int): + max_output_size = expr.const(max_output_size, "int32") out = _make.non_max_suppression(data, valid_count, indices, diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 7486db790780..f9cdaf66e255 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -73,7 +73,7 @@ TVM_REGISTER_NODE_TYPE(NonMaximumSuppressionAttrs); bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { - CHECK_EQ(types.size(), 4); + CHECK_EQ(types.size(), 5); const auto* data = types[0].as(); const auto* valid_count = types[1].as(); const NonMaximumSuppressionAttrs* param = attrs.as(); @@ -90,18 +90,17 @@ bool NMSRel(const Array& types, int num_inputs, const Attrs& attrs, fields.push_back(TensorType(oshape, DataType::Int(32))); std::vector countshape({dshape[0], 1}); fields.push_back(TensorType(countshape, DataType::Int(32))); - reporter->Assign(types[3], TupleType(Array(fields))); + reporter->Assign(types[4], TupleType(Array(fields))); } else { - reporter->Assign(types[3], TensorType(dshape, data->dtype)); + reporter->Assign(types[4], TensorType(dshape, data->dtype)); } return true; } -Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size, double iou_threshold, +Expr MakeNMS(Expr data, Expr valid_count, Expr indices, Expr max_output_size, double iou_threshold, bool force_suppress, int top_k, int coord_start, int score_index, int id_index, bool return_indices, bool invalid_to_bottom) { auto attrs = make_object(); - attrs->max_output_size = max_output_size; attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; attrs->top_k = top_k; @@ -111,7 +110,7 @@ Expr MakeNMS(Expr data, Expr valid_count, Expr indices, int max_output_size, dou attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; static const Op& op = Op::Get("vision.non_max_suppression"); - return Call(op, {data, valid_count, indices}, Attrs(attrs), {}); + return Call(op, {data, valid_count, indices, max_output_size}, Attrs(attrs), {}); } TVM_REGISTER_GLOBAL("relay.op.vision._make.non_max_suppression").set_body_typed(MakeNMS); @@ -122,10 +121,11 @@ be in the format of [class_id, score, left, top, right, bottom] or [score, left, top, right, bottom]. Set id_index to be -1 to ignore class_id axis. )doc" TVM_ADD_FILELINE) - .set_num_inputs(3) + .set_num_inputs(4) .add_argument("data", "Tensor", "Input data.") .add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") .add_argument("indices", "Tensor", "Corresponding indices in original input tensor.") + .add_argument("max_output_size", "Tensor", "Max number of output valid boxes.") .set_support_level(5) .add_type_rel("NMS", NMSRel); diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 1a0baf8e2a80..182c2d72447a 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2020,15 +2020,16 @@ def test_forward_crop_and_resize(): def _test_forward_nms_v3(bx_shape, score_shape, iou_threshold, score_threshold, out_size, dtype="float32"): boxes = np.random.uniform(0, 10, size=bx_shape).astype(dtype) scores = np.random.uniform(size=score_shape).astype(dtype) + max_output_size = np.int32(out_size) tf.reset_default_graph() in_data_1 = tf.placeholder(dtype, boxes.shape, name="in_data_1") in_data_2 = tf.placeholder(dtype, scores.shape, name="in_data_2") - tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2, - max_output_size=out_size, iou_threshold=iou_threshold, - score_threshold=score_threshold, name="nms") - compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], + in_data_3 = tf.placeholder(tf.int32, name="in_data_3") + tf.image.non_max_suppression(boxes=in_data_1, scores=in_data_2, max_output_size=in_data_3, + iou_threshold=iou_threshold, score_threshold=score_threshold, name="nms") + compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], 'nms/NonMaxSuppressionV3:0', mode='vm') - compare_tf_with_tvm([boxes, scores], ['in_data_1:0', 'in_data_2:0'], + compare_tf_with_tvm([boxes, scores, max_output_size], ['in_data_1:0', 'in_data_2:0', 'in_data_3:0'], 'nms/NonMaxSuppressionV3:0', mode='debug') def test_forward_nms_v3(): @@ -2036,6 +2037,7 @@ def test_forward_nms_v3(): _test_forward_nms_v3((5, 4), (5,), 0.7, 0.5, 5) _test_forward_nms_v3((20, 4), (20,), 0.5, 0.6, 10) _test_forward_nms_v3((1000, 4), (1000,), 0.3, 0.7, 1000) + _test_forward_nms_v3((2000, 4), (2000,), 0.4, 0.6, 7) ####################################################################### diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 14d43c0a5fca..265db43d9904 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -283,16 +283,17 @@ def verify_get_valid_counts(dshape, score_threshold, id_index, score_index): def test_non_max_suppression(): - def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res, - iou_threshold=0.5, force_suppress=False, top_k=-1, - check_type_only=False): + def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res, + ref_indices_res, iou_threshold=0.5, force_suppress=False, + top_k=-1, check_type_only=False): x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32")) x2 = relay.var("x2", relay.ty.TensorType((dshape[0], dshape[1]), "int32")) - z = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \ + x3 = relay.var("x3", relay.ty.TensorType((), "int32")) + z = relay.vision.non_max_suppression(x0, x1, x2, x3, \ iou_threshold=iou_threshold, force_suppress=force_suppress, \ top_k=top_k, return_indices=False) - z_indices = relay.vision.non_max_suppression(x0, x1, x2, max_output_size=-1, \ + z_indices = relay.vision.non_max_suppression(x0, x1, x2, x3, \ iou_threshold=iou_threshold, force_suppress=force_suppress, \ top_k=top_k, return_indices=True) if isinstance(z_indices, relay.expr.TupleWrapper): @@ -309,30 +310,30 @@ def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res, if check_type_only: return - func = relay.Function([x0, x1, x2], z) + func = relay.Function([x0, x1, x2, x3], z) func = run_infer_type(func) - func_indices = relay.Function([x0, x1, x2], z_indices) + func_indices = relay.Function([x0, x1, x2, x3], z_indices) func_indices = run_infer_type(func_indices) for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) - op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data) + op_res1 = intrp1.evaluate(func)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5) intrp2 = relay.create_executor("debug", ctx=ctx, target=target) - op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data) + op_res2 = intrp2.evaluate(func)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_res2.asnumpy(), ref_res, rtol=1e-5) if target == 'cuda': return - op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data) + op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res1[0].asnumpy(), ref_indices_res, rtol=1e-5) - op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data) + op_indices_res2 = intrp2.evaluate(func_indices)(x0_data, x1_data, x2_data, x3_data) tvm.testing.assert_allclose(op_indices_res2[0].asnumpy(), ref_indices_res, rtol=1e-5) np_data = np.array([[[0, 0.8, 1, 20, 25, 45], [1, 0.7, 30, 60, 50, 80], [0, 0.4, 4, 21, 19, 40], [2, 0.9, 35, 61, 52, 79], [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") - np_indices = np.array([[0, 1, 3, 4, -1]]).astype("int32") + np_max_output_size = -1 np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], @@ -341,22 +342,23 @@ def verify_nms(x0_data, x1_data, x2_data, dshape, ref_res, ref_indices_res, num_anchors = 5 dshape = (te.size_var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, + verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, np_indices_result, force_suppress=True, top_k=2, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, np_indices_result, + verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, np_indices_result, force_suppress=True, top_k=2, check_type_only=False) np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], - [1, 0.7, 30, 60, 50, 80], [-1, -1, -1, -1, -1, -1], + [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) - np_indices_result = np.array([[4, 0, 1, -1, -1]]) + np_indices_result = np.array([[4, 0, -1, -1, -1]]) + np_max_output_size = 2 dshape = (te.size_var("n"), num_anchors, 6) - verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, + verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, np_indices_result, check_type_only=True) dshape = (1, num_anchors, 6) - verify_nms(np_data, np_valid_count, np_indices, dshape, np_result, - np_indices_result, top_k=3) + verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, + np_indices_result, top_k=2) def test_multibox_transform_loc(): diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 9e3200a0c418..9f46b95297c3 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -458,7 +458,7 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1, in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], name="nms", tag="nms") - + # TODO(yongwww): Update cuda nms to be consistent with cpu version if return_indices: return box_indices diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index 269c876d647e..1ee9e836705c 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -257,9 +257,12 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors Batch size. We need to pass it in since hybrid script doesn't support binding variable to symbolic dim. - max_output_size : tvm.tir.const + num_anchors: tvm.tir.IntImm or tvm.tir.Var + The number of anchors. + + max_output_size : tvm.te.Tensor Max number of output valid boxes for each instance. - By default all valid boxes are returned. + Return all valid boxes if max_output_size < 0. iou_threshold : tvm.tir.const Overlapping(IoU) threshold to suppress object with smaller score. @@ -300,7 +303,7 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors box_data_length = data.shape[2] - # box_indices is the expected value, similar to TF & ONNX + # box_indices is the expected indices of boxes box_indices = output_tensor((batch_size, num_anchors), sorted_index.dtype) output = output_tensor((batch_size, num_anchors, @@ -326,13 +329,33 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors # Apply nms box_start_idx = coord_start batch_idx = i + num_valid_boxes = 0 for j in range(valid_count[i]): - if output[i, j, score_index] > 0 and (id_index < 0 or output[i, j, id_index] >= 0): + if num_valid_boxes == max_output_size: + for k in range(box_data_length): + output[i, j, k] = -one + box_indices[i, j] = -1 + + elif output[i, j, score_index] > 0: box_a_idx = j - for k in parallel(valid_count[i]): + is_valid_box = 1 + + # a_l: left, a_t: top, a_r: right, a_b: bottom + a_l = min(output[batch_idx, box_a_idx, box_start_idx], + output[batch_idx, box_a_idx, box_start_idx + 2]) + a_t = min(output[batch_idx, box_a_idx, box_start_idx + 1], + output[batch_idx, box_a_idx, box_start_idx + 3]) + a_r = max(output[batch_idx, box_a_idx, box_start_idx], + output[batch_idx, box_a_idx, box_start_idx + 2]) + a_b = max(output[batch_idx, box_a_idx, box_start_idx + 1], + output[batch_idx, box_a_idx, box_start_idx + 3]) + + # check if current box j is valid by calculating iou with + # all existing valid boxes + for k in range(j): check_iou = 0 - if k > j and output[i, k, score_index] > 0 \ + if is_valid_box == 1 and k < j and output[i, k, score_index] > 0 \ and (id_index < 0 or output[i, k, id_index] >= 0): if force_suppress: check_iou = 1 @@ -340,16 +363,6 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors check_iou = 1 if check_iou > 0: - # a_l: left, a_t: top, a_r: right, a_b: bottom - a_l = min(output[batch_idx, box_a_idx, box_start_idx], - output[batch_idx, box_a_idx, box_start_idx + 2]) - a_t = min(output[batch_idx, box_a_idx, box_start_idx + 1], - output[batch_idx, box_a_idx, box_start_idx + 3]) - a_r = max(output[batch_idx, box_a_idx, box_start_idx], - output[batch_idx, box_a_idx, box_start_idx + 2]) - a_b = max(output[batch_idx, box_a_idx, box_start_idx + 1], - output[batch_idx, box_a_idx, box_start_idx + 3]) - box_b_idx = k # b_l: left, b_t: top, b_r: right, b_b: bottom @@ -377,10 +390,14 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors iou = zero if u <= zero else area / u if iou >= iou_threshold: - output[i, k, score_index] = -one - if id_index >= 0: - output[i, k, id_index] = -one - box_indices[i, k] = -1 + is_valid_box = 0 + + if is_valid_box == 0: + for k in range(box_data_length): + output[i, j, k] = -one + box_indices[i, j] = -1 + else: + num_valid_boxes += 1 else: for j in parallel(valid_count[i]): @@ -394,18 +411,6 @@ def hybrid_nms(data, sorted_index, valid_count, indices, batch_size, num_anchors output[i, j + valid_count[i], k] = -one box_indices[i, j + valid_count[i]] = -1 - # Only return max_output_size valid boxes - num_valid_boxes = 0 - if max_output_size > 0: - for j in range(valid_count[i]): - if output[i, j, 0] >= zero: - if num_valid_boxes == max_output_size: - for k in range(box_data_length): - output[i, j, k] = -one - box_indices[i, j] = -1 - else: - num_valid_boxes += 1 - if return_indices: for j in range(valid_count[i]): idx = box_indices[i, j] @@ -432,9 +437,9 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1, indices : tvm.te.Tensor 2-D tensor with shape [batch_size, num_anchors]. - max_output_size : optional, int + max_output_size : optional, int or tvm.te.Tensor Max number of output valid boxes for each instance. - By default all valid boxes are returned. + Return all valid boxes if the value of max_output_size is less than 0. iou_threshold : optional, float Non-maximum suppression threshold. @@ -494,17 +499,20 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1, """ batch_size = data.shape[0] num_anchors = data.shape[1] + if isinstance(max_output_size, int): + max_output_size = tvm.tir.const(max_output_size, dtype="int32") score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = te.compute(score_shape, lambda i, j: data[i, j, score_axis]) sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False) + out, box_indices = hybrid_nms(data, sort_tensor, valid_count, indices, batch_size, num_anchors, - tvm.tir.const(max_output_size, dtype="int32"), + max_output_size, tvm.tir.const(iou_threshold, dtype=data.dtype), tvm.tir.const(force_suppress, dtype="bool"), tvm.tir.const(top_k, dtype="int32"), diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index d2331ee0c7f7..b74e19346f30 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -132,7 +132,7 @@ def test_get_valid_counts(): verify_get_valid_counts((16, 500, 5), 0.95, -1, 1) -def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, +def verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, max_output_size, iou_threshold, force_suppress, top_k, coord_start, score_index, id_index): dshape = np_data.shape batch, num_anchors, _ = dshape @@ -149,11 +149,11 @@ def check_device(device): print("Running on target: %s" % device) with tvm.target.create(device): fcompute, fschedule = topi.testing.dispatch(device, _nms_implement) - out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, - coord_start=coord_start, score_index=score_index, id_index=id_index, + out = fcompute(data, valid_count, indices, max_output_size, iou_threshold, force_suppress, + top_k, coord_start=coord_start, score_index=score_index, id_index=id_index, return_indices=False) - indices_out = fcompute(data, valid_count, indices, -1, iou_threshold, force_suppress, top_k, - coord_start=coord_start, score_index=score_index, id_index=id_index, + indices_out = fcompute(data, valid_count, indices, max_output_size, iou_threshold, force_suppress, + top_k, coord_start=coord_start, score_index=score_index, id_index=id_index, return_indices=True) s = fschedule(out) indices_s = fschedule(indices_out) @@ -186,23 +186,27 @@ def test_non_max_suppression(): [1, 0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") + max_output_size = -1 np_result = np.array([[[2, 0.9, 35, 61, 52, 79], [0, 0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1]]]) np_indices_result = np.array([[3, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, True, 2, 2, 1, 0) + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, + max_output_size, 0.7, True, 2, 2, 1, 0) np_data = np.array([[[0.8, 1, 20, 25, 45], [0.7, 30, 60, 50, 80], [0.4, 4, 21, 19, 40], [0.9, 35, 61, 52, 79], [0.5, 100, 60, 70, 110]]]).astype("float32") np_valid_count = np.array([4]).astype("int32") np_indices = np.array([[0, 1, 2, 3, 4]]).astype("int32") + max_output_size = 2 np_result = np.array([[[0.9, 35, 61, 52, 79], [0.8, 1, 20, 25, 45], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1], [-1, -1, -1, -1, -1]]]) np_indices_result = np.array([[3, 0, -1, -1, -1]]) - verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, 0.7, False, 2, 1, 0, -1) + verify_non_max_suppression(np_data, np_valid_count, np_indices, np_result, np_indices_result, + max_output_size, 0.7, False, 2, 1, 0, -1) def verify_multibox_prior(dshape, sizes=(1,), ratios=(1,), steps=(-1, -1), offsets=(0.5, 0.5), clip=False):