diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 20eb95ba7c00..b03df22a8bc3 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1166,6 +1166,125 @@ def _impl(inputs, attr, params, mod): return _impl +def _math_segment_sum(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 2, "There should be 2 input tensors" + return get_relay_op("segment_sum")(inputs[0], inputs[1]) + + return _impl + + +def _sparse_segment_sum(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + return _op.segment_sum(data, inputs[2]) + + return _impl + + +def _sparse_segment_sum_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.asnumpy().item()) + return _op.segment_sum(data, inputs[2], num_segments) + + return _impl + + +def row_wise_divide(multi_dim_tensor, one_dim_vector): + """ + This function enables row-wise division of multi_dim_tensor and one_dim_vector. + To achieve this, it is first tiled to the appropriate shape and then elemwise_division + """ + multi_dim_tensor_offrow_shape = _op.strided_slice( + _op.shape_of(multi_dim_tensor, "int32"), [1], [-1], slice_mode="size" + ) + one_dim_vector_tiled_shape = _op.concatenate( + [_op.reverse(multi_dim_tensor_offrow_shape, 0), _expr.const([1])], axis=0 + ) + one_dim_vector_tiled = _op.transpose(_op.tile(one_dim_vector, one_dim_vector_tiled_shape)) + return _op.divide(multi_dim_tensor, one_dim_vector_tiled) + + +def count_all_indices(segment_ids, counts_dtype, num_segments=None): + """ + This snippet calculates the sqrt count of each index among all valid indices + Valid indices are from 0 to max of [segment ids, num_segments] + """ + + max_segments = _op.reshape(_op.max(segment_ids), -1) + _expr.const([1]) + if num_segments: + max_segments = _op.maximum(max_segments, _expr.const([num_segments])) + max_ones = _op.maximum(max_segments, _op.shape_of(segment_ids)) + counts = _op.segment_sum( + _op.ones(max_ones, counts_dtype), segment_ids, num_segments=num_segments + ) + real_counts = _op.clip(counts, 1, 2147483647) # Clip max doesn't work over int32 + return real_counts + + +def _sparse_segment_sum_sqrtn(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + real_counts = count_all_indices(inputs[2], attr["T"].name) + real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2]) + + return row_wise_divide(segment_sum, real_sqrt_counts) + + return _impl + + +def _sparse_segment_sum_sqrtn_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.asnumpy().item()) + real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) + real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data)) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) + + return row_wise_divide(segment_sum, real_sqrt_counts) + + return _impl + + +def _sparse_segment_mean(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 3, "There should be 3 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + real_counts = count_all_indices(inputs[2], attr["T"].name) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2]) + + return row_wise_divide(segment_sum, real_counts) + + return _impl + + +def _sparse_segment_mean_with_num_segments(): + def _impl(inputs, attr, params, mod): + assert len(inputs) == 4, "There should be 4 input tensors" + data = _op.take(inputs[0], inputs[1], axis=0) + num_segments = int(inputs[3].data.asnumpy().item()) + real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments) + + # Calculate regular segment sum + segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments) + + return row_wise_divide(segment_sum, real_counts) + + return _impl + + def _identity(): def _impl(inputs, attr, params, mod): return inputs[0] @@ -2660,6 +2779,13 @@ def _impl(inputs, attr, params, mod): "SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(), "SparseFillEmptyRows": _sparse_fill_empty_rows(), "SparseReshape": _sparse_reshape(), + "SegmentSum": _math_segment_sum(), + "SparseSegmentSum": _sparse_segment_sum(), + "SparseSegmentSumWithNumSegments": _sparse_segment_sum_with_num_segments(), + "SparseSegmentSqrtN": _sparse_segment_sum_sqrtn(), + "SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(), + "SparseSegmentMean": _sparse_segment_mean(), + "SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(), "Split": _split(False), "SplitV": _split(True), "Sqrt": AttrCvt("sqrt"), diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 73508ddd2603..f2e3850a8f67 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -1450,6 +1450,75 @@ def sparse_reshape(sparse_indices, prev_shape, new_shape): return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2) +def segment_sum(data, segment_ids, num_segments=None): + """ + Computes the sum along segment_ids along axis 0. If multiple segment_ids reference the same + location their contributions add up. + result[index, j, k, ...] = Σi... data[i, j, k,..] where index = segment_ids[i] + This op is much better understood with visualization articulated in the following links and + examples at the end of this docstring. + + https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum + https://caffe2.ai/docs/sparse-operations.html#null__unsorted-segment-reduction-ops + + Parameters + ---------- + data : relay.Expr + Input Tensor. It can be of any type and multi-dimensional + segment_ids : relay.Expr + A 1-D int32/int64 tensor containing the segment_ids of the rows to calculate the output + sum upon. It defines a mapping from the zeroth dimension of data onto segment_ids. The + segment_ids tensor should be the size of the first dimension, d0, with consecutive IDs + in the range 0 to k, where k= 0 @@ -72,7 +72,7 @@ def _scatter_add_3d(data, indices, updates, axis): elif axis == 1: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): + for k in range(indices.shape[2]): out[ i, indices[i, j, k] @@ -83,7 +83,7 @@ def _scatter_add_3d(data, indices, updates, axis): else: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): + for k in range(indices.shape[2]): out[ i, j, @@ -98,17 +98,17 @@ def _scatter_add_3d(data, indices, updates, axis): @hybrid.script def _scatter_add_4d(data, indices, updates, axis): out = output_tensor(data.shape, data.dtype) - for i in const_range(data.shape[0]): - for j in const_range(data.shape[1]): - for k in const_range(data.shape[2]): - for l in const_range(data.shape[3]): + for i in range(data.shape[0]): + for j in range(data.shape[1]): + for k in range(data.shape[2]): + for l in range(data.shape[3]): out[i, j, k, l] = data[i, j, k, l] if axis == 0: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ indices[i, j, k, l] if indices[i, j, k, l] >= 0 @@ -120,8 +120,8 @@ def _scatter_add_4d(data, indices, updates, axis): elif axis == 1: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ i, indices[i, j, k, l] @@ -133,8 +133,8 @@ def _scatter_add_4d(data, indices, updates, axis): elif axis == 2: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ i, j, @@ -146,8 +146,8 @@ def _scatter_add_4d(data, indices, updates, axis): else: for i in range(indices.shape[0]): for j in range(indices.shape[1]): - for k in const_range(indices.shape[2]): - for l in const_range(indices.shape[3]): + for k in range(indices.shape[2]): + for l in range(indices.shape[3]): out[ i, j, diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 41145bf77218..81aeb5ef886c 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2080,6 +2080,181 @@ def test_forward_sparse_reshape( _test_sparse_reshape(sparse_indices_np, sparse_values_np, prev_shape_np, new_shape_np, use_dyn) +####################################################################### +# Sparse Segment Variants +# ------------ + + +def _test_sparse_segment_variant( + tf_op, data_np, indices_np, segment_ids_np, num_segments, use_dyn=False +): + with tf.Graph().as_default(): + if use_dyn: + data = tf.placeholder( + shape=[None for _ in data_np.shape], dtype=data_np.dtype, name="data" + ) + indices = tf.placeholder(shape=[None], dtype=indices_np.dtype, name="indices") + segment_ids = tf.placeholder( + shape=(None), dtype=segment_ids_np.dtype, name="segment_ids" + ) + else: + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name="data") + indices = tf.placeholder(shape=indices_np.shape, dtype=indices_np.dtype, name="indices") + segment_ids = tf.placeholder( + shape=segment_ids_np.shape, dtype=segment_ids_np.dtype, name="segment_ids" + ) + + _ = tf_op( + data, indices, segment_ids, num_segments=num_segments, name="sparse_segment_variant" + ) + compare_tf_with_tvm( + [data_np, indices_np, segment_ids_np], + [data.name, indices.name, segment_ids.name], + ["sparse_segment_variant:0"], + mode="vm", + ) + + +@pytest.mark.parametrize( + "data_np, indices_np, segment_ids_np, num_segments", + [ + ( + np.array([5, 1, 7, 2, 3, 4], dtype=np.float32), + np.array([0, 3, 4], dtype=np.int32), + np.array([0, 1, 1], dtype=np.int32), + None, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 1], dtype=np.int32), + np.array([0, 2], dtype=np.int32), + 4, + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 2, 4, 3, 1], dtype=np.int32), + np.array([0, 0, 1, 5, 5], dtype=np.int32), + 100, + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 2, 4, 3, 1], dtype=np.int32), + np.array([0, 0, 1, 5, 5], dtype=np.int32), + None, + ), + ( + np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float64), + np.array([0, 1, 2], dtype=np.int32), + np.array([0, 0, 1], dtype=np.int32), + None, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32), + np.array([0, 0, 1, 3, 5, 6, 7, 7, 8], dtype=np.int32), + 9, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32), + np.array([0, 0, 1, 3, 5, 6, 7, 7, 8], dtype=np.int32), + None, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 1], dtype=np.int32), + np.array([0, 2], dtype=np.int32), + None, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32), + np.array([0, 0, 1, 3, 5, 5, 5, 5, 5], dtype=np.int32), + 6, + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +@pytest.mark.parametrize( + "tf_op", + [ + tf.sparse.segment_sum, + tf.sparse.segment_sqrt_n, + tf.sparse.segment_mean, + ], +) +def test_forward_sparse_segment_sum_variants( + tf_op, + data_np, + indices_np, + segment_ids_np, + num_segments, + use_dyn, +): + """sparse segment sum variants tests""" + _test_sparse_segment_variant(tf_op, data_np, indices_np, segment_ids_np, num_segments, use_dyn) + + +####################################################################### +# Math SegmentSum +# ------------ + + +def _test_math_segment_sum(data_np, segment_ids_np, use_dyn=False): + with tf.Graph().as_default(): + if use_dyn: + data = tf.placeholder( + shape=[None for _ in data_np.shape], dtype=data_np.dtype, name="data" + ) + segment_ids = tf.placeholder( + shape=(None), dtype=segment_ids_np.dtype, name="segment_ids" + ) + else: + data = tf.placeholder(shape=data_np.shape, dtype=data_np.dtype, name="data") + segment_ids = tf.placeholder( + shape=segment_ids_np.shape, dtype=segment_ids_np.dtype, name="segment_ids" + ) + + _ = tf.math.segment_sum(data, segment_ids, name="segment_sum") + compare_tf_with_tvm( + [data_np, segment_ids_np], + [data.name, segment_ids.name], + ["segment_sum:0"], + mode="vm", + ) + + +@pytest.mark.parametrize( + "data_np, segment_ids_np", + [ + ( + np.array([5, 1, 7, 2, 3, 4], dtype=np.float32), + np.array([0, 0, 0, 1, 1, 1], dtype=np.int32), + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 0, 1], dtype=np.int32), + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 0, 1, 2, 2, 3], dtype=np.int64), + ), + ( + np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float32), + np.array([0, 0, 1], dtype=np.int32), + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([0, 0, 0, 1, 2, 3, 4, 4, 5], dtype=np.int64), + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_forward_math_segment_sum(data_np, segment_ids_np, use_dyn): + """math segment sum test""" + _test_math_segment_sum(data_np, segment_ids_np, use_dyn) + + # tensorflow.compat.v1.sparse_to_dense # --------------- def _test_sparse_to_dense(sparse_indices, sparse_values, default_value, output_shape): diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index c9ed975c3b9b..31b95b0b49ae 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -24,6 +24,7 @@ from tvm.error import TVMError from tvm.relay import create_executor, transform from tvm.relay.testing import check_grad, run_infer_type +from typing import Optional import tvm.testing @@ -1023,7 +1024,25 @@ def verify_dynamic_scatter(dshape, ishape, axis=0): @tvm.testing.uses_gpu -def test_scatter_add(): +@pytest.mark.parametrize( + "dshape, ishape, axis, dtype", + [ + ((10,), (10,), 0, "int32"), + ((1000,), (1000,), 0, "int32"), + ((10, 5), (10, 5), -2, "float32"), + ((10, 5), (10, 5), -1, "float32"), + ((10, 5), (3, 5), 0, "float32"), + ((12, 4), (7, 2), 1, "float32"), + ((2, 3, 4), (1, 3, 4), 0, "float32"), + ((2, 3, 4), (2, 1, 4), 1, "float32"), + ((2, 3, 4), (2, 3, 1), 2, "float32"), + ((2, 3, 4, 5), (1, 3, 4, 5), 0, "float32"), + ((6, 3, 4, 5), (2, 3, 4, 5), 1, "float32"), + ((2, 3, 8, 5), (2, 3, 1, 1), 2, "float32"), + ((16, 16, 4, 5), (16, 16, 4, 5), 3, "float32"), + ], +) +def test_scatter_add(dshape, ishape, axis, dtype): def ref_scatter_add(data, indices, updates, axis=0): output = np.copy(data) for index in np.ndindex(*indices.shape): @@ -1033,9 +1052,9 @@ def ref_scatter_add(data, indices, updates, axis=0): return output def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): - d = relay.var("d", relay.TensorType(dshape, dtype)) - i = relay.var("i", relay.TensorType(ishape, "int64")) - u = relay.var("u", relay.TensorType(ishape, dtype)) + d = relay.var("d", relay.TensorType(shape=[relay.Any() for _ in dshape], dtype=dtype)) + i = relay.var("i", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype="int64")) + u = relay.var("u", relay.TensorType(shape=[relay.Any() for _ in ishape], dtype=dtype)) z = relay.op.scatter_add(d, i, u, axis) func = relay.Function([d, i, u], z) @@ -1045,31 +1064,14 @@ def verify_scatter_add(dshape, ishape, axis=0, dtype="float32"): indices_np = np.random.randint(-dshape[axis], dshape[axis] - 1, ishape).astype("int64") ref_res = ref_scatter_add(data_np, indices_np, updates_np, axis) - for target, ctx in tvm.testing.enabled_targets(): - for kind in ["graph", "debug"]: - if target == "nvptx" and dtype == "float32" and len(dshape) == 1: - # scatter_add 1D on GPU is implemented via atomic. - # Floating point atomic requires LLVM 9 or newer for nvptx backend. - # But LLVM on CI is LLVM 8. - continue - intrp = relay.create_executor(kind, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(data_np, indices_np, updates_np) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_scatter_add((10,), (10,), 0, dtype="int32") - verify_scatter_add((1000,), (1000,)) - verify_scatter_add((1000,), (1000,), 0, dtype="int32") - verify_scatter_add((10, 5), (10, 5), -2) - verify_scatter_add((10, 5), (10, 5), -1) - verify_scatter_add((10, 5), (3, 5), 0) - verify_scatter_add((12, 4), (7, 2), 1) - verify_scatter_add((2, 3, 4), (1, 3, 4), 0) - verify_scatter_add((2, 3, 4), (2, 1, 4), 1) - verify_scatter_add((2, 3, 4), (2, 3, 1), 2) - verify_scatter_add((2, 3, 4, 5), (1, 3, 4, 5), 0) - verify_scatter_add((6, 3, 4, 5), (2, 3, 4, 5), 1) - verify_scatter_add((2, 3, 8, 5), (2, 3, 1, 1), 2) - verify_scatter_add((16, 16, 4, 5), (16, 16, 4, 5), 3) + verify_func( + func, + [data_np, indices_np, updates_np], + ref_res, + ) + + verify_scatter_add(dshape, ishape, axis, dtype) @tvm.testing.uses_gpu @@ -1515,6 +1517,105 @@ def verify_sparse_reshape( ) +@tvm.testing.uses_gpu +@pytest.mark.parametrize( + "data_np, segment_ids_np, num_segments", + [ + ( + np.array([5, 1, 7, 2, 3, 4], dtype=np.float32), + np.array([0, 0, 1, 1, 0, 1], dtype=np.int32), + None, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 0, 1], dtype=np.int32), + None, + ), + ( + np.random.random((6, 4, 5)), + np.array([2, 0, 1, 0, 3, 2], dtype=np.int64), + None, + ), + ( + np.array([[[1, 7]], [[3, 8]], [[2, 9]]], dtype=np.float32), + np.array([0, 0, 1], dtype=np.int32), + None, + ), + ( + np.random.random((9, 4, 5, 7)), + np.array([5, 0, 1, 0, 3, 6, 8, 7, 7], dtype=np.int64), + 9, + ), + ( + np.array([[1, 2, 3, 4], [-1, -2, -3, -4], [5, 6, 7, 8]], dtype=np.float64), + np.array([0, 2], dtype=np.int32), + 4, + ), + ( + np.random.random((6, 4, 5)), + np.array([0, 0, 1, 5, 5], dtype=np.int32), + 100, + ), + ], +) +@pytest.mark.parametrize("use_dyn", [True, False]) +def test_segment_sum(data_np, segment_ids_np, num_segments, use_dyn): + def ref_segment_sum( + data: np.ndarray, + segment_ids: np.ndarray, + num_segments: Optional[int] = None, + ): + """ + This function calculates the expected output of segment_sum operator given the inputs. + """ + if not num_segments: + num_segments = np.unique(segment_ids).shape[0] + + result = np.zeros((num_segments,) + data.shape[1:], data.dtype) + for i, index in enumerate(segment_ids): + result[index] += data[i] + return result + + def verify_segment_sum( + data_np: np.ndarray, segment_ids_np: np.ndarray, num_segments: Optional[int] + ): + """ + This function verifies the relay output of segment_sum with its expected output. + """ + if use_dyn: + data = relay.var( + "data", + shape=[relay.Any() for _ in data_np.shape], + dtype=str(data_np.dtype), + ) + segment_ids = relay.var( + "segment_ids", + shape=[relay.Any()], + dtype=str(segment_ids_np.dtype), + ) + else: + data = relay.var( + "data", + relay.TensorType(data_np.shape, str(data_np.dtype)), + ) + segment_ids = relay.var( + "segment_ids", relay.TensorType(segment_ids_np.shape, str(segment_ids_np.dtype)) + ) + z = relay.op.segment_sum(data, segment_ids, num_segments) + + func = relay.Function([data, segment_ids], z) + ref_res = ref_segment_sum(data_np, segment_ids_np, num_segments=num_segments) + segment_sum_result = run_infer_type(z) + assert segment_sum_result.checked_type.dtype == data_np.dtype + verify_func( + func, + [data_np, segment_ids_np], + ref_res, + ) + + verify_segment_sum(data_np, segment_ids_np, num_segments) + + def verify_func(func, data, ref_res, target_ctx=tvm.testing.enabled_targets()): assert isinstance(data, list) for target, ctx in target_ctx: