-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize index computation #33909
optimize index computation #33909
Conversation
Thanks for your contribution! |
… roll_optimize_kernel
paddle/fluid/operators/roll_op.cu
Outdated
dim_idx = (idx / strides[i]) % sizes[i]; | ||
dim_idx_shift = (dim_idx + shifts[i]) % sizes[i]; | ||
output_idx = output_idx + (dim_idx_shift - dim_idx) * strides[i]; | ||
dim_idx = (idx / strides[i]) % sizes[i] + shifts[i]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量名应符合实际代表的含义,这里应该是原来的dim_idx_shift
,且临时变量dim_idx
不再需要,应该删除。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不是dim_idx_shift,就是新的dim_idx位置的预估
paddle/fluid/operators/roll_op.cu
Outdated
@@ -40,9 +40,12 @@ __global__ void RollCudaKernel(const T* input, T* output, int64_t N, | |||
|
|||
#pragma unroll Rank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里写#pragma unroll
就够了吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,thanks
… roll_optimize_kernel
… roll_optimize_kernel
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Performance optimization
PR changes
OPs
Describe
optimize the index computation
Paddle vs Pytorch
axis=(1)
shift=(5)
axis=(0)
shift=(5)
axis=(0,1)
shift=(5,5)