Skip to content

Commit

Permalink
expr max_output_size for cuda nms
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww committed Jun 25, 2020
1 parent f8dd14f commit def4ac1
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
7 changes: 5 additions & 2 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def get_valid_counts_strategy(attrs, inputs, out_type, target):
def wrap_compute_nms(topi_compute):
"""wrap nms topi compute"""
def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[3]
if attrs.max_output_size is not None:
max_output_size = attrs.max_output_size
return_indices = bool(get_const_int(attrs.return_indices))
iou_threshold = get_const_float(attrs.iou_threshold)
force_suppress = bool(get_const_int(attrs.force_suppress))
Expand All @@ -733,10 +736,10 @@ def _compute_nms(attrs, inputs, out_type):
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
if return_indices:
return topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], iou_threshold,
return topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold,
force_suppress, top_k, coord_start, score_index, id_index,
return_indices, invalid_to_bottom)
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], iou_threshold,
return [topi_compute(inputs[0], inputs[1], inputs[2], max_output_size, iou_threshold,
force_suppress, top_k, coord_start, score_index, id_index,
return_indices, invalid_to_bottom)]
return _compute_nms
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def verify_nms(x0_data, x1_data, x2_data, x3_data, dshape, ref_res,
np_indices_result, check_type_only=True)
dshape = (1, num_anchors, 6)
verify_nms(np_data, np_valid_count, np_indices, np_max_output_size, dshape, np_result,
np_indices_result, top_k=3)
np_indices_result, top_k=2)


def test_multibox_transform_loc():
Expand Down
3 changes: 2 additions & 1 deletion topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,9 @@ def non_max_suppression(data, valid_count, indices, max_output_size=-1,
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
name="nms",
tag="nms")

# TODO(yongwww): Update cuda nms to be consistent with cpu version
if return_indices:
return box_indices

return out

0 comments on commit def4ac1

Please sign in to comment.