From 137a3825489d3a014138fb3ed07ec85878ca5cb0 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Tue, 12 Mar 2019 11:58:55 -0700 Subject: [PATCH] multibox bug fixed --- tests/python/relay/test_op_level5.py | 2 +- topi/python/topi/cuda/nms.py | 20 +++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index e32f18cba9a59..2192ee86f87fa 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -457,8 +457,8 @@ def verify_yolo_reorg(shape, stride): if __name__ == "__main__": test_resize_infer_type() test_resize() - test_multibox_prior() test_multibox_transform_loc() + test_multibox_prior() test_get_valid_counts() test_roi_align() test_proposal() diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 0b4477b79a337..3da23dada9d1d 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -20,10 +20,10 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): 3D Buffer with shape [batch_size, num_anchors, 6], output of nms. flag : Buffer - 1D Buffer of flag indicating valid data with [num_anchors]. + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. idx : Buffer - 1D Buffer of valid data indices with [num_anchors]. + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. score_threshold: float32 Lower limit of score for valid bounding boxes. @@ -43,8 +43,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): idx = ib.buffer_ptr(idx) score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) - max_threads = int(math.sqrt( - tvm.target.current_target(allow_none=False).max_num_threads)) + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -85,10 +84,10 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): Input data. 3-D Buffer with shape [batch_size, num_anchors, 6]. flag : Buffer - 1D Buffer of flag indicating valid data with [num_anchors]. + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. idx : Buffer - 1D Buffer of valid data indices with [num_anchors]. + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. valid_count : Buffer 1-D buffer for valid number of boxes. @@ -113,22 +112,21 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): valid_count = ib.buffer_ptr(valid_count) out = ib.buffer_ptr(out) - max_threads = int(math.sqrt( - tvm.target.current_target(allow_none=False).max_num_threads)) + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads - nthread_bx = batch_size * num_anchors // max_threads + 1 + nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx + with ib.if_scope(tid < batch_size * num_anchors * elem_length): + out[tid] = -1.0 with ib.if_scope(tid < batch_size * num_anchors): i = tid / num_anchors # number of batches j = tid % num_anchors # number of anchors base_idx = i * num_anchors * 6 - with ib.for_range(0, elem_length) as k: - out[base_idx + j * 6 + k] = -1.0 with ib.if_scope(flag[tid] > 0): with ib.for_range(0, elem_length) as k: out[base_idx + (idx[tid] - 1) * 6 + k] = data[base_idx + j * 6 + k]