Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add segment sum Op to relay and 7 corresponding TF Ops , fix scatter_add dynamic bug #7562

Merged
merged 26 commits into from
Mar 4, 2021
126 changes: 126 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,125 @@ def _impl(inputs, attr, params, mod):
return _impl


def _math_segment_sum():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 2, "There should be 2 input tensors"
return get_relay_op("segment_sum")(inputs[0], inputs[1])

return _impl


def _sparse_segment_sum():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
return _op.segment_sum(data, inputs[2])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ok for now, but we definitely want a fused implementation here, just like TF/PT/C2 does. I don't expect this would work for a huge embedding table people want to use in practice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. When you say a "fused implementation" , do you mean that all of it happens in a single ir ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have any examples of what a "fused implementation" is ? Does this mean that in a fused implementation, the frontend will always just be a one liner ?

Copy link
Contributor Author

@codeislife99 codeislife99 Mar 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I understand we must do the take and the addition from segment_sum simultaneously for performance. So a fused implementation in that case would be a new op ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By "fused" I meant we shouldn't materialize the result of take, which can be huge. In a fused implementation, we need to look up indices and accumulate the sum on the fly. This is why PT has EmbeddingBag op, see their doc https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html.

Yes, a complicated op like this will not likely be feasible if we rely only on Relay-level op fusion. We need a dedicated sparse_segment_sum TOPI/Relay op.

Copy link
Member

@masahi masahi Mar 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he meant that scatter_nd exactly realizes fused take and segment_sum above. I haven't put deep thought into this but it made sense to me. But I remember parallelizing scatter_nd looked harder than scatter_add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am having a bit of a mind block understanding how take and segment_sum is essentially scatter_nd, do anyone of you mind writing small pseudocode ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this more, I believe the take is necessary if we are using scatter_nd. We could make a more generic version of scatter_nd and gather_nd that has indices in both the input and output buffers. That would cover this case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I'll merge this as it is then.

return _impl


def _sparse_segment_sum_with_num_segments():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
num_segments = int(inputs[3].data.asnumpy().item())
return _op.segment_sum(data, inputs[2], num_segments)

return _impl


def row_wise_divide(multi_dim_tensor, one_dim_vector):
"""
This function enables row-wise division of multi_dim_tensor and one_dim_vector.
To achieve this, it is first tiled to the appropriate shape and then elemwise_division
"""
multi_dim_tensor_offrow_shape = _op.strided_slice(
_op.shape_of(multi_dim_tensor, "int32"), [1], [-1], slice_mode="size"
)
one_dim_vector_tiled_shape = _op.concatenate(
[_op.reverse(multi_dim_tensor_offrow_shape, 0), _expr.const([1])], axis=0
)
one_dim_vector_tiled = _op.transpose(_op.tile(one_dim_vector, one_dim_vector_tiled_shape))
return _op.divide(multi_dim_tensor, one_dim_vector_tiled)


def count_all_indices(segment_ids, counts_dtype, num_segments=None):
"""
This snippet calculates the sqrt count of each index among all valid indices
Valid indices are from 0 to max of [segment ids, num_segments]
"""

max_segments = _op.reshape(_op.max(segment_ids), -1) + _expr.const([1])
if num_segments:
max_segments = _op.maximum(max_segments, _expr.const([num_segments]))
max_ones = _op.maximum(max_segments, _op.shape_of(segment_ids))
counts = _op.segment_sum(
_op.ones(max_ones, counts_dtype), segment_ids, num_segments=num_segments
)
real_counts = _op.clip(counts, 1, 2147483647) # Clip max doesn't work over int32
return real_counts


def _sparse_segment_sum_sqrtn():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
real_counts = count_all_indices(inputs[2], attr["T"].name)
real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data))

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2])

return row_wise_divide(segment_sum, real_sqrt_counts)

return _impl


def _sparse_segment_sum_sqrtn_with_num_segments():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
num_segments = int(inputs[3].data.asnumpy().item())
real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments)
real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data))

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments)

return row_wise_divide(segment_sum, real_sqrt_counts)

return _impl


def _sparse_segment_mean():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
real_counts = count_all_indices(inputs[2], attr["T"].name)

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2])

return row_wise_divide(segment_sum, real_counts)

return _impl


def _sparse_segment_mean_with_num_segments():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
num_segments = int(inputs[3].data.asnumpy().item())
real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments)

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments)

