From 6c0d24af4bddd57f2ca9b5d86ccf943eebd7397a Mon Sep 17 00:00:00 2001 From: padreofthegame Date: Fri, 10 Mar 2023 13:24:53 +0100 Subject: [PATCH] [Fix][Relay][TOPI] Bug fix in relay.sum and topi.sum functions when working with boolean tensor --- include/tvm/topi/reduction.h | 6 +++- tests/python/relay/test_op_level4.py | 36 +++++++++++++++++++- tests/python/topi/python/test_topi_reduce.py | 12 +++++-- 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h index 5e79bd429d6f..169ae010aa98 100644 --- a/include/tvm/topi/reduction.h +++ b/include/tvm/topi/reduction.h @@ -325,7 +325,11 @@ inline PrimExpr ProdOp(PrimExpr source, Array axis, Array ini */ inline Tensor sum(const Tensor& data, const Array& axis, bool keepdims = false, bool atleast1d = false) { - return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); + if (data->dtype.is_bool()) { + return CommReduce(data, axis, tvm::any, keepdims, atleast1d); + } else { + return CommReduce(data, axis, tvm::sum, keepdims, atleast1d); + } } inline Tensor collapse_sum(const Tensor& data, Array target_shape) { diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c4207b158c94..c2877c5cda55 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -369,6 +369,41 @@ def test_reduce( tvm.testing.assert_allclose(op_res1.numpy(), ref_res, rtol=1e-5) +@tvm.testing.uses_gpu +def test_sum_with_bool_input(): + def verify(dshape, axis, keepdims, exclude): + x = relay.var("x", relay.TensorType(dshape, "bool")) + + y = relay.sum(x, axis, keepdims, exclude) + + func = relay.Function([x], y) + func = run_infer_type(func) + + text = func.astext() + assert "sum" in text + + data = np.random.choice([False, True], size=dshape) + + if exclude and axis is not None: + axis = tuple(set(range(len(dshape))) - set(axis)) + + ref_res = np.sum(data, axis, keepdims=keepdims, dtype="bool") + for target, dev in tvm.testing.enabled_targets(): + op_res = relay.create_executor("graph", device=dev, target=target).evaluate(func)(data) + tvm.testing.assert_allclose(op_res.numpy(), ref_res) + + verify((3, 5, 7, 9), None, False, False) + verify((3, 5, 7, 9), None, True, False) + verify((3, 5, 7, 9), (0,), False, False) + verify((3, 5, 7, 9), (1,), True, False) + verify((3, 5, 7, 9), (2, 3), False, True) + verify((3, 5, 7, 9), (0, 2), True, True) + verify((3, 5, 7, 9), (0, 1, 2, 3), False, False) + verify((3, 5, 7, 9), (0, 1, 2, 3), False, True) + verify((3, 5, 7, 9), (0, 1, 2, 3), True, False) + verify((3, 5, 7, 9), (0, 1, 2, 3), True, True) + + @tvm.testing.uses_gpu def test_argmin_argmax_get_last_elements(): def get_test_case(shape, gt_func, test_argmin=False): @@ -638,7 +673,6 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): func = run_infer_type(func) text = func.astext() assert "strided_set" in text - print(text) assert func.body.checked_type == relay.ty.TensorType(dshape, "float32") if not test_ref: return diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py index e7f47ba0c4db..3c4c170d0dd9 100644 --- a/tests/python/topi/python/test_topi_reduce.py +++ b/tests/python/topi/python/test_topi_reduce.py @@ -44,6 +44,8 @@ ((32, 128, 24), None, True, "any", "bool"), ((1, 4, 7), 1, True, "any", "bool"), ((128, 24, 128, 24), 2, False, "any", "bool"), + ((128, 24, 128, 24), 2, False, "sum", "bool"), + ((128, 24, 128, 24), 0, True, "sum", "bool"), ) @@ -57,7 +59,10 @@ def ref_data(in_shape, axis, keepdims, reduce_type, dtype): in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype) if reduce_type == "sum": - out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) + if dtype == "bool": + out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims, dtype="bool") + else: + out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims) elif reduce_type == "all" and dtype == "bool": out_npy = in_npy_map.all(axis=axis, keepdims=keepdims) elif reduce_type == "any" and dtype == "bool": @@ -113,7 +118,10 @@ def test_reduce_map(target, dev, ref_data, in_shape, axis, keepdims, reduce_type A1 = topi.sqrt(topi.exp(A)) out_dtype = dtype if reduce_type == "sum": - B = topi.sum(A1, axis=axis, keepdims=keepdims) + if dtype == "bool": + B = topi.sum(A, axis=axis, keepdims=keepdims) + else: + B = topi.sum(A1, axis=axis, keepdims=keepdims) elif reduce_type == "all": B = topi.all(A, axis=axis, keepdims=keepdims) elif reduce_type == "any":