diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 3fd0f7f5884ce..4a914b71760ea 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -724,6 +724,9 @@ def get_valid_counts_strategy(attrs, inputs, out_type, target): def wrap_compute_nms(topi_compute): """wrap nms topi compute""" def _compute_nms(attrs, inputs, out_type): + max_output_size = inputs[3] + if attrs.max_output_size is not None: + max_output_size = attrs.max_output_size return_indices = bool(get_const_int(attrs.return_indices)) iou_threshold = get_const_float(attrs.iou_threshold) force_suppress = bool(get_const_int(attrs.force_suppress)) @@ -733,10 +736,10 @@ def _compute_nms(attrs, inputs, out_type): id_index = get_const_int(attrs.id_index) invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom)) if return_indices: - return topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], iou_threshold, + return topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, force_suppress, top_k, coord_start, score_index, id_index, return_indices, invalid_to_bottom) - return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], iou_threshold, + return [topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold, force_suppress, top_k, coord_start, score_index, id_index, return_indices, invalid_to_bottom)] return _compute_nms diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 845326e96f129..265db43d99045 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -358,7 +358,7 @@ def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res, np_indices_result, check_type_only=True) dshape = (1, num_anchors, 6) verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result, - np_indices_result, top_k=3) + np_indices_result, top_k=2) def test_multibox_transform_loc(): diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index f2c1143b5fb82..8729891b27cdc 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -457,8 +457,9 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1, in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], name="nms", tag="nms") - + # TODO(yongwww): Update cuda nms to be consistent with cpu version if return_indices: return box_indices return out +