return row_wise_divide(segment_sum, real_counts)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2660,6 +2779,13 @@ def _impl(inputs, attr, params, mod):
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseFillEmptyRows": _sparse_fill_empty_rows(),
"SparseReshape": _sparse_reshape(),
"SegmentSum": _math_segment_sum(),
"SparseSegmentSum": _sparse_segment_sum(),
"SparseSegmentSumWithNumSegments": _sparse_segment_sum_with_num_segments(),
"SparseSegmentSqrtN": _sparse_segment_sum_sqrtn(),
"SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(),
"SparseSegmentMean": _sparse_segment_mean(),
"SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
69 changes: 69 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,75 @@ def sparse_reshape(sparse_indices, prev_shape, new_shape):
return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2)


def segment_sum(data, segment_ids, num_segments=None):
"""
Computes the sum along segment_ids along axis 0. If multiple segment_ids reference the same
location their contributions add up.
result[index, j, k, ...] = Σi... data[i, j, k,..] where index = segment_ids[i]
This op is much better understood with visualization articulated in the following links and
examples at the end of this docstring.

https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum
https://caffe2.ai/docs/sparse-operations.html#null__unsorted-segment-reduction-ops

Parameters
----------
data : relay.Expr
codeislife99 marked this conversation as resolved.
Show resolved Hide resolved
Input Tensor. It can be of any type and multi-dimensional
segment_ids : relay.Expr
A 1-D int32/int64 tensor containing the segment_ids of the rows to calculate the output
sum upon. It defines a mapping from the zeroth dimension of data onto segment_ids. The
segment_ids tensor should be the size of the first dimension, d0, with consecutive IDs
in the range 0 to k, where k<d0. In particular, a segmentation of a matrix tensor is a
mapping of rows to segments. This tensor doesn't need to be sorted
num_segments : Optional[int]
An integer describing the shape of the zeroth dimension. If unspecified, its calculated
equivalent to the number of unique segment_ids
Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python
data = [[1, 2, 3, 4],
[4, -3, 2, -1],
[5, 6, 7, 8]]
segment_ids = [0, 0, 1]
result = segment_sum(data, segment_ids)
result = [[5, -1, 5, 3],[5, 6, 7, 8]]

data = [[1, 2, 3, 4],
[4, -3, 2, -1],
[5, 6, 7, 8]]
segment_ids = [2, 0, 0]
num_segments = 3
result = segment_sum(data, segment_ids, num_segments)
result = [[5, 6, 7, 8],[0, 0, 0, 0], [5, -1, 5, 3]]
"""

one_tensor = cast_like(const([1]), segment_ids)
if num_segments:
if isinstance(num_segments, int):
max_segments = const([num_segments])
max_segments = cast_like(max_segments, segment_ids)
else:
max_segments = cast_like(num_segments, segment_ids)
else:
max_segments = _make.add(reshape(_make.max(segment_ids, [0], False, False), -1), one_tensor)

data_offrow_shape = strided_slice(_make.shape_of(data, "int32"), [1], [-1], slice_mode="size")
data_offrow_shape = cast_like(data_offrow_shape, max_segments)
new_shape = _make.concatenate(Tuple([max_segments, data_offrow_shape]), 0)
segment_ids_tiled_shape = _make.concatenate(
Tuple([reverse(data_offrow_shape, 0), one_tensor]), 0
)
expanded_segment_ids = tile(segment_ids, segment_ids_tiled_shape)
scatter_add_segment_ids = transpose(expanded_segment_ids)
src = cast_like(_dyn_make.zeros(new_shape, "float64"), data)
return scatter_add(src, scatter_add_segment_ids, data, axis=0)


def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Expand Down
40 changes: 20 additions & 20 deletions python/tvm/topi/scatter_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _scatter_add_1d(data, indices, updates):
@hybrid.script
def _scatter_add_2d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
out[i, j] = data[i, j]
if axis == 0:
for i in range(indices.shape[0]):
Expand All @@ -54,14 +54,14 @@ def _scatter_add_2d(data, indices, updates, axis):
@hybrid.script
def _scatter_add_3d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
out[i, j, k] = data[i, j, k]
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
indices[i, j, k]
if indices[i, j, k] >= 0
Expand All @@ -72,7 +72,7 @@ def _scatter_add_3d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
indices[i, j, k]
Expand All @@ -83,7 +83,7 @@ def _scatter_add_3d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
j,
Expand All @@ -98,17 +98,17 @@ def _scatter_add_3d(data, indices, updates, axis):
@hybrid.script
def _scatter_add_4d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for l in const_range(data.shape[3]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
for l in range(data.shape[3]):
out[i, j, k, l] = data[i, j, k, l]

if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
indices[i, j, k, l]
if indices[i, j, k, l] >= 0
Expand All @@ -120,8 +120,8 @@ def _scatter_add_4d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
indices[i, j, k, l]
Expand All @@ -133,8 +133,8 @@ def _scatter_add_4d(data, indices, updates, axis):
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand All @@ -146,8 +146,8 @@ def _scatter_add_4d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand Down
Loading