diff --git a/python/tvm/relay/op/reduce.py b/python/tvm/relay/op/reduce.py index 71c7dea9c0dc..a4d5f66c009d 100644 --- a/python/tvm/relay/op/reduce.py +++ b/python/tvm/relay/op/reduce.py @@ -12,8 +12,8 @@ def argmax(data, axis=None, keepdims=False, exclude=False): The input data axis : None or int or tuple of int - Axis or axes along which a argmin operation is performed. - The default, axis=None, will find the indices of maximum element all of the elements of + Axis or axes along which a argmax operation is performed. + The default, axis=None, will find the indices of the maximum element of the elements of the input array. If axis is negative it counts from the last to the first axis. keepdims : bool @@ -73,14 +73,14 @@ def sum(data, axis=None, keepdims=False, exclude=False): The input data axis : None or int or tuple of int - Axis or axes along which a argmin operation is performed. - The default, axis=None, will find the indices of minimum element all of the elements of - the input array. If axis is negative it counts from the last to the first axis. + Axis or axes along which a sum is performed. The default, axis=None, + will sum all of the elements of the input array. If axis is + negative it counts from the last to the first axis. keepdims : bool - If this is set to True, the axes which are reduced are left in the result as dimensions - with size one. - With this option, the result will broadcast correctly against the input array. + If this is set to True, the axes which are reduced are left in the result as + dimensions with size one. With this option, the result will broadcast + correctly against the input array. exclude : bool If `exclude` is true, reduction will be performed on the axes that are @@ -91,7 +91,7 @@ def sum(data, axis=None, keepdims=False, exclude=False): result : relay.Expr The computed result. """ - axis = [axis] if isinstance(axis, int) else axis + axis = [axis] if axis and isinstance(axis, int) else axis return _make.sum(data, axis, keepdims, exclude) @@ -104,9 +104,9 @@ def max(data, axis=None, keepdims=False, exclude=False): The input data axis : None or int or tuple of int - Axis or axes along which a argmin operation is performed. - The default, axis=None, will find the indices of minimum element all of the elements of - the input array. If axis is negative it counts from the last to the first axis. + Axis or axes along which the max operation is performed. + The default, axis=None, will find the max element from all of the elements of the input + array. If axis is negative it counts from the last to the first axis. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions @@ -135,9 +135,10 @@ def min(data, axis=None, keepdims=False, exclude=False): The input data axis : None or int or tuple of int - Axis or axes along which a argmin operation is performed. - The default, axis=None, will find the indices of minimum element all of the elements of - the input array. If axis is negative it counts from the last to the first axis. + Axis or axes along which a minimum operation is performed. + The default, axis=None, will find the minimum element from all + of the elements of the input array. If axis is negative it counts from + the last to the first axis. keepdims : bool If this is set to True, the axes which are reduced are left in the result as dimensions @@ -166,7 +167,7 @@ def mean(data, axis=None, keepdims=False, exclude=False): The input data axis : None or int or tuple of int - Axis or axes along which a argmin operation is performed. + Axis or axes along which a mean operation is performed. The default, axis=None, will find the indices of minimum element all of the elements of the input array. If axis is negative it counts from the last to the first axis. @@ -197,7 +198,7 @@ def prod(data, axis=None, keepdims=False, exclude=False): The input data axis : None or int or tuple of int - Axis or axes along which a argmin operation is performed. + Axis or axes along which a product is performed. The default, axis=None, will find the indices of minimum element all of the elements of the input array. If axis is negative it counts from the last to the first axis. diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 45d6d36fdc20..ae7fe320940a 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -180,6 +180,7 @@ def _wrapper(data, axis=None, keepdims=False): [relay.prod, np.prod], [relay.argmin, _with_keepdims(np.argmin)], [relay.argmax, _with_keepdims(np.argmax)]]: + verify_reduce(func, (d1, d2, d3, d4), None, False, False, ()) verify_reduce(func, (d1, d2, d3, d4), 2, True, False, (d1, d2, 1, d4)) verify_reduce(func, (d1, d2, d3), 1, True, False, (d1, 1, d3)) verify_reduce(func, (d1, d2, d3), None, True, False, (1, 1, 1))