-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add segment sum Op to relay and 7 corresponding TF Ops , fix scatter_add dynamic bug #7562
Merged
Merged
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
4fe855b
Add segment sum Op
20bcbba
Remove unnecessary
fa9fb48
Documentation
74e8370
Black
8d59675
Add GPU
6575048
Uncomment
9ea802d
Add documentation
23f8b2b
Add dynamic tests
a8d77fa
Add TF Op
5d2d192
Add Sparse Segment Sum
c3d90a6
Add test coverage
bac338e
PR Comments
e11f5e1
Int64 tests
9b76291
Add SparseSegmentSqrtN
20ece40
Add SparseSegmentSqrtNOp
59a0bdd
Deduplicate code
3d495f3
Add SparseSegmentMean
e2fb098
Parametrize Tests
907c3d8
Remove
20c0f3e
Modularize
93b87d3
Black
c20a026
Modularize Code
82b2a13
Pylint
97a2446
PR Comments
6e77c0a
Add scatter add tests
bfd71b6
Remove Test
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is ok for now, but we definitely want a fused implementation here, just like TF/PT/C2 does. I don't expect this would work for a huge embedding table people want to use in practice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. When you say a "fused implementation" , do you mean that all of it happens in a single ir ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have any examples of what a "fused implementation" is ? Does this mean that in a fused implementation, the frontend will always just be a one liner ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, I understand we must do the
take
and the addition fromsegment_sum
simultaneously for performance. So a fused implementation in that case would be a new op ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By "fused" I meant we shouldn't materialize the result of
take
, which can be huge. In a fused implementation, we need to look up indices and accumulate the sum on the fly. This is why PT hasEmbeddingBag
op, see their doc https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html.Yes, a complicated op like this will not likely be feasible if we rely only on Relay-level op fusion. We need a dedicated
sparse_segment_sum
TOPI/Relay op.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think he meant that
scatter_nd
exactly realizes fusedtake
andsegment_sum
above. I haven't put deep thought into this but it made sense to me. But I remember parallelizingscatter_nd
looked harder thanscatter_add
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I am having a bit of a mind block understanding how take and segment_sum is essentially scatter_nd, do anyone of you mind writing small pseudocode ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FWIW I did a few variants of
torch.nn.EmbeddingBag
,c2::sparse_length_sum
, etc in TVM IR in https://github.com/ajtulloch/tvm/blob/4b98beb75ca1505ec81ddca358ad61282ab6a05b/topi/python/topi/x86/sparse.py#L162-L257, https://github.com/ajtulloch/tvm/blob/sparse-ops/topi/python/topi/sparse/sparse_lengths_sum.py#L45-L98, https://github.com/ajtulloch/sparse-ads-baselines/blob/a495ea076882615d454d27a1a5b191ec675d3acc/lxu_cache_cpu_funcs.py#L8-L149, etc if that's of interest.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about this more, I believe the take is necessary if we are using scatter_nd. We could make a more generic version of scatter_nd and gather_nd that has indices in both the input and output buffers. That would cover this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok I'll merge this as it is then.