Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Laurawly committed Mar 19, 2019
1 parent 2ceac81 commit d6bfd93
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 4 deletions.
7 changes: 3 additions & 4 deletions topi/python/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def sort_ir(data, valid_count, output, axis, is_ascend, flag):
axis_mul_after = 1
shape = data.shape
if axis < 0:
axis = len(shape) + axis;
axis = len(shape) + axis
for i in range(0, len(shape)):
size *= shape[i]
if i < axis:
Expand Down Expand Up @@ -112,12 +112,12 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, flag=0):
out : tvm.Tensor
The output of this function.
"""
data_buf = api.decl_buffer(data.shape, data.dtype,"data_buf", data_alignment=8)
data_buf = api.decl_buffer(data.shape, data.dtype, "data_buf", data_alignment=8)
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
"valid_count_buf", data_alignment=4)
out_buf = api.decl_buffer(data.shape, "int32", "out_buf", data_alignment=4)

out = tvm.extern([data.shape],
out = tvm.extern([data.shape],
[data, valid_count],
lambda ins, outs: sort_ir(
ins[0], ins[1], outs[0], axis, bool(is_ascend), bool(flag)),
Expand All @@ -127,4 +127,3 @@ def argsort_gpu(data, valid_count, axis=-1, is_ascend=1, flag=0):
name="argsort_gpu",
tag="argsort_gpu")
return out

49 changes: 49 additions & 0 deletions topi/python/topi/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,55 @@

@tvm.target.generic_func
def argsort(data, valid_count, axis=-1, is_ascend=1, flag=0):
"""Performs sorting along the given axis and returns an array
of indices having the same shape as an input array that index
data in sorted order.
Parameters
----------
data : tvm.Tensor
The input tensor.
valid_count : tvm.Tensor
1-D tensor for valid number of boxes only for ssd.
axis : optional, int
Axis along which to sort the input tensor.
By default the flattened array is used.
is_ascend : optional, boolean
Whether to sort in ascending or descending order.
flag : optional, boolean
Whether valid_count is valid.
Returns
-------
out : tvm.Tensor
Sorted index tensor.
Example
--------
.. code-block:: python
# An example to use argsort
dshape = (1, 5, 6)
data = tvm.placeholder(dshape, name="data")
valid_count = tvm.placeholder((dshape[0],), dtype="int32", name="valid_count")
axis = 0
is_ascend = False
flag = False
out = argsort(data, valid_count, axis, is_ascend, flag)
np_data = np.random.uniform(dshape)
np_valid_count = np.array([4])
s = topi.generic.schedule_argsort(out)
f = tvm.build(s, [data, valid_count, out], "llvm")
ctx = tvm.cpu()
tvm_data = tvm.nd.array(np_data, ctx)
tvm_valid_count = tvm.nd.array(np_valid_count, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data.dtype), ctx)
f(tvm_data, tvm_valid_count, tvm_out)
"""
data_buf = api.decl_buffer(data.shape, data.dtype,
"sort_data_buf", data_alignment=8)
valid_count_buf = api.decl_buffer(valid_count.shape, valid_count.dtype,
Expand Down

0 comments on commit d6bfd93

Please sign in to comment.