Skip to content

Commit

Permalink
Make topi cuda nms_gpu method signature similar to non_max_suppression (
Browse files Browse the repository at this point in the history
  • Loading branch information
apivovarov authored and tqchen committed Mar 12, 2019
1 parent d8abc73 commit a3f3dc7
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,15 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):


@non_max_suppression.register(["cuda", "gpu"])
def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress=False,
topk=-1, id_index=0, invalid_to_bottom=False):
def nms_gpu(data,
valid_count,
max_output_size=-1,
iou_threshold=0.5,
force_suppress=False,
top_k=-1,
id_index=0,
return_indices=True,
invalid_to_bottom=False):
"""Non-maximum suppression operator for object detection.
Parameters
Expand All @@ -205,7 +212,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
force_suppress : optional, boolean
Whether to suppress all detections regardless of class_id.
topk : optional, int
top_k : optional, int
Keep maximum top k detections before nms, -1 for no limit.
id_index : optional, int
Expand All @@ -229,7 +236,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
iou_threshold = 0.7
force_suppress = True
topk = -1
top_k = -1
out = nms(data, valid_count, iou_threshold, force_suppress, topk)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
Expand Down Expand Up @@ -273,7 +280,7 @@ def nms_gpu(data, valid_count, return_indices, iou_threshold=0.5, force_suppress
[data, sort_tensor, valid_count],
lambda ins, outs: nms_ir(
ins[0], ins[1], ins[2], outs[0], iou_threshold,
force_suppress, topk),
force_suppress, top_k),
dtype="float32",
in_buffers=[data_buf, sort_tensor_buf, valid_count_buf],
tag="nms")
Expand Down

0 comments on commit a3f3dc7

Please sign in to comment.