Skip to content

Commit

Permalink
Add test case of argmax for detecting out of bound access (apache#2234)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene authored and tqchen committed Dec 6, 2018
1 parent 374918f commit 4e21982
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,28 @@ def test_where():
y = np.random.uniform(size=shape).astype("float32")
verify_where(condition, x, y)

def test_argmax():
dshape = (204800, 2)
oshape = (1, 320, 640)

dtype = "float32"
x = sym.Variable("x", shape=dshape, dtype=dtype)
x = sym.reshape(x, shape=(1, 320, 640, 2))
x = sym.transpose(x, axes=(0, 3, 1, 2))
y = sym.argmax(x, axis=1)
target_str = "llvm"
target = tvm.target.create(target_str)
ctx = tvm.context(target_str, 0)
with nnvm.compiler.build_config(opt_level=2):
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = graph_runtime.create(graph, lib, ctx)
data = np.random.uniform(size=dshape).astype(dtype)
m.run(x=data)
np_reshape = np.reshape(data, (1, 320, 640, 2))
np_transpose = np.transpose(np_reshape, axes=(0, 3, 1, 2))
np_argmax = np.argmax(np_transpose, axis=1)
out = m.get_output(0)
np.testing.assert_allclose(out.asnumpy(), np_argmax, atol=1e-5, rtol=1e-5)

if __name__ == "__main__":
test_reshape()
Expand All @@ -707,4 +729,5 @@ def test_where():
test_nms()
test_slice_like()
test_where()
test_argmax()
print(nnvm.compiler.engine.dump())

0 comments on commit 4e21982

Please sign in to comment.