Skip to content

Commit

Permalink
[TEST][FLAKY] topi/tests/python/test_topi_sort.py::test_argsort (apac…
Browse files Browse the repository at this point in the history
…he#4891)

* [TEST][FLAKY] topi/tests/python/test_topi_sort.py::test_argsort

* upadate test function of argsort like topk

* Shuffle index and get data from shuffled index

* Replace the random.uniform with np.arange
  • Loading branch information
cchung100m authored Feb 21, 2020
1 parent f47c38d commit 8290eab
Showing 1 changed file with 31 additions and 7 deletions.
38 changes: 31 additions & 7 deletions topi/tests/python/test_topi_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,46 @@
import topi
import topi.testing

def test_argsort():

def verify_argsort(axis, is_ascend):
dshape = (20, 100)
data = tvm.placeholder(dshape, name="data", dtype="float32")
np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype)
np_result = np.argsort(-np_data)
data_dtype = "float32"
data = tvm.placeholder(dshape, name="data", dtype=data_dtype)

perm = np.arange(dshape[0] * dshape[1], dtype=data_dtype)
np.random.shuffle(perm)
np_data = perm.reshape(dshape)

if is_ascend:
np_indices = np.argsort(np_data, axis=axis)
else:
np_indices = np.argsort(-np_data, axis=axis)

if axis == 0:
np_indices = np_indices[:dshape[axis], :]
else:
np_indices = np_indices[:, :dshape[axis]]

def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
out = topi.argsort(data, axis=-1, is_ascend=False)
out = topi.argsort(data, axis=axis, is_ascend=is_ascend)
s = topi.generic.schedule_argsort(out)

tvm_data = tvm.nd.array(np_data, ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx)
tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), ctx)
f = tvm.build(s, [data, out], device)
f(tvm_data, tvm_out)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0)
tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0)

for device in ['llvm', 'cuda', 'opencl']:
check_device(device)


def verify_topk(k, axis, ret_type, is_ascend, dtype):
shape = (20, 100)
data_dtype = "float32"
Expand Down Expand Up @@ -95,6 +111,14 @@ def check_device(device):
for device in ['llvm', 'cuda', 'opencl']:
check_device(device)


def test_argsort():
np.random.seed(0)
for axis in [0, -1, 1]:
verify_argsort(axis, True)
verify_argsort(axis, False)


def test_topk():
np.random.seed(0)
for k in [0, 1, 5]:
Expand Down

0 comments on commit 8290eab

Please sign in to comment.