From 1e9d014b3f648c6b977e86deb9275524d000eac4 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Thu, 27 Jun 2019 10:03:29 -0700 Subject: [PATCH] [Relay] Fix reduce axis bug (#3422) * fix relay reduce axis bug * add tests for reduce bug --- python/tvm/relay/op/reduce.py | 4 ++-- tests/python/relay/test_op_level4.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 0f2594600b0a..41e1fc041cce 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -107,7 +107,7 @@ def sum(data, axis=None, keepdims=False, exclude=False): result : relay.Expr The computed result. """ - axis = [axis] if axis and isinstance(axis, int) else axis + axis = [axis] if isinstance(axis, int) else axis return _make.sum(data, axis, keepdims, exclude) @@ -159,7 +159,7 @@ def all(data, axis=None, keepdims=False, exclude=False): # [False, True, False]] """ - axis = [axis] if axis and isinstance(axis, int) else axis + axis = [axis] if isinstance(axis, int) else axis return _make.all(data, axis, keepdims, exclude) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index aac4a6d4af16..da0fe01063f4 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -202,7 +202,9 @@ def _wrapper(data, axis=None, keepdims=False): [relay.argmax, _with_keepdims(np.argmax)]]: verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4)) + verify_reduce(func, (d1, d2, d3, d4), 0, True, False, (1, d2, d3, d4)) verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3)) + verify_reduce(func, (d1, d2, d3), 0, True, False, (1, d2, d3)) verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1)) verify_reduce(func, (d1, d2, d3), (0, 1), True, False, (1, 1, d3)) verify_reduce(func, (2, 3, 4), 1, True, False, (2, 1, 4))