Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Support ScatterElements with reduction #13894

Merged
merged 37 commits into from
Feb 16, 2023

Conversation

vvchernov
Copy link
Contributor

@vvchernov vvchernov commented Feb 1, 2023

Support ScatterElements on ONNX front-end as described in ONNX docs. Just now Scatter op implementation is used where reduction attribute is not supported at all. Also CI tests are supported for the op.

P.S. In other PRs after that I plan: 1. remove scatter_add and reconnect all its using to ScatterElements(reduction="add") 2. remove scatter implementation and use ScatterElements(reduction="update") instead of it. It will remove the restriction related to input tensor rank size (just now rank <= 4)

@tvm-bot
Copy link
Collaborator

tvm-bot commented Feb 1, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@vvchernov vvchernov force-pushed the vc/scatter branch 9 times, most recently from f0fdaf1 to 3511f6d Compare February 8, 2023 09:43
@vvchernov vvchernov changed the title WIP: [ONNX] Support ScatterElements with reduction [ONNX] Support ScatterElements with reduction Feb 9, 2023
@vvchernov
Copy link
Contributor Author

@tvm-bot rerun

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did a first pass, looks good for most part. One comment.

Will take a closer look tonight.


ind_fused = bx1 * max_threads + tx1
with ib.if_scope(ind_fused < ind_full_range):
index_check = tir.LT(indices_ptr[ind_fused], tir.const(0, indices.dtype))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why can this block not be fused with the below block? Is it to prevent warp divergence?

Perhaps you can do something like:

index = index + (index < 0) * axis_range

If that is the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @AndrewZhaoLuo It was initially fused to the common loop. I decided that conditions is not good thing for GPU calculations aspecially wrapped to additional loops. But in this case it looks like the same (I can return it back if you think it is better way). Another thing is assert for probably shifted axis which is still out of bounds (theoretically we should do this check), but I do not know how I can do this check on ir builder side.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very convenient and clear expression that you supposed. Is it the same as I wrote or use other tir procedures? I'm slightly aware about excess multiplication in it

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo Feb 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arithmetic intensity per element loaded is very very low (single digit) so I am not very worried about the multiplication. I would always expect memory access to always be the bottleneck so we can do a lot of computation while the GPU is fetching stuff from global memory. Regardless, it's a pretty common trick to get "branch-less" programming.

Yes, I am not as familiar with IR Builder, I will say if you want to enforce indices being valid you can just do something like modulo indices internally (for wrapping behavior). I am not familiar with assertions in tir in general.

Personally I think it is ok to not check and assume caller will check and guarantee good inputs as we don't necessarily want the check in the base computation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I have fused block related to index shifting with general one and use your expression.
Code related to scatter_add was refactored in another branch and PR was prepared (see below), the latter is waiting for merging of this PR.

Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you compare current performance of your GPU kernel vs the current scatter_add implementation?

Looking at it, it may be easier to reuse the existing scatter_add topi implementation and extend it with new reduction functions.

@vvchernov
Copy link
Contributor Author

@AndrewZhaoLuo I did not compare performance on GPU, but I prepared new PR where replace scatter_add by scatter_elements and reuse code for cuda

@vvchernov
Copy link
Contributor Author

One moment for scatter_add: it is implemented for 1d, 2d, 3d and 4d input tensors, as scatter. I thought that I can reuse scatter_add approach, but as result I have done scatter_elements in general way without restrictions on input data rank

@vvchernov
Copy link
Contributor Author

@tvm-bot rerun

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants