From d0fb390a53c08023f7697a821279acd7707619b8 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Wed, 16 Jan 2019 14:50:02 +0800 Subject: [PATCH] Add missing global barrier in argsort --- topi/python/topi/cuda/rcnn/proposal.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/topi/python/topi/cuda/rcnn/proposal.py b/topi/python/topi/cuda/rcnn/proposal.py index 39bb60b4cc090..b606bb9225d5e 100644 --- a/topi/python/topi/cuda/rcnn/proposal.py +++ b/topi/python/topi/cuda/rcnn/proposal.py @@ -134,23 +134,30 @@ def argsort_ir(data_buf, out_index_buf): """ batch, num_bbox = get_const_tuple(data_buf.shape) max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) - tx = tvm.thread_axis("threadIdx.x") - bx = tvm.thread_axis("blockIdx.x") ib = tvm.ir_builder.create() - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") p_data = ib.buffer_ptr(data_buf) index_out = ib.buffer_ptr(out_index_buf) nthread_tx = max_threads - nthread_bx = num_bbox // max_threads + 1 + nthread_bx = (num_bbox + 1) // 2 // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x", (0, nthread_tx)) + bx = tvm.thread_axis("blockIdx.x", (0, nthread_bx)) ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * nthread_tx + tx + temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") + temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + ib.emit(tvm.make.Call(None, 'tvm_global_barrier_kinit', None, tvm.expr.Call.Intrinsic, None, 0)) with ib.for_range(0, batch, for_type="unroll") as b: start = b * num_bbox - with ib.if_scope(tid < num_bbox): - index_out[start + tid] = tid + for i in range(2): + offset = start + 2 * tid + i + with ib.if_scope(offset < num_bbox): + index_out[offset] = offset + + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) with ib.for_range(0, num_bbox) as k: with ib.if_scope(tid < (num_bbox + 1) // 2): @@ -163,7 +170,9 @@ def argsort_ir(data_buf, out_index_buf): temp_index[0] = index_out[offset] index_out[offset] = index_out[offset + 1] index_out[offset + 1] = temp_index[0] - + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['global', True, nthread_bx]), + tvm.expr.Call.Intrinsic, None, 0)) return ib.get()