-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add segment sum Op to relay and 7 corresponding TF Ops , fix scatter_add dynamic bug #7562
Conversation
@masahi @tkonolige @mbrookhart @ymwangg PTAL. |
Nice, are you going to add frontend? |
Yes, do you prefer I add it in this PR or the next one ? I want to add frontends for multiple framework ops based on this relay op. |
Yes, I think it's better to add frontends (TF, PT) to make sure they are supported by this op. |
@masahi I have added 3 TF Ops to the frontend, all of which use this op. Let me know if that's enough. |
Can you also try PT EmbeddingBag? |
Hey @masahi , upon closely reading the Embedding Bag documentation, it seems that: (Referencing the
Now all of these ops exist except Let me know your thoughts on the best way to reuse existing code. After that implementation would be only a trivial few lines. |
Ok lets do embedding bag later, then. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good. A couple documentation improvements would be nice though.
@tkonolige I have finished addressing your comments, please re-review |
Actually I would like to add another related op in this PR. I will ping you after I am done with that. |
@tkonolige @masahi . I am done with the PR Please review/ re-review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple minor comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM.
Could you add a direct test for scatter_add with dynamic inputs? That would help identifying problems in the future.
@tkonolige int64 is not allowed with tf sparse ops, I put it on the relay op tests and the tf math ops. |
assert len(inputs) == 3, "There should be 3 input tensors" | ||
data = _op.take(inputs[0], inputs[1], axis=0) | ||
return _op.segment_sum(data, inputs[2]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is ok for now, but we definitely want a fused implementation here, just like TF/PT/C2 does. I don't expect this would work for a huge embedding table people want to use in practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. When you say a "fused implementation" , do you mean that all of it happens in a single ir ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any examples of what a "fused implementation" is ? Does this mean that in a fused implementation, the frontend will always just be a one liner ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, I understand we must do the take
and the addition from segment_sum
simultaneously for performance. So a fused implementation in that case would be a new op ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By "fused" I meant we shouldn't materialize the result of take
, which can be huge. In a fused implementation, we need to look up indices and accumulate the sum on the fly. This is why PT has EmbeddingBag
op, see their doc https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html.
Yes, a complicated op like this will not likely be feasible if we rely only on Relay-level op fusion. We need a dedicated sparse_segment_sum
TOPI/Relay op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think he meant that scatter_nd
exactly realizes fused take
and segment_sum
above. I haven't put deep thought into this but it made sense to me. But I remember parallelizing scatter_nd
looked harder than scatter_add
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am having a bit of a mind block understanding how take and segment_sum is essentially scatter_nd, do anyone of you mind writing small pseudocode ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I did a few variants of torch.nn.EmbeddingBag
, c2::sparse_length_sum
, etc in TVM IR in https://github.com/ajtulloch/tvm/blob/4b98beb75ca1505ec81ddca358ad61282ab6a05b/topi/python/topi/x86/sparse.py#L162-L257, https://github.com/ajtulloch/tvm/blob/sparse-ops/topi/python/topi/sparse/sparse_lengths_sum.py#L45-L98, https://github.com/ajtulloch/sparse-ads-baselines/blob/a495ea076882615d454d27a1a5b191ec675d3acc/lxu_cache_cpu_funcs.py#L8-L149, etc if that's of interest.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this more, I believe the take is necessary if we are using scatter_nd. We could make a more generic version of scatter_nd and gather_nd that has indices in both the input and output buffers. That would cover this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I'll merge this as it is then.
Thanks @codeislife99 @tkonolige @mbrookhart |
…add dynamic bug (apache#7562) * Add segment sum Op * Remove unnecessary * Documentation * Black * Add GPU * Uncomment * Add documentation * Add dynamic tests * Add TF Op * Add Sparse Segment Sum * Add test coverage * PR Comments * Int64 tests * Add SparseSegmentSqrtN * Add SparseSegmentSqrtNOp * Deduplicate code * Add SparseSegmentMean * Parametrize Tests * Remove * Modularize * Black * Modularize Code * Pylint * PR Comments * Add scatter add tests * Remove Test Co-authored-by: Ubuntu <[email protected]>
…add dynamic bug (apache#7562) * Add segment sum Op * Remove unnecessary * Documentation * Black * Add GPU * Uncomment * Add documentation * Add dynamic tests * Add TF Op * Add Sparse Segment Sum * Add test coverage * PR Comments * Int64 tests * Add SparseSegmentSqrtN * Add SparseSegmentSqrtNOp * Deduplicate code * Add SparseSegmentMean * Parametrize Tests * Remove * Modularize * Black * Modularize Code * Pylint * PR Comments * Add scatter add tests * Remove Test Co-authored-by: Ubuntu <[email protected]>
This PR adds the Segment Sum Op which will serve as a generic op for multiple framework specific ops
Tensorflow -- tf.math.segment_sum, tf.sparse.segment_sum
Caffe -- sparse length sum
PyTorch -- Embedding Bag
Since this PR uses scatter_add , it also makes some small changes which make it work for dynamic inputs.