From 8290eabafb22a8a54249165ce09dd31ff1f31c7c Mon Sep 17 00:00:00 2001 From: Neo Chien Date: Sat, 22 Feb 2020 05:03:41 +0800 Subject: [PATCH] [TEST][FLAKY] topi/tests/python/test_topi_sort.py::test_argsort (#4891) * [TEST][FLAKY] topi/tests/python/test_topi_sort.py::test_argsort * upadate test function of argsort like topk * Shuffle index and get data from shuffled index * Replace the random.uniform with np.arange --- topi/tests/python/test_topi_sort.py | 38 +++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/topi/tests/python/test_topi_sort.py b/topi/tests/python/test_topi_sort.py index c084a7c431b6..0ad4e987d17d 100644 --- a/topi/tests/python/test_topi_sort.py +++ b/topi/tests/python/test_topi_sort.py @@ -21,11 +21,26 @@ import topi import topi.testing -def test_argsort(): + +def verify_argsort(axis, is_ascend): dshape = (20, 100) - data = tvm.placeholder(dshape, name="data", dtype="float32") - np_data = np.random.rand(dshape[0], dshape[1]).astype(data.dtype) - np_result = np.argsort(-np_data) + data_dtype = "float32" + data = tvm.placeholder(dshape, name="data", dtype=data_dtype) + + perm = np.arange(dshape[0] * dshape[1], dtype=data_dtype) + np.random.shuffle(perm) + np_data = perm.reshape(dshape) + + if is_ascend: + np_indices = np.argsort(np_data, axis=axis) + else: + np_indices = np.argsort(-np_data, axis=axis) + + if axis == 0: + np_indices = np_indices[:dshape[axis], :] + else: + np_indices = np_indices[:, :dshape[axis]] + def check_device(device): ctx = tvm.context(device, 0) if not ctx.exist: @@ -33,18 +48,19 @@ def check_device(device): return print("Running on target: %s" % device) with tvm.target.create(device): - out = topi.argsort(data, axis=-1, is_ascend=False) + out = topi.argsort(data, axis=axis, is_ascend=is_ascend) s = topi.generic.schedule_argsort(out) tvm_data = tvm.nd.array(np_data, ctx) - tvm_out = tvm.nd.array(np.zeros(dshape, dtype="float32"), ctx) + tvm_out = tvm.nd.array(np.zeros(dshape, dtype=data_dtype), ctx) f = tvm.build(s, [data, out], device) f(tvm_data, tvm_out) - tvm.testing.assert_allclose(tvm_out.asnumpy(), np_result.astype("float32"), rtol=1e0) + tvm.testing.assert_allclose(tvm_out.asnumpy(), np_indices.astype(data_dtype), rtol=1e0) for device in ['llvm', 'cuda', 'opencl']: check_device(device) + def verify_topk(k, axis, ret_type, is_ascend, dtype): shape = (20, 100) data_dtype = "float32" @@ -95,6 +111,14 @@ def check_device(device): for device in ['llvm', 'cuda', 'opencl']: check_device(device) + +def test_argsort(): + np.random.seed(0) + for axis in [0, -1, 1]: + verify_argsort(axis, True) + verify_argsort(axis, False) + + def test_topk(): np.random.seed(0) for k in [0, 1, 5]: