diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index 3da23dada9d1d..644d366ee697b 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -73,6 +73,7 @@ def get_valid_counts_pre(data, flag, idx, score_threshold): return ib.get() + def get_valid_counts_ir(data, flag, idx, valid_count, out): """Low level IR to get valid count of bounding boxes given a score threshold. Also moves valid boxes to the @@ -133,6 +134,8 @@ def get_valid_counts_ir(data, flag, idx, valid_count, out): valid_count[i] = idx[i * num_anchors + num_anchors - 1] return ib.get() + + @get_valid_counts.register(["cuda", "gpu"]) def get_valid_counts_gpu(data, score_threshold=0): """Get valid count of bounding boxes given a score threshold. @@ -180,6 +183,7 @@ def get_valid_counts_gpu(data, score_threshold=0): return [valid_count, out_tensor] + def sort_ir(data, index, output): """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU.