Skip to content

Commit

Permalink
Add segment sum Op to relay and 7 corresponding TF Ops , fix scatter_…
Browse files Browse the repository at this point in the history
…add dynamic bug (apache#7562)

* Add segment sum Op

* Remove unnecessary

* Documentation

* Black

* Add GPU

* Uncomment

* Add documentation

* Add dynamic tests

* Add TF Op

* Add Sparse Segment Sum

* Add test coverage

* PR Comments

* Int64 tests

* Add SparseSegmentSqrtN

* Add SparseSegmentSqrtNOp

* Deduplicate code

* Add SparseSegmentMean

* Parametrize Tests

* Remove

* Modularize

* Black

* Modularize Code

* Pylint

* PR Comments

* Add scatter add tests

* Remove Test

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
2 people authored and trevor-m committed May 11, 2021
1 parent d215f9e commit 3b82b33
Show file tree
Hide file tree
Showing 5 changed files with 519 additions and 48 deletions.
126 changes: 126 additions & 0 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,6 +1167,125 @@ def _impl(inputs, attr, params, mod):
return _impl


def _math_segment_sum():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 2, "There should be 2 input tensors"
return get_relay_op("segment_sum")(inputs[0], inputs[1])

return _impl


def _sparse_segment_sum():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
return _op.segment_sum(data, inputs[2])

return _impl


def _sparse_segment_sum_with_num_segments():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
num_segments = int(inputs[3].data.asnumpy().item())
return _op.segment_sum(data, inputs[2], num_segments)

return _impl


def row_wise_divide(multi_dim_tensor, one_dim_vector):
"""
This function enables row-wise division of multi_dim_tensor and one_dim_vector.
To achieve this, it is first tiled to the appropriate shape and then elemwise_division
"""
multi_dim_tensor_offrow_shape = _op.strided_slice(
_op.shape_of(multi_dim_tensor, "int32"), [1], [-1], slice_mode="size"
)
one_dim_vector_tiled_shape = _op.concatenate(
[_op.reverse(multi_dim_tensor_offrow_shape, 0), _expr.const([1])], axis=0
)
one_dim_vector_tiled = _op.transpose(_op.tile(one_dim_vector, one_dim_vector_tiled_shape))
return _op.divide(multi_dim_tensor, one_dim_vector_tiled)


def count_all_indices(segment_ids, counts_dtype, num_segments=None):
"""
This snippet calculates the sqrt count of each index among all valid indices
Valid indices are from 0 to max of [segment ids, num_segments]
"""

max_segments = _op.reshape(_op.max(segment_ids), -1) + _expr.const([1])
if num_segments:
max_segments = _op.maximum(max_segments, _expr.const([num_segments]))
max_ones = _op.maximum(max_segments, _op.shape_of(segment_ids))
counts = _op.segment_sum(
_op.ones(max_ones, counts_dtype), segment_ids, num_segments=num_segments
)
real_counts = _op.clip(counts, 1, 2147483647) # Clip max doesn't work over int32
return real_counts


def _sparse_segment_sum_sqrtn():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
real_counts = count_all_indices(inputs[2], attr["T"].name)
real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data))

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2])

return row_wise_divide(segment_sum, real_sqrt_counts)

return _impl


def _sparse_segment_sum_sqrtn_with_num_segments():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
num_segments = int(inputs[3].data.asnumpy().item())
real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments)
real_sqrt_counts = _op.sqrt(_op.cast_like(real_counts, data))

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments)

return row_wise_divide(segment_sum, real_sqrt_counts)

return _impl


def _sparse_segment_mean():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 3, "There should be 3 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
real_counts = count_all_indices(inputs[2], attr["T"].name)

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2])

return row_wise_divide(segment_sum, real_counts)

return _impl


def _sparse_segment_mean_with_num_segments():
def _impl(inputs, attr, params, mod):
assert len(inputs) == 4, "There should be 4 input tensors"
data = _op.take(inputs[0], inputs[1], axis=0)
num_segments = int(inputs[3].data.asnumpy().item())
real_counts = count_all_indices(inputs[2], attr["T"].name, num_segments=num_segments)

# Calculate regular segment sum
segment_sum = _op.segment_sum(data, inputs[2], num_segments=num_segments)

return row_wise_divide(segment_sum, real_counts)

return _impl


def _identity():
def _impl(inputs, attr, params, mod):
return inputs[0]
Expand Down Expand Up @@ -2661,6 +2780,13 @@ def _impl(inputs, attr, params, mod):
"SparseTensorDenseMatMul": _sparse_tensor_dense_matmul(),
"SparseFillEmptyRows": _sparse_fill_empty_rows(),
"SparseReshape": _sparse_reshape(),
"SegmentSum": _math_segment_sum(),
"SparseSegmentSum": _sparse_segment_sum(),
"SparseSegmentSumWithNumSegments": _sparse_segment_sum_with_num_segments(),
"SparseSegmentSqrtN": _sparse_segment_sum_sqrtn(),
"SparseSegmentSqrtNWithNumSegments": _sparse_segment_sum_sqrtn_with_num_segments(),
"SparseSegmentMean": _sparse_segment_mean(),
"SparseSegmentMeanWithNumSegments": _sparse_segment_mean_with_num_segments(),
"Split": _split(False),
"SplitV": _split(True),
"Sqrt": AttrCvt("sqrt"),
Expand Down
69 changes: 69 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,75 @@ def sparse_reshape(sparse_indices, prev_shape, new_shape):
return TupleWrapper(_make.sparse_reshape(sparse_indices, prev_shape, new_shape), 2)


