-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay/TOPI][Op] Add TopK operator #3256
Conversation
The input data tensor. | ||
|
||
k : int, optional | ||
Number of top elements to select. Return all elements if k < 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recommend returning a tensor with size 0 if k==0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think current graph runtime doesn't support 0 size tensor. So I assert k >= 1 in tensorflow frontend converter.
Have you compared you current sort implementation with existing one? Is there any performance regression for large workloads like (1, 8000) |
@Laurawly I evaluated the CUDA performance of argsort with input size (1, 8000) on p3. Existing one take 28.57ms to finish, and the one in this PR takes 15.11ms. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* init impl for topk * Fix cpu for topk * init cuda impl for topk * Add cuda for topk * fix * Add doc * update doc * lint * lint * lint * x * fix warning * [Relay] Add TopK in tf converter * Add frontend converter * fix
* init impl for topk * Fix cpu for topk * init cuda impl for topk * Add cuda for topk * fix * Add doc * update doc * lint * lint * lint * x * fix warning * [Relay] Add TopK in tf converter * Add frontend converter * fix
This PR includes the following changes
Thanks @yongwww for help on Tensorflow frontend converter.