diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 6118ff11f9ce6..53f413f837d6e 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -44,13 +44,13 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): idx = ib.buffer_ptr(idx) score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) - max_threads = int(math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") - bx = tvm.thread_axis("blockIdx.x") + bx = tvm.thread_axis("vthread") ib.scope_attr(tx, "thread_extent", nthread_tx) - ib.scope_attr(bx, "thread_extent", nthread_bx) + ib.scope_attr(bx, "virtual_thread", nthread_bx) tid = bx * max_threads + tx with ib.if_scope(tid < batch_size * num_anchors): @@ -64,10 +64,6 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): flag[tid] = 0 idx[tid] = 0 - ib.emit(tvm.make.Call(None, 'tvm_storage_sync', - tvm.convert(['shared']), - tvm.expr.Call.Intrinsic, None, 0)) - with ib.if_scope(tid < batch_size): with ib.for_range(0, num_anchors) as k: with ib.if_scope(k > 0): diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index 599c729d95cae..0cb6f70d9ee6e 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -181,7 +181,7 @@ def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold) max_threads = int( - tvm.target.current_target(allow_none=False).max_num_threads) + math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) nthread_tx = max_threads nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") @@ -293,7 +293,7 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, out_loc = ib.buffer_ptr(out) max_threads = int( - tvm.target.current_target(allow_none=False).max_num_threads) + math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) nthread_tx = max_threads nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index f32e21183f037..1e6aa87843275 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -17,10 +17,14 @@ def _default_schedule(outs): def traverse(op): """inline all one-to-one-mapping operators except the last stage (output)""" if op.tag in ["nms", "invalid_to_bottom"]: - sort = op.input_tensors[1] + if op.tag == "nms": + sort = op.input_tensors[1] + else: + out = op.input_tensors[0] + sort = s[out].op.input_tensors[1] score = s[sort].op.input_tensors[0] fused = s[score].fuse(*s[score].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads + num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads) bx, tx = s[score].split(fused, factor=num_thread) s[score].bind(bx, tvm.thread_axis("blockIdx.x")) s[score].bind(tx, tvm.thread_axis("threadIdx.x"))