-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Support ScatterElements with reduction #13894
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
f0fdaf1
to
3511f6d
Compare
35d01d5
to
8e302e7
Compare
@tvm-bot rerun |
8e302e7
to
251bd94
Compare
251bd94
to
e31332a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did a first pass, looks good for most part. One comment.
Will take a closer look tonight.
|
||
ind_fused = bx1 * max_threads + tx1 | ||
with ib.if_scope(ind_fused < ind_full_range): | ||
index_check = tir.LT(indices_ptr[ind_fused], tir.const(0, indices.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why can this block not be fused with the below block? Is it to prevent warp divergence?
Perhaps you can do something like:
index = index + (index < 0) * axis_range
If that is the case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @AndrewZhaoLuo It was initially fused to the common loop. I decided that conditions is not good thing for GPU calculations aspecially wrapped to additional loops. But in this case it looks like the same (I can return it back if you think it is better way). Another thing is assert for probably shifted axis which is still out of bounds (theoretically we should do this check), but I do not know how I can do this check on ir builder side.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is very convenient and clear expression that you supposed. Is it the same as I wrote or use other tir procedures? I'm slightly aware about excess multiplication in it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arithmetic intensity per element loaded is very very low (single digit) so I am not very worried about the multiplication. I would always expect memory access to always be the bottleneck so we can do a lot of computation while the GPU is fetching stuff from global memory. Regardless, it's a pretty common trick to get "branch-less" programming.
Yes, I am not as familiar with IR Builder, I will say if you want to enforce indices being valid you can just do something like modulo indices internally (for wrapping behavior). I am not familiar with assertions in tir in general.
Personally I think it is ok to not check and assume caller will check and guarantee good inputs as we don't necessarily want the check in the base computation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. I have fused block related to index shifting with general one and use your expression.
Code related to scatter_add was refactored in another branch and PR was prepared (see below), the latter is waiting for merging of this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you compare current performance of your GPU kernel vs the current scatter_add implementation?
Looking at it, it may be easier to reuse the existing scatter_add topi implementation and extend it with new reduction functions.
@AndrewZhaoLuo I did not compare performance on GPU, but I prepared new PR where replace scatter_add by scatter_elements and reuse code for cuda |
One moment for scatter_add: it is implemented for 1d, 2d, 3d and 4d input tensors, as scatter. I thought that I can reuse scatter_add approach, but as result I have done scatter_elements in general way without restrictions on input data rank |
@tvm-bot rerun |
7d6db59
to
ac7c230
Compare
Support ScatterElements on ONNX front-end as described in ONNX docs. Just now Scatter op implementation is used where
reduction
attribute is not supported at all. Also CI tests are supported for the op.P.S. In other PRs after that I plan: 1. remove scatter_add and reconnect all its using to ScatterElements(reduction="add") 2. remove scatter implementation and use ScatterElements(reduction="update") instead of it. It will remove the restriction related to input tensor rank size (just now rank <= 4)