From c0e3e5ec7130ada190d00a4e59e19550cf825068 Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Thu, 14 Mar 2019 22:38:13 -0700 Subject: [PATCH] ssd gluoncv ops merge with master ssd gluoncv gpu op updated tutorials and testes modified fix lint address comment multibox bug fixed space line added use less threads per block less threads per block for get valid count Revert "less threads per block for get valid count" This reverts commit 08896cfccc34b0b2a1646d01d01ea4cad73941c4. typo fixed elem length made to a variable fix lint error fix lint error lint fixed bug fixed lint fixed error fixed test ci seperate argsort to be an independent op fix lint fix lint remove unsupported models ssd gluoncv gpu op updated tutorials and testes modified fix lint use less threads per block less threads per block for get valid count Revert "less threads per block for get valid count" This reverts commit 08896cfccc34b0b2a1646d01d01ea4cad73941c4. bug fixed error fixed test ci seperate argsort to be an independent op typo fixed argsort added to realy solve conflicts with master fix lint fix lint test push Revert "test push" This reverts commit 6db00883fab6cc06bddf564c926bb27c874397d8. fix lint error fix more lint cpu test_sort udpated debug ci nms fixed expose argsort to relay frontend test ci fix lint sort register error fixed fix nnvm adaptive pooling added to relay nms type fixed Revert "adaptive pooling added to relay" This reverts commit 1119f1f2c055753e0cc5611627597749134c5c8c. fix lint expose argsort op fix lint fix lint fix lint sort test updated sort bug fixed nnvm error fixed fix argsort default data type returned to be float insteaf of int fix lint fix lint test fixed fix valid count fix titanx bug tutorial add both targets titanx error fixed try to fix CI old gpu error try to solve CI GPU error get_valid_count added [AutoTVM] fix argument type for curve feature (#3004) --- include/tvm/relay/attrs/vision.h | 24 + nnvm/include/nnvm/top/nn.h | 6 + nnvm/python/nnvm/top/vision.py | 10 +- nnvm/tests/python/compiler/test_top_level4.py | 51 +- python/tvm/relay/frontend/mxnet.py | 46 +- python/tvm/relay/op/tensor.py | 24 + python/tvm/relay/op/transform.py | 23 - python/tvm/relay/op/vision/__init__.py | 2 + python/tvm/relay/op/vision/_sort.py | 29 + python/tvm/relay/op/vision/_vision.py | 5 +- python/tvm/relay/op/vision/nms.py | 11 +- python/tvm/relay/op/vision/sort.py | 31 + src/contrib/sort/sort.cc | 75 ++- src/relay/op/vision/nms.cc | 4 + src/relay/op/vision/sort_op.cc | 61 ++ tests/python/contrib/test_sort.py | 12 +- tests/python/relay/test_op_level5.py | 50 +- topi/python/topi/cuda/nms.py | 565 +++++++++++++----- topi/python/topi/cuda/sort.py | 230 +++++++ topi/python/topi/cuda/ssd/multibox.py | 195 +++--- topi/python/topi/cuda/vision.py | 37 +- topi/python/topi/generic/vision.py | 17 + topi/python/topi/vision/__init__.py | 1 + topi/python/topi/vision/nms.py | 41 +- topi/python/topi/vision/sort.py | 88 +++ topi/python/topi/vision/ssd/multibox.py | 6 +- topi/tests/python/test_topi_vision.py | 40 +- tutorials/frontend/deploy_ssd_gluoncv.py | 23 +- 28 files changed, 1331 insertions(+), 376 deletions(-) create mode 100644 python/tvm/relay/op/vision/_sort.py create mode 100644 python/tvm/relay/op/vision/sort.py create mode 100644 src/relay/op/vision/sort_op.cc create mode 100644 topi/python/topi/cuda/sort.py create mode 100644 topi/python/topi/vision/sort.py diff --git a/include/tvm/relay/attrs/vision.h b/include/tvm/relay/attrs/vision.h index 2b3eb4f32b458..0b57d99887555 100644 --- a/include/tvm/relay/attrs/vision.h +++ b/include/tvm/relay/attrs/vision.h @@ -30,6 +30,24 @@ namespace tvm { namespace relay { +/*! \brief Attributes used in argsort operators */ +struct ArgsortAttrs : public tvm::AttrsNode { + int axis; + bool is_ascend; + std::string dtype; + + TVM_DECLARE_ATTRS(ArgsortAttrs, "relay.attrs.ArgsortAttrs") { + TVM_ATTR_FIELD(axis).set_default(-1) + .describe("Axis along which to sort the input tensor." + "If not given, the flattened array is used."); + TVM_ATTR_FIELD(is_ascend).set_default(true) + .describe("Whether to sort in ascending or descending order." + "By default, sort in ascending order"); + TVM_ATTR_FIELD(dtype).set_default("float32") + .describe("DType of the output indices."); + } +}; + /*! \brief Attributes used in multibox_prior operators */ struct MultiBoxPriorAttrs : public tvm::AttrsNode { Array sizes; @@ -92,6 +110,8 @@ struct NonMaximumSuppressionAttrs : public tvm::AttrsNode& lhs, } -// Argsort implemented C library sort. +// Argsort implemented C library sort for nms. // Return indices of sorted tensor. // By default, the last axis will be used to sort. // sort_num specify the number of elements to be sorted. // If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) // and sort axis is dk. sort_num should have dimension of // (d1, d2, ..., d(k-1), d(k+1), ..., dn). -TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort_nms") .set_body([](TVMArgs args, TVMRetValue *ret) { DLTensor *input = args[0]; DLTensor *sort_num = args[1]; DLTensor *output = args[2]; int32_t axis = args[3]; - bool is_descend = args[4]; + bool is_ascend = args[4]; auto dtype = input->dtype; auto data_ptr = static_cast(input->data); @@ -97,10 +97,10 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") int64_t full_idx = base_idx + k * axis_mul_after; sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); } - if (is_descend) { - std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); - } else { + if (is_ascend) { std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + } else { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); } for (int32_t k = 0; k < input->shape[axis]; ++k) { *(static_cast(output->data) + base_idx + k * axis_mul_after) @@ -110,5 +110,68 @@ TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") } }); + +// Argsort implemented C library sort. +// Return indices of sorted tensor. +// By default, the last axis will be used to sort. +// sort_num specify the number of elements to be sorted. +// If input tensor has dimension (d0, d1, ..., d(k-1), dk, d(k+1), ..., d(n-1)) +// and sort axis is dk. sort_num should have dimension of +// (d1, d2, ..., d(k-1), d(k+1), ..., dn). +TVM_REGISTER_GLOBAL("tvm.contrib.sort.argsort") +.set_body([](TVMArgs args, TVMRetValue *ret) { + DLTensor *input = args[0]; + DLTensor *output = args[1]; + int32_t axis = args[2]; + bool is_ascend = args[3]; + + auto dtype = input->dtype; + auto data_ptr = static_cast(input->data); + std::vector> sorter; + int64_t axis_mul_before = 1; + int64_t axis_mul_after = 1; + + if (axis < 0) { + axis = input->ndim + axis; + } + + // Currently only supports input dtype to be float32. + CHECK_EQ(dtype.code, 2) << "Currently only supports input dtype " + "to be float32."; + CHECK_EQ(dtype.bits, 32) << "Currently only supports input dtype " + "to be float32."; + CHECK_LT(axis, input->ndim) << "Axis out of boundary for " + "input ndim " << input->ndim; + + for (int i = 0; i < input->ndim; ++i) { + if (i < axis) { + axis_mul_before *= input->shape[i]; + } else if (i > axis) { + axis_mul_after *= input->shape[i]; + } + } + + int32_t current_sort_num = input->shape[axis]; + for (int64_t i = 0 ; i < axis_mul_before; ++i) { + for (int64_t j = 0 ; j < axis_mul_after; ++j) { + sorter.clear(); + int64_t base_idx = i * input->shape[axis] * axis_mul_after + j; + for (int64_t k = 0; k < current_sort_num; ++k) { + int64_t full_idx = base_idx + k * axis_mul_after; + sorter.emplace_back(std::make_pair(k, *(data_ptr + full_idx))); + } + if (is_ascend) { + std::stable_sort(sorter.begin(), sorter.end(), CompareAscend); + } else { + std::stable_sort(sorter.begin(), sorter.end(), CompareDescend); + } + for (int32_t k = 0; k < input->shape[axis]; ++k) { + *(static_cast(output->data) + base_idx + k * axis_mul_after) + = k < static_cast(sorter.size()) ? sorter[k].first : k; + } + } + } +}); + } // namespace contrib } // namespace tvm diff --git a/src/relay/op/vision/nms.cc b/src/relay/op/vision/nms.cc index 5344bce3d6413..2e5661cdc4dc8 100644 --- a/src/relay/op/vision/nms.cc +++ b/src/relay/op/vision/nms.cc @@ -106,6 +106,8 @@ Expr MakeNMS(Expr data, double iou_threshold, bool force_suppress, int top_k, + int coord_start, + int score_index, int id_index, bool return_indices, bool invalid_to_bottom) { @@ -114,6 +116,8 @@ Expr MakeNMS(Expr data, attrs->iou_threshold = iou_threshold; attrs->force_suppress = force_suppress; attrs->top_k = top_k; + attrs->coord_start = coord_start; + attrs->score_index = score_index; attrs->id_index = id_index; attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; diff --git a/src/relay/op/vision/sort_op.cc b/src/relay/op/vision/sort_op.cc new file mode 100644 index 0000000000000..194db6979f81e --- /dev/null +++ b/src/relay/op/vision/sort_op.cc @@ -0,0 +1,61 @@ +/*! + * Copyright (c) 2018 by Contributors + * \file nms.cc + * \brief Non-maximum suppression operators + */ +#include +#include + +namespace tvm { +namespace relay { + +TVM_REGISTER_NODE_TYPE(ArgsortAttrs); + +bool ArgsortRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + // `types` contains: [data, result] + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as(); + if (data == nullptr) { + CHECK(types[0].as()) + << "Argsort: expect input type to be TensorType but get " + << types[0]; + return false; + } + reporter->Assign(types[1], TensorTypeNode::make(data->shape, data->dtype)); + return true; +} + +Expr MakeArgsort(Expr data, + int axis, + bool is_ascend, + std::string dtype) { + auto attrs = make_node(); + attrs->axis = axis; + attrs->is_ascend = is_ascend; + CHECK_NE(dtype, "bool"); + attrs->dtype = dtype; + static const Op& op = Op::Get("vision.argsort"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.vision._make.argsort") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeArgsort, args, rv); +}); + + +RELAY_REGISTER_OP("vision.argsort") +.describe(R"doc(Returns the indices that would sort an +input array along the given axis. +)doc" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ArgsortAttrs") +.add_argument("data", "Tensor", "Input data.") +.set_support_level(5) +.add_type_rel("Argsort", ArgsortRel); +} // namespace relay +} // namespace tvm diff --git a/tests/python/contrib/test_sort.py b/tests/python/contrib/test_sort.py index 856d3fa9cf832..87cdac01ce3a6 100644 --- a/tests/python/contrib/test_sort.py +++ b/tests/python/contrib/test_sort.py @@ -24,11 +24,11 @@ def test_sort(): data = tvm.placeholder((n, l, m), name='data') sort_num = tvm.placeholder((n, m), name="sort_num", dtype="int32") axis = 1 - is_descend = True + is_ascend = False out = tvm.extern(data.shape, [data, sort_num], lambda ins, outs: tvm.call_packed( - "tvm.contrib.sort.argsort", ins[0], - ins[1], outs[0], axis, is_descend), + "tvm.contrib.sort.argsort_nms", ins[0], + ins[1], outs[0], axis, is_ascend), dtype='int32', name="sort_tensor") input = [[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]], [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]]] @@ -50,13 +50,13 @@ def test_sort_np(): dshape = (1, 2, 3, 4, 5, 6) axis = 4 reduced_shape = (1, 2, 3, 4, 6) - is_descend = False + is_ascend = True data = tvm.placeholder(dshape, name='data') sort_num = tvm.placeholder(reduced_shape, name="sort_num", dtype="int32") out = tvm.extern(data.shape, [data, sort_num], lambda ins, outs: tvm.call_packed( - "tvm.contrib.sort.argsort", ins[0], - ins[1], outs[0], axis, is_descend), + "tvm.contrib.sort.argsort_nms", ins[0], + ins[1], outs[0], axis, is_ascend), dtype='int32', name="sort_tensor") ctx = tvm.cpu(0) diff --git a/tests/python/relay/test_op_level5.py b/tests/python/relay/test_op_level5.py index 7e1c371699782..7d0aa6a2beb45 100644 --- a/tests/python/relay/test_op_level5.py +++ b/tests/python/relay/test_op_level5.py @@ -177,12 +177,13 @@ def verify_get_valid_counts(dshape, score_threshold): assert "score_threshold" in z.astext() func = relay.Function([x], z.astuple()) func = relay.ir_pass.infer_type(func) - ctx_list = [("llvm", tvm.cpu(0))] - for target, ctx in ctx_list: + for target, ctx in ctx_list(): + if target == 'cuda': + return intrp = relay.create_executor("debug", ctx=ctx, target=target) out = intrp.evaluate(func)(np_data) - tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3) - tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3) + tvm.testing.assert_allclose(out[0].asnumpy(), np_out1, rtol=1e-3, atol=1e-04) + tvm.testing.assert_allclose(out[1].asnumpy(), np_out2, rtol=1e-3, atol=1e-04) verify_get_valid_counts((1, 2500, 6), 0) verify_get_valid_counts((1, 2500, 6), -1) @@ -195,9 +196,13 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, iou_threshold=0.5, force_suppress=False, top_k=-1, check_type_only=False): x0 = relay.var("x0", relay.ty.TensorType(dshape, "float32")) - x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int")) - z = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k, return_indices=False) - z_indices = relay.vision.non_max_suppression(x0, x1, -1, iou_threshold, force_suppress, top_k) + x1 = relay.var("x1", relay.ty.TensorType((dshape[0],), "int32")) + z = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ + iou_threshold = iou_threshold, force_suppress = force_suppress, \ + top_k = top_k, return_indices=False) + z_indices = relay.vision.non_max_suppression(x0, x1, max_output_size = -1, \ + iou_threshold = iou_threshold, force_suppress = force_suppress, \ + top_k = top_k) assert "iou_threshold" in z.astext() assert "iou_threshold" in z_indices.astext() zz = relay.ir_pass.infer_type(z) @@ -212,8 +217,7 @@ def verify_nms(x0_data, x1_data, dshape, ref_res, ref_indices_res, func = relay.ir_pass.infer_type(func) func_indices = relay.Function([x0, x1], z_indices) func_indices = relay.ir_pass.infer_type(func_indices) - ctx_list = [("llvm", tvm.cpu(0))] - for target, ctx in ctx_list: + for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(x0_data, x1_data) op_indices_res1 = intrp1.evaluate(func_indices)(x0_data, x1_data) @@ -296,8 +300,7 @@ def test_default_value(): nms = relay.vision.non_max_suppression(mtl[0], mtl[1], return_indices=False) func = relay.Function([cls_prob, loc_pred, anchors], nms) func = relay.ir_pass.infer_type(func) - ctx_list = [("llvm", tvm.cpu(0))] - for target, ctx in ctx_list: + for target, ctx in ctx_list(): intrp1 = relay.create_executor("graph", ctx=ctx, target=target) op_res1 = intrp1.evaluate(func)(np_cls_prob, np_loc_preds, np_anchors) @@ -565,11 +568,33 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): test_run(2, 4, 16, 4, 4, 1) +def test_argsort(): + def verify_argsort(shape, axis, is_ascend): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.vision.argsort(x, axis=axis, is_ascend=is_ascend) + zz = relay.ir_pass.infer_type(z) + func = relay.Function([x], z) + x_data = np.random.uniform(size=shape).astype("float32") + if is_ascend: + ref_res = np.argsort(x_data, axis=axis) + else: + ref_res = np.argsort(-x_data, axis=axis) + + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.astype("float"), rtol=1e-5) + verify_argsort((2, 3, 4), axis=0, is_ascend=False) + verify_argsort((1, 4, 6), axis=1, is_ascend=True) + verify_argsort((3, 5, 6), axis=-1, is_ascend=False) + + if __name__ == "__main__": test_resize_infer_type() test_resize() - test_multibox_prior() test_multibox_transform_loc() + test_multibox_prior() test_get_valid_counts() test_roi_align() test_roi_pool() @@ -578,3 +603,4 @@ def test_run(batch, in_channel, size, out_channel, deformable_groups, groups): test_yolo_reorg() test_non_max_suppression() test_deformable_conv2d() + test_argsort() diff --git a/topi/python/topi/cuda/nms.py b/topi/python/topi/cuda/nms.py index e6377fa40c529..879a6d20c7362 100644 --- a/topi/python/topi/cuda/nms.py +++ b/topi/python/topi/cuda/nms.py @@ -20,77 +20,196 @@ import tvm from tvm import api -from topi.vision import non_max_suppression -from ..util import get_const_tuple +from tvm.intrin import if_then_else +from topi.vision import non_max_suppression, get_valid_counts +from .sort import argsort -def sort_ir(data, index, output): - """Low level IR to do sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + +def get_valid_counts_pre(data, flag, idx, score_threshold): + """Low level IR to get valid count of bounding boxes + given a score threshold. Also moves valid boxes to the + top of input data. Parameters ---------- data: Buffer - 2D Buffer of input boxes' score with shape [batch_size, num_anchors]. + 3D Buffer with shape [batch_size, num_anchors, 6], output of nms. + + flag : Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. - index : Buffer - 1D Buffer of number of valid number of boxes. + idx : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. - output : Buffer - 2D Output buffer of indicies of sorted tensor with shape [batch_size, num_anchors]. + score_threshold : float32 + Lower limit of score for valid bounding boxes. Returns ------- stmt : Stmt The result IR statement. """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] - assert data.dtype == "float32", "Currently only supports input dtype to be float32" - batch, num_anchors = get_const_tuple(data.shape) - max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() - p_data = ib.buffer_ptr(data) - p_index = ib.buffer_ptr(index) - p_out = ib.buffer_ptr(output) + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + score_threshold = tvm.make.node("FloatImm", dtype="float32", value=score_threshold) + + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) nthread_tx = max_threads - nthread_bx = num_anchors // max_threads + 1 + nthread_bx = batch_size * num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("vthread") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "virtual_thread", nthread_bx) - tid = bx * nthread_tx + tx - temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") - temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") - - with ib.for_range(0, batch, for_type="unroll") as b: - start = b * num_anchors - with ib.if_scope(tid < num_anchors): - p_out[start + tid] = tid - # OddEvenTransposeSort - with ib.for_range(0, p_index[b]) as k: - with ib.if_scope(tid < (p_index[b] + 1) // 2): - offset = start + 2 * tid + (k % 2) - with ib.if_scope( \ - tvm.all(offset + 1 < p_index[0], p_data[offset] < p_data[offset + 1])): - temp_data[0] = p_data[offset] - p_data[offset] = p_data[offset + 1] - p_data[offset + 1] = temp_data[0] - temp_index[0] = p_out[offset] - p_out[offset] = p_out[offset + 1] - p_out[offset + 1] = temp_index[0] - ib.emit(tvm.make.Call(None, 'tvm_storage_sync', - tvm.convert(['shared']), - tvm.expr.Call.Intrinsic, None, 0)) + tid = bx * max_threads + tx + + with ib.if_scope(tid < batch_size * num_anchors): + i = tid / num_anchors # number of batches + j = tid % num_anchors # number of anchors + base_idx = i * num_anchors * box_data_length + with ib.if_scope(data[base_idx + j * box_data_length + 1] > score_threshold): + flag[tid] = 1 + idx[tid] = 1 + with ib.else_scope(): + flag[tid] = 0 + idx[tid] = 0 + + with ib.if_scope(tid < batch_size): + with ib.for_range(0, num_anchors) as k: + with ib.if_scope(k > 0): + idx[tid * num_anchors + k] += idx[tid * num_anchors + k - 1] + + return ib.get() + + +def get_valid_counts_ir(data, flag, idx, valid_count, out): + """Low level IR to get valid count of bounding boxes + given a score threshold. Also moves valid boxes to the + top of input data. + + Parameters + ---------- + data : Buffer + Input data. 3-D Buffer with shape [batch_size, num_anchors, 6]. + + flag : Buffer + 2D Buffer of flag indicating valid data with shape [batch_size, num_anchors]. + + idx : Buffer + 2D Buffer of valid data indices with shape [batch_size, num_anchors]. + + valid_count : Buffer + 1-D buffer for valid number of boxes. + + out : Buffer + Rearranged data buffer. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + valid_count = ib.buffer_ptr(valid_count) + out = ib.buffer_ptr(out) + + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + nthread_tx = max_threads + nthread_bx = batch_size * num_anchors * elem_length // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + tid = bx * max_threads + tx + + with ib.if_scope(tid < batch_size * num_anchors): + i = tid / num_anchors # number of batches + j = tid % num_anchors # number of anchors + base_idx = i * num_anchors * elem_length + with ib.for_range(0, elem_length) as l: + out[tid * elem_length + l] = -1.0 + with ib.if_scope(flag[tid] > 0): + with ib.for_range(0, elem_length) as k: + out[base_idx + (idx[tid] - 1) * elem_length + k] =\ + data[base_idx + j * elem_length + k] + valid_count[i] = idx[i * num_anchors + num_anchors - 1] return ib.get() -def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, nms_topk): + +@get_valid_counts.register(["cuda", "gpu"]) +def get_valid_counts_gpu(data, score_threshold=0): + """Get valid count of bounding boxes given a score threshold. + Also moves valid boxes to the top of input data. + + Parameters + ---------- + data : tvm.Tensor + Input data. 3-D tensor with shape [batch_size, num_anchors, 6]. + + score_threshold : optional, float + Lower limit of score for valid bounding boxes. + + Returns + ------- + valid_count : tvm.Tensor + 1-D tensor for valid number of boxes. + + out_tensor : tvm.Tensor + Rearranged data tensor. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + temp_flag_buf = api.decl_buffer( + (batch_size, num_anchors,), "int32", "temp_flag", data_alignment=8) + temp_idx_buf = api.decl_buffer( + (batch_size, num_anchors,), "int32", "temp_idx", data_alignment=8) + data_buf = api.decl_buffer( + data.shape, data.dtype, "data_buf", data_alignment=8) + temp_flag, temp_idx = \ + tvm.extern([(batch_size, num_anchors,), (batch_size, num_anchors,)], [data], + lambda ins, outs: get_valid_counts_pre( + ins[0], outs[0], outs[1], score_threshold), + dtype=["int32", "int32"], + out_buffers=[temp_flag_buf, temp_idx_buf], + name="get_valid_counts_phase_one") + + valid_count, out_tensor = \ + tvm.extern([(batch_size,), data.shape], [data, temp_flag, temp_idx], + lambda ins, outs: get_valid_counts_ir( + ins[0], ins[1], ins[2], outs[0], outs[1]), + dtype=["int32", data.dtype], + in_buffers=[data_buf, temp_flag_buf, temp_idx_buf], + tag="get_valid_counts") + + return [valid_count, out_tensor] + + +def nms_ir(data, sorted_index, valid_count, out, box_indices, + max_output_size, iou_threshold, force_suppress, + top_k, coord_start, id_index): """Low level IR routing for transform location in multibox_detection operator. Parameters ---------- - data: Buffer + data : Buffer Buffer of output boxes with class and score. - sort_result : Buffer + sort_index : Buffer Buffer of output box indexes sorted by score. valid_count : Buffer @@ -99,15 +218,25 @@ def nms_ir(data, sort_result, valid_count, out, nms_threshold, force_suppress, n out : Buffer Output buffer. - nms_threshold : float - Non-maximum suppression threshold. + max_output_size : int + Max number of output valid boxes for each instance. + By default all valid boxes are returned. + + iou_threshold : float + Overlapping(IoU) threshold to suppress object with smaller score. force_suppress : boolean Whether to suppress all detections regardless of class_id. - nms_topk : int + top_k : int Keep maximum top k detections before nms, -1 for no limit. + coord_start : int + Start index of the consecutive 4 coordinates. + + id_index : int + index of the class categories, -1 to disable. + Returns ------- stmt : Stmt @@ -127,86 +256,217 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): (out_tensor[box_b_idx + 3] - out_tensor[box_b_idx + 1]) - i return tvm.expr.Select(u <= 0.0, 0.0, i / u) + batch_size = data.shape[0] + num_anchors = data.shape[1] + box_data_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + sorted_index = ib.buffer_ptr(sorted_index) + valid_count = ib.buffer_ptr(valid_count) + out = ib.buffer_ptr(out) + box_indices = ib.buffer_ptr(box_indices) + num_valid_boxes = ib.allocate("int32", (1,), name="num_valid_boxes", scope="local") + max_threads = int(math.sqrt( tvm.target.current_target(allow_none=False).max_num_threads)) - ib = tvm.ir_builder.create() - p_data = ib.buffer_ptr(data) - p_sort_result = ib.buffer_ptr(sort_result) - p_valid_count = ib.buffer_ptr(valid_count) - p_out = ib.buffer_ptr(out) - batch_size = out.shape[0] - num_anchors = out.shape[1] nthread_tx = max_threads nthread_bx = num_anchors // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) - i = bx * max_threads + tx - - nms_threshold_node = tvm.make.node( - "FloatImm", dtype="float32", value=nms_threshold) - nms_topk_node = tvm.make.node("IntImm", dtype="int32", value=nms_topk) - force_suppress_node = tvm.make.node( - "IntImm", dtype="int32", value=1 if force_suppress else 0) - with ib.for_range(0, batch_size, for_type="unroll") as b: - base_idx = b * num_anchors * 6 - with ib.if_scope( \ - tvm.all(nms_threshold_node > 0, nms_threshold_node < 1, - p_valid_count[0] > 0)): + k = bx * max_threads + tx + + iou_threshold = tvm.make.node("FloatImm", dtype="float32", value=iou_threshold) + top_k = tvm.make.node("IntImm", dtype="int32", value=top_k) + coord_start = tvm.make.node("IntImm", dtype="int32", value=coord_start) + id_index = tvm.make.node("IntImm", dtype="int32", value=id_index) + force_suppress = tvm.make.node("IntImm", dtype="int32", value=1 if force_suppress else 0) + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * box_data_length + with ib.if_scope(tvm.all(iou_threshold > 0, valid_count[i] > 0)): # Reorder output - nkeep = tvm.if_then_else( \ - tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b]), - nms_topk, p_valid_count[b]) - with ib.for_range(0, nkeep) as l: - with ib.if_scope(i < 6): - p_out[(base_idx + l * 6 + i)] = \ - p_data[(base_idx + p_sort_result[b * num_anchors + l] * 6 + i)] - with ib.if_scope(tvm.all(nms_topk_node > 0, nms_topk < p_valid_count[b])): - with ib.for_range(0, p_valid_count[b] - nkeep) as l: - with ib.if_scope(i < 6): - p_out[(base_idx + (l + nkeep) * 6 + i)] = -1.0 + nkeep = if_then_else( \ + tvm.all(top_k > 0, top_k < valid_count[i]), + top_k, valid_count[i]) + with ib.for_range(0, nkeep) as j: + with ib.if_scope(k < box_data_length): + out[(base_idx + j * box_data_length + k)] = \ + data[(base_idx + sorted_index[i * num_anchors + j] \ + * box_data_length + k)] + box_indices[i * num_anchors + j] = sorted_index[i * num_anchors + j] + with ib.if_scope(tvm.all(top_k > 0, top_k < valid_count[i])): + with ib.for_range(0, valid_count[i] - nkeep) as j: + with ib.if_scope(k < box_data_length): + out[(base_idx + (j + nkeep) * box_data_length + k)] = -1.0 + box_indices[i * num_anchors + (j + nkeep)] = -1 # Apply nms - with ib.for_range(0, p_valid_count[b]) as l: - offset_l = l * 6 - with ib.if_scope(p_out[base_idx + offset_l] >= 0): - with ib.if_scope(i < p_valid_count[b]): - offset_i = i * 6 - with ib.if_scope(tvm.all(i > l, p_out[base_idx - + offset_i] >= 0)): - with ib.if_scope(tvm.any(force_suppress_node > 0, - p_out[base_idx + offset_l] == - p_out[base_idx + offset_i])): - # When force_suppress == True or class_id equals - iou = calculate_overlap(p_out, base_idx + offset_l + 2, - base_idx + offset_i + 2) - with ib.if_scope(iou >= nms_threshold): - p_out[base_idx + offset_i] = -1.0 + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(out[base_idx + offset_j] >= 0): + with ib.if_scope(k < valid_count[i]): + offset_k = k * box_data_length + with ib.if_scope(tvm.all(k > j, out[base_idx + offset_k] >= 0, \ + tvm.any(force_suppress > 0, id_index < 0, \ + out[base_idx + offset_j] == \ + out[base_idx + offset_k]))): + iou = calculate_overlap(out, base_idx + offset_k + coord_start, + base_idx + offset_j + coord_start) + with ib.if_scope(iou >= iou_threshold): + out[base_idx + offset_k] = -1.0 + box_indices[i * num_anchors + k] = -1 ib.emit(tvm.make.Call(None, 'tvm_storage_sync', tvm.convert(['shared']), tvm.expr.Call.Intrinsic, None, 0)) with ib.else_scope(): - with ib.for_range(0, p_valid_count[b]) as c: - with ib.if_scope(i < 6): - p_out[(base_idx + c * 6 + i)] = p_data[base_idx + c * 6 + i] + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(k < box_data_length): + out[(base_idx + offset_j + k)] = data[base_idx + offset_j + k] + box_indices[i * num_anchors + j] = j # Set invalid entry to be -1 - with ib.for_range(0, num_anchors - p_valid_count[b]) as c: - with ib.if_scope(i < 6): - p_out[base_idx + (c + p_valid_count[b]) * 6 + i] = -1.0 - body = ib.get() - return body + with ib.for_range(0, num_anchors - valid_count[i]) as j: + with ib.if_scope(k < box_data_length): + out[base_idx + (j + valid_count[i]) * box_data_length + k] = -1.0 + box_indices[i * num_anchors + j + valid_count[i]] = -1 + # Only return max_output_size number of valid boxes + num_valid_boxes[0] = 0 + with ib.if_scope(max_output_size > 0): + with ib.for_range(0, valid_count[i]) as j: + offset_j = j * box_data_length + with ib.if_scope(out[base_idx + offset_j] >= 0): + with ib.if_scope(num_valid_boxes[0] == max_output_size): + with ib.if_scope(k < box_data_length): + out[base_idx + offset_j + k] = -1.0 + box_indices[i * num_anchors + j] = -1 + with ib.else_scope(): + num_valid_boxes[0] += 1 + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) + + return ib.get() + + +def invalid_to_bottom_pre(data, flag, idx): + """Low level IR to rearrange nms output to move all valid entries to top. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, 6], output of nms. + + flag : Buffer + 1D Buffer of flag indicating valid data with [num_anchors]. + + idx : Buffer + 1D Buffer of valid data indices with [num_anchors]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + + max_threads = int(math.sqrt( + tvm.target.current_target(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + j = bx * max_threads + tx + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * elem_length + with ib.if_scope(j < num_anchors): + with ib.if_scope(data[base_idx + j * elem_length] >= 0): + flag[i * num_anchors + j] = 1 + idx[i * num_anchors + j] = 1 + with ib.else_scope(): + flag[i * num_anchors + j] = 0 + idx[i * num_anchors + j] = 0 + + with ib.if_scope(j < batch_size): + with ib.for_range(0, num_anchors) as k: + with ib.if_scope(k > 0): + idx[j * num_anchors + k] += idx[j * num_anchors + k - 1] + return ib.get() + + +def invalid_to_bottom_ir(data, flag, idx, out): + """Low level IR to rearrange nms output to move all valid entries to top. + + Parameters + ---------- + data: Buffer + 3D Buffer with shape [batch_size, num_anchors, 6], output of nms. + + flag : Buffer + 1D Buffer of flag indicating valid data with [num_anchors]. + + idx : Buffer + 1D Buffer of valid data indices with [num_anchors]. + + out : Buffer + 3D Buffer of rearranged nms output with shape [batch_size, num_anchors, 6]. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + batch_size = data.shape[0] + num_anchors = data.shape[1] + elem_length = data.shape[2] + + ib = tvm.ir_builder.create() + + data = ib.buffer_ptr(data) + flag = ib.buffer_ptr(flag) + idx = ib.buffer_ptr(idx) + out = ib.buffer_ptr(out) + + max_threads = int(math.sqrt( + tvm.target.current_target(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = num_anchors // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("blockIdx.x") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "thread_extent", nthread_bx) + j = bx * max_threads + tx + + with ib.for_range(0, batch_size, for_type="unroll") as i: + base_idx = i * num_anchors * elem_length + with ib.if_scope(j < num_anchors): + with ib.for_range(0, elem_length) as k: + out[base_idx + j * elem_length + k] = -1.0 + with ib.if_scope(flag[i * num_anchors + j] > 0): + with ib.for_range(0, elem_length) as k: + out[base_idx + (idx[i * num_anchors + j] - 1) * elem_length + k] \ + = data[base_idx + j * elem_length + k] + return ib.get() @non_max_suppression.register(["cuda", "gpu"]) -def nms_gpu(data, - valid_count, - max_output_size=-1, - iou_threshold=0.5, - force_suppress=False, - top_k=-1, - id_index=0, - return_indices=True, - invalid_to_bottom=False): +def non_max_suppression_gpu(data, valid_count, max_output_size=-1, + iou_threshold=0.5, force_suppress=False, top_k=-1, + coord_start=2, score_index=1, id_index=0, + return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters @@ -219,8 +479,9 @@ def nms_gpu(data, valid_count : tvm.Tensor 1-D tensor for valid number of boxes. - return_indices : boolean - Whether to return box indices in input data. + max_output_size : optional, int + Max number of output valid boxes for each instance. + By default all valid boxes are returned. iou_threshold : optional, float Non-maximum suppression threshold. @@ -231,9 +492,18 @@ def nms_gpu(data, top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. + coord_start : required, int + Start index of the consecutive 4 coordinates. + + score_index : optional, int + Index of the scores/confidence of boxes. + id_index : optional, int index of the class categories, -1 to disable. + return_indices : boolean + Whether to return box indices in input data. + invalid_to_bottom : optional, boolean Whether to move all valid bounding boxes to the top. @@ -253,12 +523,13 @@ def nms_gpu(data, iou_threshold = 0.7 force_suppress = True top_k = -1 - out = nms(data, valid_count, iou_threshold, force_suppress, topk) + out = non_max_suppression(data=data, valid_count=valid_count, iou_threshold=iou_threshold, + force_suppress=force_supress, top_k=top_k, return_indices=False) np_data = np.random.uniform(dshape) np_valid_count = np.array([4]) s = topi.generic.schedule_nms(out) - f = tvm.build(s, [data, valid_count, out], "llvm") - ctx = tvm.cpu() + f = tvm.build(s, [data, valid_count, out], "cuda") + ctx = tvm.gpu(0) tvm_data = tvm.nd.array(np_data, ctx) tvm_valid_count = tvm.nd.array(np_valid_count, ctx) tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) @@ -266,38 +537,62 @@ def nms_gpu(data, """ batch_size = data.shape[0] num_anchors = data.shape[1] + valid_count_dtype = "int32" valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, "valid_count_buf", data_alignment=4) - data_buf = api.decl_buffer( - data.shape, data.dtype, "data_buf", data_alignment=8) + score_axis = score_index score_shape = (batch_size, num_anchors) - score_tensor = tvm.compute( - score_shape, lambda i, j: data[i, j, 1], name="score_tensor") - score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, - "score_tensor_buf", data_alignment=8) + score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) + sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) - sort_tensor_dtype = "int32" - sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, + sort_tensor_buf = api.decl_buffer(sort_tensor.shape, sort_tensor.dtype, "sort_tensor_buf", data_alignment=8) - sort_tensor = \ - tvm.extern(score_shape, - [score_tensor, valid_count], - lambda ins, outs: sort_ir( - ins[0], ins[1], outs[0]), - dtype=sort_tensor_dtype, - in_buffers=[score_tensor_buf, valid_count_buf], - out_buffers=sort_tensor_buf, - name="nms_sort") - - out = \ - tvm.extern(data.shape, + data_buf = api.decl_buffer( + data.shape, data.dtype, "data_buf", data_alignment=8) + + out_buf = api.decl_buffer( + data.shape, data.dtype, "out_buf", data_alignment=8) + + out, box_indices = \ + tvm.extern([data.shape, score_shape], [data, sort_tensor, valid_count], lambda ins, outs: nms_ir( - ins[0], ins[1], ins[2], outs[0], iou_threshold, - force_suppress, top_k), - dtype="float32", + ins[0], ins[1], ins[2], outs[0], outs[1], + max_output_size, iou_threshold, force_suppress, + top_k, coord_start, id_index), + dtype=[data.dtype, "int32"], in_buffers=[data_buf, sort_tensor_buf, valid_count_buf], + name="nms", tag="nms") + + if return_indices: + return box_indices + + if invalid_to_bottom: + output_buf = api.decl_buffer( + data.shape, data.dtype, "output_buf", data_alignment=8) + temp_flag_buf = api.decl_buffer( + score_shape, valid_count_dtype, "temp_flag", data_alignment=8) + temp_idx_buf = api.decl_buffer( + score_shape, valid_count_dtype, "temp_idx", data_alignment=8) + temp_flag, temp_idx = tvm.extern([score_shape, score_shape], [out], + lambda ins, outs: invalid_to_bottom_pre( + ins[0], outs[0], outs[1]), + dtype=["int32", "int32"], + in_buffers=[out_buf], + out_buffers=[temp_flag_buf, temp_idx_buf], + name="invalid_to_bottom_phase_one") + + output = tvm.extern([data.shape], [out, temp_flag, temp_idx], + lambda ins, outs: invalid_to_bottom_ir( + ins[0], ins[1], ins[2], outs[0]), + dtype=[data.dtype], + in_buffers=[out_buf, temp_flag_buf, temp_idx_buf], + out_buffers=[output_buf], + name="invalid_to_bottom", + tag="invalid_to_bottom") + return output + return out diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py new file mode 100644 index 0000000000000..cdbb52f402090 --- /dev/null +++ b/topi/python/topi/cuda/sort.py @@ -0,0 +1,230 @@ +# pylint: disable=invalid-name, no-member, too-many-locals, too-many-arguments, too-many-statements, singleton-comparison, unused-argument +"""Argsort operator """ +import tvm + +from tvm import api +from topi.vision.sort import argsort + +def sort_ir(data, output, axis, is_ascend): + """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + + Parameters + ---------- + data: Buffer + Buffer of input data. + + output : Buffer + Output buffer of indicies of sorted tensor with same shape as data. + + axis : Int + Axis long which to sort the input tensor. + + is_ascend : Boolean + Whether to sort in ascending or descending order. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + size = 1 + axis_mul_before = 1 + axis_mul_after = 1 + shape = data.shape + if axis < 0: + axis = len(shape) + axis + for i, value in enumerate(shape, 0): + size *= value + if i < axis: + axis_mul_before *= value + elif i > axis: + axis_mul_after *= value + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + ib = tvm.ir_builder.create() + data = ib.buffer_ptr(data) + output = ib.buffer_ptr(output) + nthread_tx = max_threads + nthread_bx = size // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("vthread") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "virtual_thread", nthread_bx) + tid = bx * nthread_tx + tx + temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") + temp_index = ib.allocate("float32", (1,), name="temp_index", scope="local") + is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) + + with ib.for_range(0, axis_mul_before) as i: + with ib.for_range(0, axis_mul_after) as j: + current_sort_num = shape[axis] + base_idx = i * shape[axis] * axis_mul_after + j + with ib.if_scope(tid < shape[axis]): + output[base_idx + tid * axis_mul_after] = tid.astype("float32") + # OddEvenTransposeSort + with ib.for_range(0, current_sort_num) as k: + with ib.if_scope(tid < (current_sort_num + 1) // 2): + offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tvm.all(is_ascend == 1, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] > data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + with ib.if_scope(tvm.all(is_ascend == 0, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] < data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) + + return ib.get() + + + +def sort_nms_ir(data, valid_count, output, axis, is_ascend): + """Low level IR to do nms sorting on the GPU, same usage as tvm.contrib.sort.argsort on the CPU. + + Parameters + ---------- + data: Buffer + Buffer of input data. + + valid_count : Buffer + 1D Buffer of number of valid number of boxes. + + output : Buffer + Output buffer of indicies of sorted tensor with same shape as data. + + axis : Int + Axis long which to sort the input tensor. + + is_ascend : Boolean + Whether to sort in ascending or descending order. + + Returns + ------- + stmt : Stmt + The result IR statement. + """ + + size = 1 + axis_mul_before = 1 + axis_mul_after = 1 + shape = data.shape + if axis < 0: + axis = len(shape) + axis + for i, value in enumerate(shape, 0): + size *= value + if i < axis: + axis_mul_before *= value + elif i > axis: + axis_mul_after *= value + max_threads = int(tvm.target.current_target(allow_none=False).max_num_threads) + ib = tvm.ir_builder.create() + data = ib.buffer_ptr(data) + valid_count = ib.buffer_ptr(valid_count) + output = ib.buffer_ptr(output) + nthread_tx = max_threads + nthread_bx = size // max_threads + 1 + tx = tvm.thread_axis("threadIdx.x") + bx = tvm.thread_axis("vthread") + ib.scope_attr(tx, "thread_extent", nthread_tx) + ib.scope_attr(bx, "virtual_thread", nthread_bx) + tid = bx * nthread_tx + tx + temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") + temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") + is_ascend = tvm.make.node("IntImm", dtype="int32", value=is_ascend) + + with ib.for_range(0, axis_mul_before) as i: + with ib.for_range(0, axis_mul_after) as j: + current_sort_num = valid_count[i * axis_mul_after + j] + base_idx = i * shape[axis] * axis_mul_after + j + with ib.if_scope(tid < shape[axis]): + output[base_idx + tid * axis_mul_after] = tid + # OddEvenTransposeSort + with ib.for_range(0, current_sort_num) as k: + with ib.if_scope(tid < (current_sort_num + 1) // 2): + offset = base_idx + (2 * tid + (k % 2)) * axis_mul_after + with ib.if_scope(tvm.all(is_ascend == 1, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] > data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + with ib.if_scope(tvm.all(is_ascend == 0, \ + 2 * tid + (k % 2) + 1 < current_sort_num, \ + data[offset] < data[offset + axis_mul_after])): + temp_data[0] = data[offset] + data[offset] = data[offset + axis_mul_after] + data[offset + axis_mul_after] = temp_data[0] + temp_index[0] = output[offset] + output[offset] = output[offset + axis_mul_after] + output[offset + axis_mul_after] = temp_index[0] + ib.emit(tvm.make.Call(None, 'tvm_storage_sync', + tvm.convert(['shared']), + tvm.expr.Call.Intrinsic, None, 0)) + + return ib.get() + +@argsort.register(["cuda", "gpu"]) +def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): + """Performs sorting along the given axis and returns an array of indicies + having same shape as an input array that index data in sorted order. + + Parameters + ---------- + data: tvm.Tensor + The input array. + + valid_count : tvm.Tensor + The number of valid elements to be sorted. + + axis : int + Axis long which to sort the input tensor. + + is_ascend : boolean + Whether to sort in ascending or descending order. + + Returns + ------- + out : tvm.Tensor + The output of this function. + """ + data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + if flag: + valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, + "valid_count_buf", data_alignment=4) + out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) + out = tvm.extern([data.shape], + [data, valid_count], + lambda ins, outs: sort_nms_ir( + ins[0], ins[1], outs[0], axis, is_ascend), + dtype="int32", + in_buffers=[data_buf, valid_count_buf], + out_buffers=[out_buf], + name="argsort_nms_gpu", + tag="argsort_nms_gpu") + else: + out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + out = tvm.extern([data.shape], + [data], + lambda ins, outs: sort_ir( + ins[0], outs[0], axis, is_ascend), + dtype=dtype, + in_buffers=[data_buf], + out_buffers=[out_buf], + name="argsort_gpu", + tag="argsort_gpu") + return out diff --git a/topi/python/topi/cuda/ssd/multibox.py b/topi/python/topi/cuda/ssd/multibox.py index 38b76f36801ee..847f35790e905 100644 --- a/topi/python/topi/cuda/ssd/multibox.py +++ b/topi/python/topi/cuda/ssd/multibox.py @@ -21,6 +21,7 @@ import tvm from tvm import api +from tvm.intrin import if_then_else import topi @@ -93,12 +94,11 @@ def multibox_prior_ir(data, out, sizes, ratios, steps, offsets): center_w = (j + offset_w) * steps_w for k in range(num_sizes + num_ratios - 1): - w = tvm.if_then_else(k < num_sizes, - size_ratio_concat[ - k] * in_height / in_width / 2.0, - size_ratio_concat[0] * in_height / in_width * - math.sqrt(size_ratio_concat[k + 1]) / 2.0) - h = tvm.if_then_else( + w = if_then_else(k < num_sizes, + size_ratio_concat[k] * in_height / in_width / 2.0, + size_ratio_concat[0] * in_height / in_width * + math.sqrt(size_ratio_concat[k + 1]) / 2.0) + h = if_then_else( k < num_sizes, size_ratio_concat[k] / 2.0, size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0) count = (i * in_width * (num_sizes + num_ratios - 1) + @@ -154,8 +154,7 @@ def multibox_prior_gpu(data, sizes=(1,), ratios=(1,), steps=(-1, -1), out = topi.clip(out, 0, 1) return out - -def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, threshold): +def transform_loc_pre(cls_prob, valid_count, temp_valid_count, temp_cls_id, temp_score, threshold): """Low level IR routing for transform location data preparation. Parameters @@ -166,13 +165,13 @@ def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, valid_count : Buffer Buffer of number of valid output boxes. - temp_flag : Buffer + temp_valid_count : Buffer Output intermediate result buffer - temp_id : Buffer + temp_cls_id : Buffer Output intermediate result buffer - temp_score_out : Buffer + temp_score : Buffer Output buffer threshold : float @@ -187,53 +186,56 @@ def transform_loc_pre(cls_prob, valid_count, temp_flag, temp_id, temp_score_out, num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] - max_threads = int( - tvm.target.current_target(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() - score = ib.buffer_ptr(temp_score_out) - cls_id = ib.buffer_ptr(temp_id) - flag = ib.buffer_ptr(temp_flag) + + cls_prob = ib.buffer_ptr(cls_prob) + cls_id = ib.buffer_ptr(temp_cls_id) + valid_count = ib.buffer_ptr(valid_count) + temp_valid_count = ib.buffer_ptr(temp_valid_count) + score = ib.buffer_ptr(temp_score) + + threshold = tvm.make.node("FloatImm", dtype="float32", value=threshold) + + max_threads = int( + math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") - nthread_tx = max_threads - nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1 ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - p_cls_prob = ib.buffer_ptr(cls_prob) - p_valid_count = ib.buffer_ptr(valid_count) with ib.if_scope(tid < batch_size * num_anchors): - n = tid / num_anchors # number of batches - i = tid % num_anchors # number of anchors - score[i] = -1.0 - cls_id[i] = 0 - p_valid_count[n] = 0 - with ib.for_range(0, num_classes-1, name="k") as k: - temp = p_cls_prob[n * num_anchors * num_classes + (k + 1) * num_anchors + i] - with ib.if_scope(temp > score[i]): - cls_id[i] = k + 1 - score[i] = temp - with ib.if_scope(tvm.all(cls_id[i] > 0, score[i] < threshold)): - cls_id[i] = 0 - with ib.if_scope(cls_id[i] > 0): - flag[i] = 1 + i = tid / num_anchors # number of batches + j = tid % num_anchors # number of anchors + valid_count[i] = 0 + score[i * num_anchors + j] = -1.0 + cls_id[i * num_anchors + j] = 0 + with ib.for_range(0, num_classes-1) as k: + temp = cls_prob[i * num_classes * num_anchors + (k + 1) * num_anchors + j] + cls_id[i * num_anchors + j] = if_then_else(temp > score[i * num_anchors + j], \ + k + 1, cls_id[i * num_anchors + j]) + score[i * num_anchors + j] = tvm.max(temp, score[i * num_anchors + j]) + with ib.if_scope(tvm.all(cls_id[i * num_anchors + j] > 0, \ + score[i * num_anchors + j] < threshold)): + cls_id[i * num_anchors + j] = 0 + with ib.if_scope(cls_id[i * num_anchors + j] > 0): + temp_valid_count[i * num_anchors + j] = 1 with ib.else_scope(): - flag[i] = 0 + temp_valid_count[i * num_anchors + j] = 0 with ib.if_scope(tid < batch_size): - with ib.for_range(0, num_anchors, name="k") as k: + with ib.for_range(0, num_anchors) as k: with ib.if_scope(k > 0): - flag[tid * num_anchors + - k] += flag[tid * num_anchors + k - 1] - p_valid_count[n] = flag[tid * num_anchors + num_anchors - 1] - - body = ib.get() - return body + temp_valid_count[tid * num_anchors +k] += \ + temp_valid_count[tid * num_anchors + k - 1] + valid_count[i] = temp_valid_count[tid * num_anchors + num_anchors - 1] + return ib.get() -def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ - out, clip, variances, batch_size, num_classes, num_anchors): +def transform_loc_ir(loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score, out, \ + clip, variances, batch_size, num_anchors): """Low level IR routing for transform location in multibox_detection operator. Parameters @@ -244,13 +246,13 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ anchor : Buffer Buffer of prior anchor boxes. - temp_flag : Buffer + temp_valid_count : Buffer Intermediate result buffer. - temp_id : Buffer + temp_cls_id : Buffer Intermediate result buffer. - temp_score_in : Buffer + temp_score : Buffer Input buffer which stores intermediate results. out : Buffer @@ -265,9 +267,6 @@ def transform_loc_ir(loc_pred, anchor, temp_flag, temp_id, temp_score_in, \ batch_size : int Batch size - num_classes : int - Number of classes - num_anchors : int Number of anchors @@ -300,40 +299,41 @@ def transform_loc(loc, loc_base_idx, anchor, anchor_base_idx, clip, vx, vy, vw, tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, ox + ow)), ox + ow), \ tvm.if_then_else(clip, tvm.make.Max(0.0, tvm.make.Min(1.0, oy + oh)), oy + oh) - max_threads = int( - tvm.target.current_target(allow_none=False).max_num_threads) ib = tvm.ir_builder.create() - score = ib.buffer_ptr(temp_score_in) - cls_id = ib.buffer_ptr(temp_id) - flag = ib.buffer_ptr(temp_flag) + + loc_pred = ib.buffer_ptr(loc_pred) + anchor = ib.buffer_ptr(anchor) + temp_valid_count = ib.buffer_ptr(temp_valid_count) + cls_id = ib.buffer_ptr(temp_cls_id) + score = ib.buffer_ptr(temp_score) + out_loc = ib.buffer_ptr(out) + + max_threads = int( + math.sqrt(tvm.target.current_target(allow_none=False).max_num_threads)) + nthread_tx = max_threads + nthread_bx = (batch_size * num_anchors) // max_threads + 1 tx = tvm.thread_axis("threadIdx.x") bx = tvm.thread_axis("blockIdx.x") - nthread_tx = max_threads - nthread_bx = (batch_size * num_anchors * num_classes) // max_threads + 1 ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) tid = bx * max_threads + tx - p_loc_pred = ib.buffer_ptr(loc_pred) - p_anchor = ib.buffer_ptr(anchor) - p_out = ib.buffer_ptr(out) with ib.if_scope(tid < batch_size * num_anchors): - n = tid / num_anchors # number of batches - i = tid % num_anchors # number of anchors + i = tid / num_anchors # number of batches + j = tid % num_anchors # number of anchors with ib.if_scope(cls_id[tid] > 0): with ib.if_scope(tid == 0): - out_base_idx = n * num_anchors * 6 + out_base_idx = i * num_anchors * 6 with ib.else_scope(): - out_base_idx = n * num_anchors * 6 + flag[tid - 1] * 6 - p_out[out_base_idx] = cls_id[tid] - 1.0 - p_out[out_base_idx + 1] = score[tid] - p_out[out_base_idx + 2], p_out[out_base_idx + 3], p_out[out_base_idx + 4], \ - p_out[out_base_idx + 5] = transform_loc(p_loc_pred, tid * 4, - p_anchor, i*4, clip, variances[0], - variances[1], variances[2], variances[3]) + out_base_idx = i * num_anchors * 6 + temp_valid_count[tid - 1] * 6 + out_loc[out_base_idx] = cls_id[tid] - 1.0 + out_loc[out_base_idx + 1] = score[tid] + out_loc[out_base_idx + 2], out_loc[out_base_idx + 3], out_loc[out_base_idx + 4], \ + out_loc[out_base_idx + 5] = transform_loc(loc_pred, tid * 4, + anchor, j * 4, clip, variances[0], + variances[1], variances[2], variances[3]) - body = ib.get() - return body + return ib.get() @multibox_transform_loc.register(["cuda", "gpu"]) @@ -372,44 +372,42 @@ def multibox_transform_loc_gpu(cls_prob, loc_pred, anchor, clip=True, \ 1-D tensor with shape (batch_size,), number of valid anchor boxes. """ batch_size = cls_prob.shape[0] - num_classes = cls_prob.shape[1] num_anchors = cls_prob.shape[2] oshape = (batch_size, num_anchors, 6) # Define data alignment for intermediate buffer valid_count_dtype = "int32" + out_loc_dtype = loc_pred.dtype + valid_count_buf = api.decl_buffer((batch_size,), valid_count_dtype, "valid_count_buf", data_alignment=4) - out_buf = api.decl_buffer( - oshape, cls_prob.dtype, "out_buf", data_alignment=8) - size = num_anchors - temp_flag_buf = api.decl_buffer( - (size,), valid_count_dtype, "flag", data_alignment=8) - temp_id_buf = api.decl_buffer( - (size,), valid_count_dtype, "cls_id", data_alignment=8) + + temp_valid_count_buf = api.decl_buffer( + (batch_size, num_anchors,), valid_count_dtype, "temp_valid_count", data_alignment=8) + temp_cls_id_buf = api.decl_buffer( + (batch_size, num_anchors,), valid_count_dtype, "temp_cls_id", data_alignment=8) temp_score_buf = api.decl_buffer( - (size,), cls_prob.dtype, "score", data_alignment=8) + (batch_size, num_anchors,), cls_prob.dtype, "temp_score", data_alignment=8) - valid_count, temp_flag, temp_id, temp_score = \ - tvm.extern([(batch_size,), (size,), (size,), (size,)], - [cls_prob], + valid_count, temp_valid_count, temp_cls_id, temp_score = \ + tvm.extern([(batch_size,), (batch_size, num_anchors,), (batch_size, num_anchors,), \ + (batch_size, num_anchors,)], [cls_prob], lambda ins, outs: transform_loc_pre( ins[0], outs[0], outs[1], outs[2], outs[3], threshold), - dtype=[valid_count_dtype, - valid_count_dtype, valid_count_dtype, cls_prob.dtype], - out_buffers=[valid_count_buf, - temp_flag_buf, temp_id_buf, temp_score_buf], - tag="multibox_transform_loc_first_step") + dtype=[valid_count_dtype, valid_count_dtype, valid_count_dtype, cls_prob.dtype], + out_buffers=[valid_count_buf, temp_valid_count_buf, \ + temp_cls_id_buf, temp_score_buf], + tag="multibox_transform_loc_phase_one") - out = \ + out_loc = \ tvm.extern([oshape], - [loc_pred, anchor, temp_flag, temp_id, temp_score], + [loc_pred, anchor, temp_valid_count, temp_cls_id, temp_score], lambda ins, outs: transform_loc_ir( - ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, \ - variances, batch_size, num_classes, num_anchors), - dtype=[cls_prob.dtype], - out_buffers=[out_buf], + ins[0], ins[1], ins[2], ins[3], ins[4], outs[0], clip, variances, \ + batch_size, num_anchors), + dtype=[out_loc_dtype], tag="multibox_transform_loc") - return [out, valid_count] + + return [out_loc, valid_count] @multibox_detection.register(["cuda", "gpu"]) @@ -453,6 +451,7 @@ def multibox_detection_gpu(cls_prob, loc_pred, anchor, clip=True, threshold=0.01 """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression( - inter_out[0], inter_out[1], nms_threshold, force_suppress, nms_topk) + out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + iou_threshold=nms_threshold, force_suppress=force_suppress, + top_k=nms_topk, return_indices=False) return out diff --git a/topi/python/topi/cuda/vision.py b/topi/python/topi/cuda/vision.py index 5d7bc9e00da63..78f5c1f51ec6a 100644 --- a/topi/python/topi/cuda/vision.py +++ b/topi/python/topi/cuda/vision.py @@ -32,11 +32,15 @@ def _default_schedule(outs): def traverse(op): """inline all one-to-one-mapping operators except the last stage (output)""" - if "nms" in op.tag: - sort = op.input_tensors[1] + if op.tag in ["nms", "invalid_to_bottom"]: + if op.tag == "nms": + sort = op.input_tensors[1] + else: + out = op.input_tensors[0] + sort = s[out].op.input_tensors[1] score = s[sort].op.input_tensors[0] fused = s[score].fuse(*s[score].op.axis) - num_thread = tvm.target.current_target(allow_none=False).max_num_threads + num_thread = int(tvm.target.current_target(allow_none=False).max_num_threads) bx, tx = s[score].split(fused, factor=num_thread) s[score].bind(bx, tvm.thread_axis("blockIdx.x")) s[score].bind(tx, tvm.thread_axis("threadIdx.x")) @@ -199,3 +203,30 @@ def schedule_get_valid_counts(outs): The computation schedule for the op. """ return _default_schedule(outs) + +@generic.schedule_argsort.register(["cuda", "gpu"]) +def schedule_argsort(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of argsort + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + scheduled_ops = [] + from .injective import _schedule_injective + def traverse(op): + for tensor in op.input_tensors: + if tensor.op.input_tensors and tensor.op not in scheduled_ops: + traverse(tensor.op) + scheduled_ops.append(op) + traverse(outs[0].op) + return s diff --git a/topi/python/topi/generic/vision.py b/topi/python/topi/generic/vision.py index a1e096a858806..5d0eb9b2e9012 100644 --- a/topi/python/topi/generic/vision.py +++ b/topi/python/topi/generic/vision.py @@ -188,3 +188,20 @@ def schedule_proposal(outs): The computation schedule for the op. """ return _default_schedule(outs, False) + +@tvm.target.generic_func +def schedule_argsort(outs): + """Schedule for argsort operator. + + Parameters + ---------- + outs: Array of Tensor + The indices that would sort an input array along + the given axis. + + Returns + ------- + s: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) diff --git a/topi/python/topi/vision/__init__.py b/topi/python/topi/vision/__init__.py index c10f7c68bf362..b3db0c56d9a95 100644 --- a/topi/python/topi/vision/__init__.py +++ b/topi/python/topi/vision/__init__.py @@ -6,3 +6,4 @@ from .reorg import * from .nms import * from .rcnn import * +from .sort import * diff --git a/topi/python/topi/vision/nms.py b/topi/python/topi/vision/nms.py index d8b15aac42c6d..43efb09f43f5e 100644 --- a/topi/python/topi/vision/nms.py +++ b/topi/python/topi/vision/nms.py @@ -18,7 +18,8 @@ """Non-maximum suppression operator""" import tvm -from tvm import api, hybrid +from tvm import hybrid +from .sort import argsort @hybrid.script def hybrid_rearrange_out(data): @@ -129,7 +130,7 @@ def get_valid_counts(data, score_threshold=0): @hybrid.script def hybrid_nms(data, sorted_index, valid_count, max_output_size, iou_threshold, force_suppress, - top_k, id_index): + top_k, coord_start, id_index): """Hybrid routing for non-maximum suppression. Parameters @@ -158,6 +159,9 @@ def hybrid_nms(data, sorted_index, valid_count, top_k : tvm.const Keep maximum top k detections before nms, -1 for no limit. + coord_start : tvm.const + Start index of the consecutive 4 coordinates. + id_index : tvm.const index of the class categories, -1 to disable. @@ -208,7 +212,7 @@ def hybrid_nms(data, sorted_index, valid_count, batch_idx = i box_a_idx = j box_b_idx = k - box_start_idx = 2 + box_start_idx = coord_start a_t = output[batch_idx, box_a_idx, box_start_idx + 1] a_b = output[batch_idx, box_a_idx, box_start_idx + 3] a_l = output[batch_idx, box_a_idx, box_start_idx] @@ -252,7 +256,8 @@ def hybrid_nms(data, sorted_index, valid_count, @tvm.target.generic_func def non_max_suppression(data, valid_count, max_output_size=-1, iou_threshold=0.5, force_suppress=False, top_k=-1, - id_index=0, return_indices=True, invalid_to_bottom=False): + coord_start=2, score_index=1, id_index=0, + return_indices=True, invalid_to_bottom=False): """Non-maximum suppression operator for object detection. Parameters @@ -278,6 +283,12 @@ def non_max_suppression(data, valid_count, max_output_size=-1, top_k : optional, int Keep maximum top k detections before nms, -1 for no limit. + coord_start : required, int + Start index of the consecutive 4 coordinates. + + score_index: optional, int + Index of the scores/confidence of boxes. + id_index : optional, int index of the class categories, -1 to disable. @@ -317,32 +328,16 @@ def non_max_suppression(data, valid_count, max_output_size=-1, """ batch_size = data.shape[0] num_anchors = data.shape[1] - valid_count_dtype = "int32" - valid_count_buf = api.decl_buffer(valid_count.shape, valid_count_dtype, - "valid_count_buf", data_alignment=4) - score_axis = 1 + score_axis = score_index score_shape = (batch_size, num_anchors) score_tensor = tvm.compute(score_shape, lambda i, j: data[i, j, score_axis]) - score_tensor_buf = api.decl_buffer(score_tensor.shape, data.dtype, - "score_tensor_buf", data_alignment=8) - sort_tensor_dtype = "int32" - sort_tensor_buf = api.decl_buffer(score_shape, sort_tensor_dtype, - "sort_tensor_buf", data_alignment=8) - sort_tensor = \ - tvm.extern(score_shape, - [score_tensor, valid_count], - lambda ins, outs: tvm.call_packed( - "tvm.contrib.sort.argsort", ins[0], ins[1], - outs[0], score_axis, True), - dtype=sort_tensor_dtype, - in_buffers=[score_tensor_buf, valid_count_buf], - out_buffers=sort_tensor_buf, - name="nms_sort") + sort_tensor = argsort(score_tensor, valid_count=valid_count, axis=1, is_ascend=False, flag=True) out, box_indices = hybrid_nms(data, sort_tensor, valid_count, tvm.const(max_output_size, dtype="int32"), tvm.const(iou_threshold, dtype="float32"), tvm.const(force_suppress, dtype="bool"), tvm.const(top_k, dtype="int32"), + tvm.const(coord_start, dtype="int32"), tvm.const(id_index, dtype="int32")) if not return_indices and invalid_to_bottom: out = hybrid_rearrange_out(out) diff --git a/topi/python/topi/vision/sort.py b/topi/python/topi/vision/sort.py new file mode 100644 index 0000000000000..afe6f45e14d31 --- /dev/null +++ b/topi/python/topi/vision/sort.py @@ -0,0 +1,88 @@ +"""Argsort operator""" +import tvm +from tvm import api + +@tvm.target.generic_func +def argsort(data, valid_count, axis=-1, is_ascend=1, dtype="float32", flag=0): + """Performs sorting along the given axis and returns an array + of indices having the same shape as an input array that index + data in sorted order. + + Parameters + ---------- + data : tvm.Tensor + The input tensor. + + valid_count : tvm.Tensor + 1-D tensor for valid number of boxes only for ssd. + + axis : optional, int + Axis along which to sort the input tensor. + By default the flattened array is used. + + is_ascend : optional, boolean + Whether to sort in ascending or descending order. + + dtype : optional, string + DType of the output indices. + + flag : optional, boolean + Whether valid_count is valid. + + Returns + ------- + out : tvm.Tensor + Sorted index tensor. + + Example + -------- + .. code-block:: python + + # An example to use argsort + dshape = (1, 5, 6) + data = tvm.placeholder(dshape, name="data") + valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") + axis = 0 + is_ascend = False + flag = False + out = argsort(data, valid_count, axis, is_ascend, flag) + np_data = np.random.uniform(dshape) + np_valid_count = np.array([4]) + s = topi.generic.schedule_argsort(out) + f = tvm.build(s, [data, valid_count, out], "llvm") + ctx = tvm.cpu() + tvm_data = tvm.nd.array(np_data, ctx) + tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) + f(tvm_data, tvm_valid_count, tvm_out) + """ + data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) + if flag: + valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, + "valid_count_buf", data_alignment=4) + out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=8) + out = \ + tvm.extern(data.shape, + [data, valid_count], + lambda ins, outs: tvm.call_packed( + "tvm.contrib.sort.argsort_nms", ins[0], ins[1], + outs[0], axis, is_ascend), + dtype="int32", + in_buffers=[data_buf, valid_count_buf], + out_buffers=out_buf, + name="argsort_nms_cpu", + tag="argsort_nms_cpu") + else: + out_buf = api.decl_buffer(data.shape, dtype, "out_buf", data_alignment=8) + out = \ + tvm.extern(data.shape, + [data], + lambda ins, outs: tvm.call_packed( + "tvm.contrib.sort.argsort", ins[0], + outs[0], axis, is_ascend, dtype), + dtype=dtype, + in_buffers=[data_buf], + out_buffers=out_buf, + name="argsort_cpu", + tag="argsort_cpu") + return out diff --git a/topi/python/topi/vision/ssd/multibox.py b/topi/python/topi/vision/ssd/multibox.py index 7996690037530..ca1b4a9eb2687 100644 --- a/topi/python/topi/vision/ssd/multibox.py +++ b/topi/python/topi/vision/ssd/multibox.py @@ -308,7 +308,7 @@ def multibox_detection(cls_prob, loc_pred, anchor, clip=True, threshold=0.01, nm """ inter_out = multibox_transform_loc(cls_prob, loc_pred, anchor, clip, threshold, variances) - out = non_max_suppression(inter_out[0], inter_out[1], -1, - nms_threshold, force_suppress, nms_topk, - return_indices=False) + out = non_max_suppression(inter_out[0], inter_out[1], max_output_size=-1, + iou_threshold=nms_threshold, force_suppress=force_suppress, + top_k=nms_topk, return_indices=False) return out diff --git a/topi/tests/python/test_topi_vision.py b/topi/tests/python/test_topi_vision.py index 6bb57b541c881..979caba5b63c6 100644 --- a/topi/tests/python/test_topi_vision.py +++ b/topi/tests/python/test_topi_vision.py @@ -24,7 +24,7 @@ from tvm.contrib.pickle_memoize import memoize from topi.util import get_const_tuple -from topi.vision import ssd, non_max_suppression, get_valid_counts +from topi.vision import ssd, non_max_suppression, get_valid_counts, argsort def verify_get_valid_counts(dshape, score_threshold): @@ -66,7 +66,7 @@ def check_device(device): tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3) tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3) - for device in ['llvm']: + for device in ['llvm', 'cuda', 'opencl']: check_device(device) @@ -124,7 +124,7 @@ def check_device(device): f(tvm_data, tvm_valid_count, tvm_indices_out) tvm.testing.assert_allclose(tvm_indices_out.asnumpy(), np_indices_result, rtol=1e-4) - for device in ['llvm']: + for device in ['llvm', 'cuda', 'opencl']: check_device(device) @@ -231,7 +231,7 @@ def check_device(device): f(tvm_cls_prob, tvm_loc_preds, tvm_anchors, tvm_out) tvm.testing.assert_allclose(tvm_out.asnumpy(), expected_np_out, rtol=1e-4) - for device in ['llvm', 'opencl']: + for device in ['llvm', 'opencl', 'cuda']: check_device(device) @@ -275,7 +275,7 @@ def check_device(device): f(tvm_a, tvm_rois, tvm_b) tvm.testing.assert_allclose(tvm_b.asnumpy(), b_np, rtol=1e-3) - for device in ['llvm', 'cuda']: + for device in ['llvm', 'cuda', 'opencl']: check_device(device) @@ -397,6 +397,35 @@ def test_proposal(): verify_proposal(np_cls_prob, np_bbox_pred, np_im_info, np_out, attrs) +def test_argsort(): + dshape = (1, 8) + valid_count_shape = (2,) + data = tvm.placeholder(dshape, name="data", dtype="float32") + valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") + np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype) + np_valid_count = np.array([4]).astype(valid_count.dtype) + np_result = np.argsort(-np_data) + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + out = argsort(data, valid_count, axis = -1, is_ascend = False, flag=False) + s = topi.generic.schedule_argsort(out) + + tvm_data = tvm.nd.array(np_data, ctx) + tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx) + f = tvm.build(s, [data, valid_count, out], device) + f(tvm_data, tvm_valid_count, tvm_out) + tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0) + + for device in ['llvm', 'cuda', 'opencl']: + check_device(device) + + if __name__ == "__main__": test_get_valid_counts() test_non_max_suppression() @@ -404,3 +433,4 @@ def test_proposal(): test_multibox_detection() test_roi_align() test_proposal() + test_argsort() diff --git a/tutorials/frontend/deploy_ssd_gluoncv.py b/tutorials/frontend/deploy_ssd_gluoncv.py index fe84283ad1918..ff7691c7bf558 100644 --- a/tutorials/frontend/deploy_ssd_gluoncv.py +++ b/tutorials/frontend/deploy_ssd_gluoncv.py @@ -18,6 +18,7 @@ Deploy Single Shot Multibox Detector(SSD) model =============================================== **Author**: `Yao Wang `_ +`Leyuan Wang `_ This article is an introductory tutorial to deploy SSD models with TVM. We will use GluonCV pre-trained SSD model and convert it to Relay IR @@ -37,30 +38,29 @@ # ------------------------------ # .. note:: # -# Currently we support compiling SSD on CPU only. -# GPU support is in progress. +# We support compiling SSD on bot CPUs and GPUs now. # # To get best inference performance on CPU, change # target argument according to your device and # follow the :ref:`tune_relay_x86` to tune x86 CPU and # :ref:`tune_relay_arm` for arm cpu. # +# To get best performance fo SSD on Intel graphics, +# change target argument to 'opencl -device=intel_graphics' +# # SSD with VGG as body network is not supported yet since # x86 conv2d schedule doesn't support dilation. supported_model = [ - 'ssd_512_resnet18_v1_voc', - 'ssd_512_resnet18_v1_coco', 'ssd_512_resnet50_v1_voc', 'ssd_512_resnet50_v1_coco', 'ssd_512_resnet101_v2_voc', - 'ssd_512_mobilenet1_0_voc', - 'ssd_512_mobilenet1_0_coco', + 'ssd_512_mobilenet1.0_voc', + 'ssd_512_mobilenet1.0_coco', ] -model_name = "ssd_512_resnet50_v1_voc" +model_name = supported_model[0] dshape = (1, 3, 512, 512) -dtype = "float32" target_list = ctx_list() ###################################################################### @@ -76,7 +76,7 @@ block = model_zoo.get_model(model_name, pretrained=True) -def compile(target): +def build(target): net, params = relay.frontend.from_mxnet(block, {"data": dshape}) with relay.build_config(opt_level=3): graph, lib, params = relay.build(net, target, params=params) @@ -98,10 +98,7 @@ def run(graph, lib, params, ctx): return class_IDs, scores, bounding_boxs for target, ctx in target_list: - if target == "cuda": - print("GPU not supported yet, skip.") - continue - graph, lib, params = compile(target) + graph, lib, params = build(target) class_IDs, scores, bounding_boxs = run(graph, lib, params, ctx) ######################################################################