From d6bfd938f17d9b95ebcb42c5e8e7df6788ca2bde Mon Sep 17 00:00:00 2001 From: Leyuan Wang Date: Tue, 19 Mar 2019 22:05:01 +0000 Subject: [PATCH] fix lint --- topi/python/topi/cuda/sort.py | 7 +++-- topi/python/topi/sort.py | 49 +++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/topi/python/topi/cuda/sort.py b/topi/python/topi/cuda/sort.py index 6d13c319f352..8918d3ed78fc 100644 --- a/topi/python/topi/cuda/sort.py +++ b/topi/python/topi/cuda/sort.py @@ -42,7 +42,7 @@ def sort_ir(data, valid_count, output, axis, is_ascend, flag): axis_mul_after = 1 shape = data.shape if axis < 0: - axis = len(shape) + axis; + axis = len(shape) + axis for i in range(0, len(shape)): size *= shape[i] if i < axis: @@ -112,12 +112,12 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, flag=0): out : tvm.Tensor The output of this function. """ - data_buf = api.decl_buffer(data.shape, data.dtype,"data_buf", data_alignment=8) + data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8) valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype, "valid_count_buf", data_alignment=4) out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4) - out = tvm.extern([data.shape], + out = tvm.extern([data.shape], [data, valid_count], lambda ins, outs: sort_ir( ins[0], ins[1], outs[0], axis, bool(is_ascend), bool(flag)), @@ -127,4 +127,3 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, flag=0): name="argsort_gpu", tag="argsort_gpu") return out - diff --git a/topi/python/topi/sort.py b/topi/python/topi/sort.py index 8f7d70db62fc..653ed6f6a0a6 100644 --- a/topi/python/topi/sort.py +++ b/topi/python/topi/sort.py @@ -3,6 +3,55 @@ @tvm.target.generic_func def argsort(data, valid_count, axis=-1, is_ascend=1, flag=0): + """Performs sorting along the given axis and returns an array + of indices having the same shape as an input array that index + data in sorted order. + + Parameters + ---------- + data : tvm.Tensor + The input tensor. + + valid_count : tvm.Tensor + 1-D tensor for valid number of boxes only for ssd. + + axis : optional, int + Axis along which to sort the input tensor. + By default the flattened array is used. + + is_ascend : optional, boolean + Whether to sort in ascending or descending order. + + flag : optional, boolean + Whether valid_count is valid. + + Returns + ------- + out : tvm.Tensor + Sorted index tensor. + + Example + -------- + .. code-block:: python + + # An example to use argsort + dshape = (1, 5, 6) + data = tvm.placeholder(dshape, name="data") + valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count") + axis = 0 + is_ascend = False + flag = False + out = argsort(data, valid_count, axis, is_ascend, flag) + np_data = np.random.uniform(dshape) + np_valid_count = np.array([4]) + s = topi.generic.schedule_argsort(out) + f = tvm.build(s, [data, valid_count, out], "llvm") + ctx = tvm.cpu() + tvm_data = tvm.nd.array(np_data, ctx) + tvm_valid_count = tvm.nd.array(np_valid_count, ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx) + f(tvm_data, tvm_valid_count, tvm_out) + """ data_buf = api.decl_buffer(data.shape, data.dtype, "sort_data_buf", data_alignment=8) valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,