Skip to content

Commit

Permalink
multibox bug fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 12, 2019
1 parent 04b1255 commit 137a382
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tests/python/relay/test_op_level5.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,8 @@ def verify_yolo_reorg(shape, stride):
if __name__ == "__main__":
test_resize_infer_type()
test_resize()
test_multibox_prior()
test_multibox_transform_loc()
test_multibox_prior()
test_get_valid_counts()
test_roi_align()
test_proposal()
Expand Down
20 changes: 9 additions & 11 deletions topi/python/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
3D Buffer with shape [batch_size, num_anchors, 6], output of nms.
flag : Buffer
1D Buffer of flag indicating valid data with [num_anchors].
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
idx : Buffer
1D Buffer of valid data indices with [num_anchors].
2D Buffer of valid data indices with shape [batch_size, num_anchors].
score_threshold: float32
Lower limit of score for valid bounding boxes.
Expand All @@ -43,8 +43,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold):
idx = ib.buffer_ptr(idx)
score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold)

max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
Expand Down Expand Up @@ -85,10 +84,10 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
Input data. 3-D Buffer with shape [batch_size, num_anchors, 6].
flag : Buffer
1D Buffer of flag indicating valid data with [num_anchors].
2D Buffer of flag indicating valid data with shape [batch_size, num_anchors].
idx : Buffer
1D Buffer of valid data indices with [num_anchors].
2D Buffer of valid data indices with shape [batch_size, num_anchors].
valid_count : Buffer
1-D buffer for valid number of boxes.
Expand All @@ -113,22 +112,21 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out):
valid_count = ib.buffer_ptr(valid_count)
out = ib.buffer_ptr(out)

max_threads = int(math.sqrt(
tvm.target.current_target(allow_none=False).max_num_threads))
max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = batch_size * num_anchors // max_threads + 1
nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1
tx = tvm.thread_axis("threadIdx.x")
bx = tvm.thread_axis("blockIdx.x")
ib.scope_attr(tx, "thread_extent", nthread_tx)
ib.scope_attr(bx, "thread_extent", nthread_bx)
tid = bx * max_threads + tx

with ib.if_scope(tid < batch_size * num_anchors * elem_length):
out[tid] = -1.0
with ib.if_scope(tid < batch_size * num_anchors):
i = tid / num_anchors # number of batches
j = tid % num_anchors # number of anchors
base_idx = i * num_anchors * 6
with ib.for_range(0, elem_length) as k:
out[base_idx + j * 6 + k] = -1.0
with ib.if_scope(flag[tid] > 0):
with ib.for_range(0, elem_length) as k:
out[base_idx + (idx[tid] - 1) * 6 + k] = data[base_idx + j * 6 + k]
Expand Down

0 comments on commit 137a382

Please sign in to comment.