From 4e21982bc6be2a13eac70ba2896541cc1a9b93fa Mon Sep 17 00:00:00 2001 From: Wu Zhao Date: Thu, 6 Dec 2018 22:30:44 +0800 Subject: [PATCH] Add test case of argmax for detecting out of bound access (#2234) --- nnvm/tests/python/compiler/test_top_level4.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/nnvm/tests/python/compiler/test_top_level4.py b/nnvm/tests/python/compiler/test_top_level4.py index 46383e73657e..fc4e62fb7156 100644 --- a/nnvm/tests/python/compiler/test_top_level4.py +++ b/nnvm/tests/python/compiler/test_top_level4.py @@ -686,6 +686,28 @@ def test_where(): y = np.random.uniform(size=shape).astype("float32") verify_where(condition, x, y) +def test_argmax(): + dshape = (204800, 2) + oshape = (1, 320, 640) + + dtype = "float32" + x = sym.Variable("x", shape=dshape, dtype=dtype) + x = sym.reshape(x, shape=(1, 320, 640, 2)) + x = sym.transpose(x, axes=(0, 3, 1, 2)) + y = sym.argmax(x, axis=1) + target_str = "llvm" + target = tvm.target.create(target_str) + ctx = tvm.context(target_str, 0) + with nnvm.compiler.build_config(opt_level=2): + graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape}) + m = graph_runtime.create(graph, lib, ctx) + data = np.random.uniform(size=dshape).astype(dtype) + m.run(x=data) + np_reshape = np.reshape(data, (1, 320, 640, 2)) + np_transpose = np.transpose(np_reshape, axes=(0, 3, 1, 2)) + np_argmax = np.argmax(np_transpose, axis=1) + out = m.get_output(0) + np.testing.assert_allclose(out.asnumpy(), np_argmax, atol=1e-5, rtol=1e-5) if __name__ == "__main__": test_reshape() @@ -707,4 +729,5 @@ def test_where(): test_nms() test_slice_like() test_where() + test_argmax() print(nnvm.compiler.engine.dump())