def segment_sum(data, segment_ids, num_segments=None):
"""
Computes the sum along segment_ids along axis 0. If multiple segment_ids reference the same
location their contributions add up.
result[index, j, k, ...] = Σi... data[i, j, k,..] where index = segment_ids[i]
This op is much better understood with visualization articulated in the following links and
examples at the end of this docstring.
https://www.tensorflow.org/api_docs/python/tf/math/unsorted_segment_sum
https://caffe2.ai/docs/sparse-operations.html#null__unsorted-segment-reduction-ops
Parameters
----------
data : relay.Expr
Input Tensor. It can be of any type and multi-dimensional
segment_ids : relay.Expr
A 1-D int32/int64 tensor containing the segment_ids of the rows to calculate the output
sum upon. It defines a mapping from the zeroth dimension of data onto segment_ids. The
segment_ids tensor should be the size of the first dimension, d0, with consecutive IDs
in the range 0 to k, where k<d0. In particular, a segmentation of a matrix tensor is a
mapping of rows to segments. This tensor doesn't need to be sorted
num_segments : Optional[int]
An integer describing the shape of the zeroth dimension. If unspecified, its calculated
equivalent to the number of unique segment_ids
Returns
-------
result: relay.Expr
Output tensor.
Examples
--------
.. code-block:: python
data = [[1, 2, 3, 4],
[4, -3, 2, -1],
[5, 6, 7, 8]]
segment_ids = [0, 0, 1]
result = segment_sum(data, segment_ids)
result = [[5, -1, 5, 3],[5, 6, 7, 8]]
data = [[1, 2, 3, 4],
[4, -3, 2, -1],
[5, 6, 7, 8]]
segment_ids = [2, 0, 0]
num_segments = 3
result = segment_sum(data, segment_ids, num_segments)
result = [[5, 6, 7, 8],[0, 0, 0, 0], [5, -1, 5, 3]]
"""

one_tensor = cast_like(const([1]), segment_ids)
if num_segments:
if isinstance(num_segments, int):
max_segments = const([num_segments])
max_segments = cast_like(max_segments, segment_ids)
else:
max_segments = cast_like(num_segments, segment_ids)
else:
max_segments = _make.add(reshape(_make.max(segment_ids, [0], False, False), -1), one_tensor)

data_offrow_shape = strided_slice(_make.shape_of(data, "int32"), [1], [-1], slice_mode="size")
data_offrow_shape = cast_like(data_offrow_shape, max_segments)
new_shape = _make.concatenate(Tuple([max_segments, data_offrow_shape]), 0)
segment_ids_tiled_shape = _make.concatenate(
Tuple([reverse(data_offrow_shape, 0), one_tensor]), 0
)
expanded_segment_ids = tile(segment_ids, segment_ids_tiled_shape)
scatter_add_segment_ids = transpose(expanded_segment_ids)
src = cast_like(_dyn_make.zeros(new_shape, "float64"), data)
return scatter_add(src, scatter_add_segment_ids, data, axis=0)


def cumsum(data, axis=None, dtype=None, exclusive=None):
"""Numpy style cumsum op. Return the cumulative inclusive sum of the elements along
a given axis.
Expand Down
40 changes: 20 additions & 20 deletions python/tvm/topi/scatter_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def _scatter_add_1d(data, indices, updates):
@hybrid.script
def _scatter_add_2d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
out[i, j] = data[i, j]
if axis == 0:
for i in range(indices.shape[0]):
Expand All @@ -54,14 +54,14 @@ def _scatter_add_2d(data, indices, updates, axis):
@hybrid.script
def _scatter_add_3d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
out[i, j, k] = data[i, j, k]
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
indices[i, j, k]
if indices[i, j, k] >= 0
Expand All @@ -72,7 +72,7 @@ def _scatter_add_3d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
indices[i, j, k]
Expand All @@ -83,7 +83,7 @@ def _scatter_add_3d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for k in range(indices.shape[2]):
out[
i,
j,
Expand All @@ -98,17 +98,17 @@ def _scatter_add_3d(data, indices, updates, axis):
@hybrid.script
def _scatter_add_4d(data, indices, updates, axis):
out = output_tensor(data.shape, data.dtype)
for i in const_range(data.shape[0]):
for j in const_range(data.shape[1]):
for k in const_range(data.shape[2]):
for l in const_range(data.shape[3]):
for i in range(data.shape[0]):
for j in range(data.shape[1]):
for k in range(data.shape[2]):
for l in range(data.shape[3]):
out[i, j, k, l] = data[i, j, k, l]

if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
indices[i, j, k, l]
if indices[i, j, k, l] >= 0
Expand All @@ -120,8 +120,8 @@ def _scatter_add_4d(data, indices, updates, axis):
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
indices[i, j, k, l]
Expand All @@ -133,8 +133,8 @@ def _scatter_add_4d(data, indices, updates, axis):
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand All @@ -146,8 +146,8 @@ def _scatter_add_4d(data, indices, updates, axis):
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in const_range(indices.shape[2]):
for l in const_range(indices.shape[3]):
for k in range(indices.shape[2]):
for l in range(indices.shape[3]):
out[
i,
j,
Expand Down
Loading

0 comments on commit 3b82b33

Please sign in to comment.